source: orange/Orange/ensemble/boosting.py @ 11630:5cfa71596edd

Revision 11630:5cfa71596edd, 6.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Fixed pickling of Boosted/Bagged/StackedLearner.

RevLine 
[8042]1import Orange
2import Orange.core as orange
3
4_inf = 100000
5
6class BoostedLearner(orange.Learner):
7    """
8    Instead of drawing a series of bootstrap samples from the training set,
9    bootstrap maintains a weight for each instance. When a classifier is
10    trained from the training set, the weights for misclassified instances
11    are increased. Just like in a bagged learner, the class is decided based
12    on voting of classifiers, but in boosting votes are weighted by accuracy
13    obtained on training set.
14
15    BoostedLearner is an implementation of AdaBoost.M1 (Freund and Shapire,
16    1996). From user's viewpoint, the use of the BoostedLearner is similar to
17    that of BaggedLearner. The learner passed as an argument needs to deal
18    with instance weights.
19   
20    :param learner: learner to be boosted.
21    :type learner: :class:`Orange.core.Learner`
22    :param t: number of boosted classifiers created from the instance set.
23    :type t: int
24    :param name: name of the resulting learner.
25    :type name: str
26    :rtype: :class:`Orange.ensemble.boosting.BoostedClassifier` or
27            :class:`Orange.ensemble.boosting.BoostedLearner`
28    """
[9733]29    def __new__(cls, learner, instances=None, weight_id=None, **kwargs):
[8042]30        self = orange.Learner.__new__(cls, **kwargs)
31        if instances is not None:
32            self.__init__(self, learner, **kwargs)
[9733]33            return self.__call__(instances, weight_id)
[8042]34        else:
35            return self
36
37    def __init__(self, learner, t=10, name='AdaBoost.M1'):
38        self.t = t
39        self.name = name
40        self.learner = learner
41
[9733]42    def __call__(self, instances, orig_weight = 0):
[8042]43        """
44        Learn from the given table of data instances.
45       
46        :param instances: data instances to learn from.
47        :type instances: Orange.data.Table
[9733]48        :param orig_weight: weight.
49        :type orig_weight: int
[8042]50        :rtype: :class:`Orange.ensemble.boosting.BoostedClassifier`
51       
52        """
53        import math
[9936]54        weight = Orange.feature.Descriptor.new_meta_id()
[9733]55        if orig_weight:
[8042]56            for i in instances:
[9733]57                i.setweight(weight, i.getweight(orig_weight))
[8042]58        else:
59            instances.addMetaAttribute(weight, 1.0)
60           
61        n = len(instances)
62        classifiers = []
63        for i in range(self.t):
64            epsilon = 0.0
65            classifier = self.learner(instances, weight)
66            corr = []
67            for ex in instances:
68                if classifier(ex) != ex.getclass():
69                    epsilon += ex.getweight(weight)
70                    corr.append(0)
71                else:
72                    corr.append(1)
73            epsilon = epsilon / float(reduce(lambda x,y:x+y.getweight(weight), 
74                instances, 0))
75            classifiers.append((classifier, epsilon and math.log(
76                (1-epsilon)/epsilon) or _inf))
77            if epsilon==0 or epsilon >= 0.499:
78                if epsilon >= 0.499 and len(classifiers)>1:
79                    del classifiers[-1]
80                instances.removeMetaAttribute(weight)
81                return BoostedClassifier(classifiers = classifiers, 
[9733]82                    name=self.name, class_var=instances.domain.class_var)
[8042]83            beta = epsilon/(1-epsilon)
84            for e in range(n):
85                if corr[e]:
86                    instances[e].setweight(weight, instances[e].getweight(weight)*beta)
87            f = 1/float(sum([e.getweight(weight) for e in instances]))
88            for e in range(n):
89                instances[e].setweight(weight, instances[e].getweight(weight)*f)
90
91        instances.removeMetaAttribute(weight)
92        return BoostedClassifier(classifiers = classifiers, name=self.name, 
[9733]93            class_var=instances.domain.class_var)
[11630]94
95    def __reduce__(self):
96        return type(self), (self.learner,), dict(self.__dict__)
97
[10580]98BoostedLearner = Orange.utils.deprecated_members({"examples":"instances", "classVar":"class_var", "weightId":"weigth_id", "origWeight":"orig_weight"})(BoostedLearner)
[8042]99
100class BoostedClassifier(orange.Classifier):
101    """
102    A classifier that uses a boosting technique. Usually the learner
103    (:class:`Orange.ensemble.boosting.BoostedLearner`) is used to construct the
104    classifier.
105   
106    When constructing the classifier manually, the following parameters can
107    be passed:
108
109    :param classifiers: a list of boosted classifiers.
110    :type classifiers: list
111   
112    :param name: name of the resulting classifier.
113    :type name: str
114   
[9733]115    :param class_var: the class feature.
[9919]116    :type class_var: :class:`Orange.feature.Descriptor`
[8042]117   
118    """
119
[9733]120    def __init__(self, classifiers, name, class_var, **kwds):
[8042]121        self.classifiers = classifiers
122        self.name = name
[9733]123        self.class_var = class_var
[8042]124        self.__dict__.update(kwds)
125
[9733]126    def __call__(self, instance, result_type = orange.GetValue):
[8042]127        """
128        :param instance: instance to be classified.
129        :type instance: :class:`Orange.data.Instance`
130       
131        :param result_type: :class:`Orange.classification.Classifier.GetValue` or \
132              :class:`Orange.classification.Classifier.GetProbabilities` or
133              :class:`Orange.classification.Classifier.GetBoth`
134       
135        :rtype: :class:`Orange.data.Value`,
136              :class:`Orange.statistics.Distribution` or a tuple with both
137        """
[9733]138        votes = Orange.statistics.distribution.Discrete(self.class_var)
[8042]139        for c, e in self.classifiers:
140            votes[int(c(instance))] += e
[10654]141        index = Orange.utils.selection.select_best_index(votes)
[8042]142        # TODO
[9733]143        value = Orange.data.Value(self.class_var, index)
144        if result_type == orange.GetValue:
[8042]145            return value
146        sv = sum(votes)
147        for i in range(len(votes)):
148            votes[i] = votes[i]/sv
[9733]149        if result_type == orange.GetProbabilities:
[8042]150            return votes
[9733]151        elif result_type == orange.GetBoth:
[8042]152            return (value, votes)
153        else:
154            return value
155       
156    def __reduce__(self):
[9733]157        return type(self), (self.classifiers, self.name, self.class_var), dict(self.__dict__)
158
[10580]159BoostedClassifier = Orange.utils.deprecated_members({"classVar":"class_var", "resultType":"result_type"})(BoostedClassifier)
Note: See TracBrowser for help on using the repository browser.