source: orange/Orange/classification/wrappers.py @ 9961:b4e8e9784a5d

Revision 9961:b4e8e9784a5d, 4.0 KB checked in by ales_erjavec, 2 years ago (diff)

Fixed camelCase names in the interface.

Line 
1import Orange.core
2import Orange.evaluation.scoring as scoring
3import Orange.data
4import Orange.evaluation.testing
5import Orange.evaluation.scoring
6
7from Orange.misc import deprecated_members
8
9class StepwiseLearner(Orange.core.Learner):
10  def __new__(cls, data=None, weight_id=None, **kwargs):
11      self = Orange.core.Learner.__new__(cls, **kwargs)
12      if data is not None:
13          self.__init__(**kwargs)
14          return self(data, weight_id)
15      else:
16          return self
17     
18  def __init__(self, **kwds):
19    self.remove_threshold = 0.3
20    self.add_threshold = 0.2
21    self.stat, self.statsign = scoring.CA, 1
22    for name, val in kwds.items():
23        setattr(self, name, val)
24
25  def __call__(self, data, weight_id = 0, **kwds):
26    import Orange.evaluation.testing, Orange.evaluation.scoring, statc
27   
28    self.__dict__.update(kwds)
29
30    if self.remove_threshold < self.add_threshold:
31        raise ValueError("'remove_threshold' should be larger or equal to 'add_threshold'")
32
33    classVar = data.domain.classVar
34   
35    indices = Orange.core.MakeRandomIndicesCV(data, folds = getattr(self, "folds", 10))
36    domain = Orange.data.Domain([], classVar)
37
38    res = Orange.evaluation.testing.test_with_indices([self.learner], Orange.data.Table(domain, data), indices)
39   
40    oldStat = self.stat(res)[0]
41    oldStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)]
42    print ".", oldStat, domain
43    stop = False
44    while not stop:
45        stop = True
46        if len(domain.attributes)>=2:
47            bestStat = None
48            for attr in domain.attributes:
49                newdomain = Orange.data.Domain(filter(lambda x: x!=attr, domain.attributes), classVar)
50                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, data), weight_id), indices)
51               
52                newStat = self.stat(res)[0]
53                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)] 
54                print "-", newStat, newdomain
55                ## If stat has increased (ie newStat is better than bestStat)
56                if not bestStat or cmp(newStat, bestStat) == self.statsign:
57                    if cmp(newStat, oldStat) == self.statsign:
58                        bestStat, bestStats, bestAttr = newStat, newStats, attr
59                    elif statc.wilcoxont(oldStats, newStats)[1] > self.remove_threshold:
60                            bestStat, bestAttr, bestStats = newStat, newStats, attr
61            if bestStat:
62                domain = Orange.data.Domain(filter(lambda x: x!=bestAttr, domain.attributes), classVar)
63                oldStat, oldStats = bestStat, bestStats
64                stop = False
65                print "removed", bestAttr.name
66
67        bestStat, bestAttr = oldStat, None
68        for attr in data.domain.attributes:
69            if not attr in domain.attributes:
70                newdomain = Orange.data.Domain(domain.attributes + [attr], classVar)
71                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, data), weight_id), indices)
72               
73                newStat = self.stat(res)[0]
74                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.split_by_iterations(res)] 
75                print "+", newStat, newdomain
76
77                ## If stat has increased (ie newStat is better than bestStat)
78                if cmp(newStat, bestStat) == self.statsign and statc.wilcoxont(oldStats, newStats)[1] < self.add_threshold:
79                    bestStat, bestStats, bestAttr = newStat, newStats, attr
80        if bestAttr:
81            domain = Orange.data.Domain(domain.attributes + [bestAttr], classVar)
82            oldStat, oldStats = bestStat, bestStats
83            stop = False
84            print "added", bestAttr.name
85
86    return self.learner(Orange.data.Table(domain, data), weight_id)
87
88StepwiseLearner = deprecated_members(
89                    {"removeThreshold": "remove_threshold",
90                     "addThreshold": "add_threshold"},
91                    )(StepwiseLearner)
Note: See TracBrowser for help on using the repository browser.