source: orange/Orange/classification/wrappers.py @ 10580:c4cbae8dcf8b

Revision 10580:c4cbae8dcf8b, 4.0 KB checked in by markotoplak, 2 years ago (diff)

Moved deprecation functions, progress bar support and environ into Orange.utils. Orange imports cleanly, although it is not tested yet.

Line 
1import Orange.core
2import Orange.evaluation.scoring as scoring
3import Orange.data
4import Orange.evaluation.testing
5import Orange.evaluation.scoring
6
7from Orange.utils import deprecated_members
8
9class StepwiseLearner(Orange.core.Learner):
10  def __new__(cls, data=None, weight_id=None, **kwargs):
11      self = Orange.core.Learner.__new__(cls, **kwargs)
12      if data is not None:
13          self.__init__(**kwargs)
14          return self(data, weight_id)
15      else:
16          return self
17     
18  def __init__(self, **kwds):
19    self.remove_threshold = 0.3
20    self.add_threshold = 0.2
21    self.stat, self.statsign = scoring.CA, 1
22    for name, val in kwds.items():
23        setattr(self, name, val)
24
25  def __call__(self, data, weight_id = 0, **kwds):
26    import Orange.evaluation.testing, Orange.evaluation.scoring, statc
27   
28    self.__dict__.update(kwds)
29
30    if self.remove_threshold < self.add_threshold:
31        raise ValueError("'remove_threshold' should be larger or equal to 'add_threshold'")
32
33    classVar = data.domain.classVar
34   
35    indices = Orange.core.MakeRandomIndicesCV(data, folds = getattr(self, "folds", 10))
36    domain = Orange.data.Domain([], classVar)
37
38    res = Orange.evaluation.testing.test_with_indices([self.learner], Orange.data.Table(domain, data), indices)
39   
40    oldStat = self.stat(res)[0]
41    oldStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)]
42    print ".", oldStat, domain
43    stop = False
44    while not stop:
45        stop = True
46        if len(domain.attributes)>=2:
47            bestStat = None
48            for attr in domain.attributes:
49                newdomain = Orange.data.Domain(filter(lambda x: x!=attr, domain.attributes), classVar)
50                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, data), weight_id), indices)
51               
52                newStat = self.stat(res)[0]
53                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)] 
54                print "-", newStat, newdomain
55                ## If stat has increased (ie newStat is better than bestStat)
56                if not bestStat or cmp(newStat, bestStat) == self.statsign:
57                    if cmp(newStat, oldStat) == self.statsign:
58                        bestStat, bestStats, bestAttr = newStat, newStats, attr
59                    elif statc.wilcoxont(oldStats, newStats)[1] > self.remove_threshold:
60                            bestStat, bestAttr, bestStats = newStat, newStats, attr
61            if bestStat:
62                domain = Orange.data.Domain(filter(lambda x: x!=bestAttr, domain.attributes), classVar)
63                oldStat, oldStats = bestStat, bestStats
64                stop = False
65                print "removed", bestAttr.name
66
67        bestStat, bestAttr = oldStat, None
68        for attr in data.domain.attributes:
69            if not attr in domain.attributes:
70                newdomain = Orange.data.Domain(domain.attributes + [attr], classVar)
71                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, data), weight_id), indices)
72               
73                newStat = self.stat(res)[0]
74                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)] 
75                print "+", newStat, newdomain
76
77                ## If stat has increased (ie newStat is better than bestStat)
78                if cmp(newStat, bestStat) == self.statsign and statc.wilcoxont(oldStats, newStats)[1] < self.add_threshold:
79                    bestStat, bestStats, bestAttr = newStat, newStats, attr
80        if bestAttr:
81            domain = Orange.data.Domain(domain.attributes + [bestAttr], classVar)
82            oldStat, oldStats = bestStat, bestStats
83            stop = False
84            print "added", bestAttr.name
85
86    return self.learner(Orange.data.Table(domain, data), weight_id)
87
88StepwiseLearner = deprecated_members(
89                    {"removeThreshold": "remove_threshold",
90                     "addThreshold": "add_threshold"},
91                    )(StepwiseLearner)
Note: See TracBrowser for help on using the repository browser.