source: orange/orange/classification/wrappers.py @ 9669:165371b04b4a

Revision 9669:165371b04b4a, 3.8 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved content of Orange dir to package dir

Line 
1import Orange.core
2import Orange.evaluation.scoring as scoring
3import Orange.data
4import Orange.evaluation.testing
5import Orange.evaluation.scoring
6
7class StepwiseLearner(Orange.core.Learner):
8  def __new__(cls, data=None, weightId=None, **kwargs):
9      self = Orange.core.Learner.__new__(cls, **kwargs)
10      if data is not None:
11          self.__init__(**kwargs)
12          return self(data, weightId)
13      else:
14          return self
15     
16  def __init__(self, **kwds):
17    self.removeThreshold = 0.3
18    self.addThreshold = 0.2
19    self.stat, self.statsign = scoring.CA, 1
20    self.__dict__.update(kwds)
21
22  def __call__(self, examples, weightID = 0, **kwds):
23    import Orange.evaluation.testing, Orange.evaluation.scoring, statc
24   
25    self.__dict__.update(kwds)
26
27    if self.removeThreshold < self.addThreshold:
28        raise ValueError("'removeThreshold' should be larger or equal to 'addThreshold'")
29
30    classVar = examples.domain.classVar
31   
32    indices = Orange.core.MakeRandomIndicesCV(examples, folds = getattr(self, "folds", 10))
33    domain = Orange.data.Domain([], classVar)
34
35    res = Orange.evaluation.testing.test_with_indices([self.learner], Orange.data.Table(domain, examples), indices)
36   
37    oldStat = self.stat(res)[0]
38    oldStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.splitByIterations(res)]
39    print ".", oldStat, domain
40    stop = False
41    while not stop:
42        stop = True
43        if len(domain.attributes)>=2:
44            bestStat = None
45            for attr in domain.attributes:
46                newdomain = Orange.data.Domain(filter(lambda x: x!=attr, domain.attributes), classVar)
47                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, examples), weightID), indices)
48               
49                newStat = self.stat(res)[0]
50                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.splitByIterations(res)] 
51                print "-", newStat, newdomain
52                ## If stat has increased (ie newStat is better than bestStat)
53                if not bestStat or cmp(newStat, bestStat) == self.statsign:
54                    if cmp(newStat, oldStat) == self.statsign:
55                        bestStat, bestStats, bestAttr = newStat, newStats, attr
56                    elif statc.wilcoxont(oldStats, newStats)[1] > self.removeThreshold:
57                            bestStat, bestAttr, bestStats = newStat, newStats, attr
58            if bestStat:
59                domain = Orange.data.Domain(filter(lambda x: x!=bestAttr, domain.attributes), classVar)
60                oldStat, oldStats = bestStat, bestStats
61                stop = False
62                print "removed", bestAttr.name
63
64        bestStat, bestAttr = oldStat, None
65        for attr in examples.domain.attributes:
66            if not attr in domain.attributes:
67                newdomain = Orange.data.Domain(domain.attributes + [attr], classVar)
68                res = Orange.evaluation.testing.test_with_indices([self.learner], (Orange.data.Table(newdomain, examples), weightID), indices)
69               
70                newStat = self.stat(res)[0]
71                newStats = [self.stat(x)[0] for x in Orange.evaluation.scoring.splitByIterations(res)] 
72                print "+", newStat, newdomain
73
74                ## If stat has increased (ie newStat is better than bestStat)
75                if cmp(newStat, bestStat) == self.statsign and statc.wilcoxont(oldStats, newStats)[1] < self.addThreshold:
76                    bestStat, bestStats, bestAttr = newStat, newStats, attr
77        if bestAttr:
78            domain = Orange.data.Domain(domain.attributes + [bestAttr], classVar)
79            oldStat, oldStats = bestStat, bestStats
80            stop = False
81            print "added", bestAttr.name
82
83    return self.learner(Orange.data.Table(domain, examples), weightID)
84
Note: See TracBrowser for help on using the repository browser.