source: orange/Orange/ensemble/boosting.py @ 11630:5cfa71596edd

Revision 11630:5cfa71596edd, 6.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Fixed pickling of Boosted/Bagged/StackedLearner.

Line 
1import Orange
2import Orange.core as orange
3
4_inf = 100000
5
6class BoostedLearner(orange.Learner):
7    """
8    Instead of drawing a series of bootstrap samples from the training set,
9    bootstrap maintains a weight for each instance. When a classifier is
10    trained from the training set, the weights for misclassified instances
11    are increased. Just like in a bagged learner, the class is decided based
12    on voting of classifiers, but in boosting votes are weighted by accuracy
13    obtained on training set.
14
15    BoostedLearner is an implementation of AdaBoost.M1 (Freund and Shapire,
16    1996). From user's viewpoint, the use of the BoostedLearner is similar to
17    that of BaggedLearner. The learner passed as an argument needs to deal
18    with instance weights.
19   
20    :param learner: learner to be boosted.
21    :type learner: :class:`Orange.core.Learner`
22    :param t: number of boosted classifiers created from the instance set.
23    :type t: int
24    :param name: name of the resulting learner.
25    :type name: str
26    :rtype: :class:`Orange.ensemble.boosting.BoostedClassifier` or
27            :class:`Orange.ensemble.boosting.BoostedLearner`
28    """
29    def __new__(cls, learner, instances=None, weight_id=None, **kwargs):
30        self = orange.Learner.__new__(cls, **kwargs)
31        if instances is not None:
32            self.__init__(self, learner, **kwargs)
33            return self.__call__(instances, weight_id)
34        else:
35            return self
36
37    def __init__(self, learner, t=10, name='AdaBoost.M1'):
38        self.t = t
39        self.name = name
40        self.learner = learner
41
42    def __call__(self, instances, orig_weight = 0):
43        """
44        Learn from the given table of data instances.
45       
46        :param instances: data instances to learn from.
47        :type instances: Orange.data.Table
48        :param orig_weight: weight.
49        :type orig_weight: int
50        :rtype: :class:`Orange.ensemble.boosting.BoostedClassifier`
51       
52        """
53        import math
54        weight = Orange.feature.Descriptor.new_meta_id()
55        if orig_weight:
56            for i in instances:
57                i.setweight(weight, i.getweight(orig_weight))
58        else:
59            instances.addMetaAttribute(weight, 1.0)
60           
61        n = len(instances)
62        classifiers = []
63        for i in range(self.t):
64            epsilon = 0.0
65            classifier = self.learner(instances, weight)
66            corr = []
67            for ex in instances:
68                if classifier(ex) != ex.getclass():
69                    epsilon += ex.getweight(weight)
70                    corr.append(0)
71                else:
72                    corr.append(1)
73            epsilon = epsilon / float(reduce(lambda x,y:x+y.getweight(weight), 
74                instances, 0))
75            classifiers.append((classifier, epsilon and math.log(
76                (1-epsilon)/epsilon) or _inf))
77            if epsilon==0 or epsilon >= 0.499:
78                if epsilon >= 0.499 and len(classifiers)>1:
79                    del classifiers[-1]
80                instances.removeMetaAttribute(weight)
81                return BoostedClassifier(classifiers = classifiers, 
82                    name=self.name, class_var=instances.domain.class_var)
83            beta = epsilon/(1-epsilon)
84            for e in range(n):
85                if corr[e]:
86                    instances[e].setweight(weight, instances[e].getweight(weight)*beta)
87            f = 1/float(sum([e.getweight(weight) for e in instances]))
88            for e in range(n):
89                instances[e].setweight(weight, instances[e].getweight(weight)*f)
90
91        instances.removeMetaAttribute(weight)
92        return BoostedClassifier(classifiers = classifiers, name=self.name, 
93            class_var=instances.domain.class_var)
94
95    def __reduce__(self):
96        return type(self), (self.learner,), dict(self.__dict__)
97
98BoostedLearner = Orange.utils.deprecated_members({"examples":"instances", "classVar":"class_var", "weightId":"weigth_id", "origWeight":"orig_weight"})(BoostedLearner)
99
100class BoostedClassifier(orange.Classifier):
101    """
102    A classifier that uses a boosting technique. Usually the learner
103    (:class:`Orange.ensemble.boosting.BoostedLearner`) is used to construct the
104    classifier.
105   
106    When constructing the classifier manually, the following parameters can
107    be passed:
108
109    :param classifiers: a list of boosted classifiers.
110    :type classifiers: list
111   
112    :param name: name of the resulting classifier.
113    :type name: str
114   
115    :param class_var: the class feature.
116    :type class_var: :class:`Orange.feature.Descriptor`
117   
118    """
119
120    def __init__(self, classifiers, name, class_var, **kwds):
121        self.classifiers = classifiers
122        self.name = name
123        self.class_var = class_var
124        self.__dict__.update(kwds)
125
126    def __call__(self, instance, result_type = orange.GetValue):
127        """
128        :param instance: instance to be classified.
129        :type instance: :class:`Orange.data.Instance`
130       
131        :param result_type: :class:`Orange.classification.Classifier.GetValue` or \
132              :class:`Orange.classification.Classifier.GetProbabilities` or
133              :class:`Orange.classification.Classifier.GetBoth`
134       
135        :rtype: :class:`Orange.data.Value`,
136              :class:`Orange.statistics.Distribution` or a tuple with both
137        """
138        votes = Orange.statistics.distribution.Discrete(self.class_var)
139        for c, e in self.classifiers:
140            votes[int(c(instance))] += e
141        index = Orange.utils.selection.select_best_index(votes)
142        # TODO
143        value = Orange.data.Value(self.class_var, index)
144        if result_type == orange.GetValue:
145            return value
146        sv = sum(votes)
147        for i in range(len(votes)):
148            votes[i] = votes[i]/sv
149        if result_type == orange.GetProbabilities:
150            return votes
151        elif result_type == orange.GetBoth:
152            return (value, votes)
153        else:
154            return value
155       
156    def __reduce__(self):
157        return type(self), (self.classifiers, self.name, self.class_var), dict(self.__dict__)
158
159BoostedClassifier = Orange.utils.deprecated_members({"classVar":"class_var", "resultType":"result_type"})(BoostedClassifier)
Note: See TracBrowser for help on using the repository browser.