source: orange/Orange/ensemble/forest.py @ 10565:51a6505b8231

Revision 10565:51a6505b8231, 20.5 KB checked in by markotoplak, 2 years ago (diff)

Fixed pickling of RandomForestLearner.

Line 
1from math import sqrt, floor
2import Orange.core as orange
3import Orange
4import Orange.feature.scoring
5import random
6import copy
7from Orange.misc import deprecated_keywords
8
9def _default_small_learner(attributes=None, rand=None, base=None):
10    # tree learner assembled as suggested by Breiman (2001)
11    if not base:
12        base = Orange.classification.tree.TreeLearner(
13            store_node_classifier=0, store_contingencies=0, 
14            store_distributions=1, min_instances=5)
15
16    return _RandomForestTreeLearner(base=base, rand=rand)
17
18def _default_simple_learner(base, randorange):
19    if base == None:
20        base = Orange.classification.tree.SimpleTreeLearner(min_instances=5)
21    return _RandomForestSimpleTreeLearner(base=base, rand=randorange)
22
23def _wrap_learner(base, rand, randorange):
24    if base == None or isinstance(base, Orange.classification.tree.SimpleTreeLearner):
25        return _default_simple_learner(base, randorange)
26    elif isinstance(base, Orange.classification.tree.TreeLearner):
27        return _default_small_learner(None, rand, base)
28    else:
29        notRightLearnerToWrap()
30 
31class _RandomForestSimpleTreeLearner(Orange.core.Learner):
32    """A learner which wraps an ordinary SimpleTreeLearner.  Sets the
33    skip_prob so that the number of randomly chosen features for each
34    split is  (on average) as specified."""
35
36    __new__ = Orange.misc._orange__new__(Orange.core.Learner)
37
38    def __init__(self, base=None, rand=None): #pickle needs an empty init
39        self.base = base
40        self.attributes = None
41        self.rand = rand
42   
43    def __call__(self, instances, weight=0):
44        osp,orand = self.base.skip_prob, self.base.random_generator
45        self.base.skip_prob = 1-float(self.attributes)/len(instances.domain.attributes)
46        self.base.random_generator = self.rand
47        r = self.base(instances, weight)
48        self.base.skip_prob, self.base.random_generator = osp, orand
49        return r
50
51_RandomForestSimpleTreeLearner = Orange.misc.deprecated_members({"weightID":"weight_id", "examples":"instances"})(_RandomForestSimpleTreeLearner)
52   
53class RandomForestLearner(Orange.core.Learner):
54    """
55    Trains an ensemble predictor consisting of trees trained
56    on bootstrap
57    samples of training data. To increase
58    randomness, the tree learner considers only a subset of
59    candidate features at each node. The algorithm closely follows
60    the original procedure (Brieman, 2001) both in implementation and parameter
61    defaults.
62       
63    :param trees: number of trees in the forest.
64    :type trees: int
65
66    :param attributes: number of randomly drawn features among
67            which to select the best one to split the data sets
68            in tree nodes. The default, None, means the square root of
69            the number of features in the training data. Ignored if
70            :obj:`learner` is specified.
71    :type attributes: int
72
73    :param base_learner: A base tree learner. The base learner will be
74        randomized with Random Forest's random
75        feature subset selection.  If None (default),
76        :class:`~Orange.classification.tree.SimpleTreeLearner` and it
77        will not split nodes with less than 5 data instances.
78    :type base_learner: None or
79        :class:`Orange.classification.tree.TreeLearner` or
80        :class:`Orange.classification.tree.SimpleTreeLearner`
81
82    :param rand: random generator used in bootstrap sampling. If None (default),
83        then ``random.Random(0)`` is used.
84
85    :param learner: Tree induction learner. If `None` (default),
86        the :obj:`base_learner` will be used (and randomized). If
87        :obj:`learner` is specified, it will be used as such
88        with no additional transformations.
89    :type learner: None or :class:`Orange.core.Learner`
90
91    :param callback: a function to be called after every iteration of
92            induction of classifier. The call includes a parameter
93            (from 0.0 to 1.0) that provides an estimate
94            of completion of the learning progress.
95
96    :param name: learner name.
97    :type name: string
98
99    :rtype: :class:`~Orange.ensemble.forest.RandomForestClassifier` or
100            :class:`~Orange.ensemble.forest.RandomForestLearner`
101
102    """
103
104    __new__ = Orange.misc._orange__new__(Orange.core.Learner)
105
106    def __init__(self, trees=100, attributes=None,\
107                    name='Random Forest', rand=None, callback=None, base_learner=None, learner=None):
108        self.trees = trees
109        self.name = name
110        self.attributes = attributes
111        self.callback = callback
112        self.rand = rand
113
114        self.base_learner = base_learner
115
116        if base_learner != None and learner != None:
117            wrongSpecification()
118
119        if not self.rand:
120            self.rand = random.Random(0)
121        self.randorange = Orange.misc.Random(self.rand.randint(0,2**31-1))
122
123        if learner == None:
124            self.learner = _wrap_learner(base=self.base_learner, rand=self.rand, randorange=self.randorange)
125        else:
126            self.learner = learner
127           
128        self.randstate = self.rand.getstate() #original state
129
130    def __call__(self, instances, weight=0):
131        """
132        Learn from the given table of data instances.
133       
134        :param instances: learning data.
135        :type instances: class:`Orange.data.Table`
136        :param weight: weight.
137        :type weight: int
138        :rtype: :class:`Orange.ensemble.forest.RandomForestClassifier`
139        """
140        self.rand.setstate(self.randstate) #when learning again, set the same state
141        self.randorange.reset()       
142
143        if "attributes" in self.learner.__dict__:
144            self.learner.attributes = len(instances.domain.attributes)**0.5 if self.attributes == None else self.attributes
145
146        learner = self.learner
147
148        n = len(instances)
149        # build the forest
150        classifiers = []
151        for i in range(self.trees):
152            # draw bootstrap sample
153            selection = []
154            for j in range(n):
155                selection.append(self.rand.randrange(n))
156            data = instances.get_items_ref(selection)
157            # build the model from the bootstrap sample
158            classifiers.append(learner(data, weight))
159            if self.callback:
160                self.callback()
161            # if self.callback: self.callback((i+1.)/self.trees)
162
163        return RandomForestClassifier(classifiers = classifiers, name=self.name,\
164                    domain=instances.domain, class_var=instances.domain.class_var)
165RandomForestLearner = Orange.misc.deprecated_members({"examples":"instances"})(RandomForestLearner)
166
167class RandomForestClassifier(orange.Classifier):
168    """
169    Uses the trees induced by the :obj:`RandomForestLearner`. An input
170    instance is classified into the class with the most frequent vote.
171    However, this implementation returns the averaged probabilities from
172    each of the trees if class probability is requested.
173
174    When constructed manually, the following parameters have to
175    be passed:
176
177    :param classifiers: a list of classifiers to be used.
178    :type classifiers: list
179   
180    :param name: name of the resulting classifier.
181    :type name: str
182   
183    :param domain: the domain of the learning set.
184    :type domain: :class:`Orange.data.Domain`
185   
186    :param class_var: the class feature.
187    :type class_var: :class:`Orange.feature.Descriptor`
188
189    """
190    def __init__(self, classifiers, name, domain, class_var, **kwds):
191        self.classifiers = classifiers
192        self.name = name
193        self.domain = domain
194        self.class_var = class_var
195        self.__dict__.update(kwds)
196
197    def __call__(self, instance, result_type = orange.GetValue):
198        """
199        :param instance: instance to be classified.
200        :type instance: :class:`Orange.data.Instance`
201       
202        :param result_type: :class:`Orange.classification.Classifier.GetValue` or \
203              :class:`Orange.classification.Classifier.GetProbabilities` or
204              :class:`Orange.classification.Classifier.GetBoth`
205       
206        :rtype: :class:`Orange.data.Value`,
207              :class:`Orange.statistics.Distribution` or a tuple with both
208        """
209        from operator import add
210       
211        # handle discreete class
212       
213        if self.class_var.var_type == Orange.feature.Discrete.Discrete:
214       
215            # voting for class probabilities
216            if result_type == orange.GetProbabilities or result_type == orange.GetBoth:
217                prob = [0.] * len(self.domain.class_var.values)
218                for c in self.classifiers:
219                    a = [x for x in c(instance, orange.GetProbabilities)]
220                    prob = map(add, prob, a)
221                norm = sum(prob)
222                cprob = Orange.statistics.distribution.Discrete(self.class_var)
223                for i in range(len(prob)):
224                    cprob[i] = prob[i]/norm
225               
226            # voting for crisp class membership, notice that
227            # this may not be the same class as one obtaining the
228            # highest probability through probability voting
229            if result_type == orange.GetValue or result_type == orange.GetBoth:
230                cfreq = [0] * len(self.domain.class_var.values)
231                for c in self.classifiers:
232                    cfreq[int(c(instance))] += 1
233                index = cfreq.index(max(cfreq))
234                cvalue = Orange.data.Value(self.domain.class_var, index)
235   
236            if result_type == orange.GetValue: return cvalue
237            elif result_type == orange.GetProbabilities: return cprob
238            else: return (cvalue, cprob)
239       
240        else:
241            # Handle continuous class
242       
243            # voting for class probabilities
244            if result_type == orange.GetProbabilities or result_type == orange.GetBoth:
245                probs = [c(instance, orange.GetBoth) for c in self.classifiers]
246                cprob = dict()
247                for val,prob in probs:
248                    if prob != None: #no probability output
249                        a = dict(prob.items())
250                    else:
251                        a = { val.value : 1. }
252                    cprob = dict( (n, a.get(n, 0)+cprob.get(n, 0)) for n in set(a)|set(cprob) )
253                cprob = Orange.statistics.distribution.Continuous(cprob)
254                cprob.normalize()
255               
256            # gather average class value
257            if result_type == orange.GetValue or result_type == orange.GetBoth:
258                values = [c(instance).value for c in self.classifiers]
259                cvalue = Orange.data.Value(self.domain.class_var, sum(values) / len(self.classifiers))
260           
261            if result_type == orange.GetValue: return cvalue
262            elif result_type == orange.GetProbabilities: return cprob
263            else: return (cvalue, cprob)
264           
265    def __reduce__(self):
266        return type(self), (self.classifiers, self.name, self.domain, self.class_var), dict(self.__dict__)
267RandomForestClassifier = Orange.misc.deprecated_members({"resultType":"result_type", "classVar":"class_var", "example":"instance"})(RandomForestClassifier)
268### MeasureAttribute_randomForests
269
270class ScoreFeature(orange.MeasureAttribute):
271    """
272    :param trees: number of trees in the forest.
273    :type trees: int
274
275    :param attributes: number of randomly drawn features among
276            which to select the best to split the nodes in tree
277            induction. The default, None, means the square root of
278            the number of features in the training data. Ignored if
279            :obj:`learner` is specified.
280    :type attributes: int
281
282    :param base_learner: A base tree learner. The base learner will be
283        randomized with Random Forest's random
284        feature subset selection.  If None (default),
285        :class:`~Orange.classification.tree.SimpleTreeLearner` and it
286        will not split nodes with less than 5 data instances.
287    :type base_learner: None or
288        :class:`Orange.classification.tree.TreeLearner` or
289        :class:`Orange.classification.tree.SimpleTreeLearner`
290
291    :param rand: random generator used in bootstrap sampling. If None (default),
292        then ``random.Random(0)`` is used.
293
294    :param learner: Tree induction learner. If `None` (default),
295        the :obj:`base_learner` will be used (and randomized). If
296        :obj:`learner` is specified, it will be used as such
297        with no additional transformations.
298    :type learner: None or :class:`Orange.core.Learner`
299
300    """
301    def __init__(self, trees=100, attributes=None, rand=None, base_learner=None, learner=None):
302
303        self.trees = trees
304        self.learner = learner
305        self._bufinstances = None
306        self.attributes = attributes
307        self.rand = rand
308        self.base_learner = base_learner
309
310        if base_learner != None and learner != None:
311            wrongSpecification()
312
313        if not self.rand:
314            self.rand = random.Random(0)
315        self.randorange = Orange.misc.Random(self.rand.randint(0,2**31-1))
316
317        if learner == None:
318            self.learner = _wrap_learner(base=self.base_learner, rand=self.rand, randorange=self.randorange)
319        else:
320            self.learner = learner
321
322    @deprecated_keywords({"apriorClass":"aprior_class"})
323    def __call__(self, feature, instances, aprior_class=None):
324        """
325        Return importance of a given feature.
326        Only the first call on a given data set is computationally expensive.
327       
328        :param feature: feature to evaluate (by index, name or
329            :class:`Orange.feature.Descriptor` object).
330        :type feature: int, str or :class:`Orange.feature.Descriptor`.
331       
332        :param instances: data instances to use for importance evaluation.
333        :type instances: :class:`Orange.data.Table`
334       
335        :param aprior_class: not used!
336       
337        """
338        attrNo = None
339
340        if type(feature) == int: #by attr. index
341          attrNo  = feature
342        elif type(feature) == type("a"): #by attr. name
343          attrName = feature
344          attrNo = instances.domain.index(attrName)
345        elif isinstance(feature, Orange.feature.Descriptor):
346          atrs = [a for a in instances.domain.attributes]
347          attrNo = atrs.index(feature)
348        else:
349          raise Exception("MeasureAttribute_rf can not be called with (\
350                contingency,classDistribution, aprior_class) as fuction arguments.")
351
352        self._buffer(instances)
353
354        return self._avimp[attrNo]*100/self._acu
355
356    def importances(self, table):
357        """
358        DEPRECATED. Return importance of all features in the dataset as a list.
359       
360        :param table: dataset of which the features' importance needs to be
361            measured.
362        :type table: :class:`Orange.data.Table`
363
364        """
365        self._buffer(table)
366        return [a*100/self._acu for a in self._avimp]
367
368    def _buffer(self, instances):
369        """
370        Recalculate importance of features if needed (ie. if it has been
371        buffered for the given dataset yet).
372
373        :param table: dataset of which the features' importance needs to be
374            measured.
375        :type table: :class:`Orange.data.Table`
376
377        """
378        if instances != self._bufinstances or \
379            instances.version != self._bufinstances.version:
380
381            self._bufinstances = instances
382            self._avimp = [0.0]*len(self._bufinstances.domain.attributes)
383            self._acu = 0
384            self._importanceAcu(self._bufinstances, self.trees, self._avimp)
385     
386    def _getOOB(self, instances, selection, nexamples):
387        ooblist = filter(lambda x: x not in selection, range(nexamples))
388        return instances.getitems(ooblist)
389
390    def _numRight(self, oob, classifier):
391        """
392        Return a number of instances which are classified correctly.
393        """
394        #TODO How to accomodate regression?
395        return sum(1 for el in oob if el.getclass() == classifier(el))
396   
397    def _numRightMix(self, oob, classifier, attr):
398        """
399        Return a number of instances which are classified
400        correctly even if a feature is shuffled.
401        """
402        perm = range(len(oob))
403        self.rand.shuffle(perm)
404
405        def shuffle_ex(index):
406            ex = Orange.data.Instance(oob[index])
407            ex[attr] = oob[perm[index]][attr]
408            return ex
409        #TODO How to accomodate regression?
410        return sum(1 for i in range(len(oob)) if oob[i].getclass() == classifier(shuffle_ex(i)))
411
412    def _importanceAcu(self, instances, trees, avimp):
413        """Accumulate avimp by importances for a given number of trees."""
414        n = len(instances)
415
416        attrs = len(instances.domain.attributes)
417
418        attrnum = {}
419        for attr in range(len(instances.domain.attributes)):
420           attrnum[instances.domain.attributes[attr].name] = attr           
421
422        if "attributes" in self.learner.__dict__:
423            self.learner.attributes = len(instances.domain.attributes)**0.5 if self.attributes == None else self.attributes
424
425        # build the forest
426        classifiers = [] 
427        for i in range(trees):
428            # draw bootstrap sample
429            selection = []
430            for j in range(n):
431                selection.append(self.rand.randrange(n))
432            data = instances.getitems(selection)
433           
434            # build the model from the bootstrap sample
435            cla = self.learner(data)
436
437            #prepare OOB data
438            oob = self._getOOB(instances, selection, n)
439           
440            #right on unmixed
441            right = self._numRight(oob, cla)
442           
443            presl = range(attrs)
444            try: #FIXME SimpleTreeLearner does not know how to output attributes yet
445                presl = list(self._presentInTree(cla.tree, attrnum))
446            except:
447                pass
448                     
449            #randomize each feature in data and test
450            #only those on which there was a split
451            for attr in presl:
452                #calculate number of right classifications
453                #if the values of this features are permutated randomly
454                rightimp = self._numRightMix(oob, cla, attr)               
455                avimp[attr] += (float(right-rightimp))/len(oob)
456        self._acu += trees 
457
458    def _presentInTree(self, node, attrnum):
459        """Return features present in tree (features that split)."""
460        if not node:
461          return set([])
462
463        if  node.branchSelector:
464            j = attrnum[node.branchSelector.class_var.name]
465            cs = set([])
466            for i in range(len(node.branches)):
467                s = self._presentInTree(node.branches[i], attrnum)
468                cs = s | cs
469            cs = cs | set([j])
470            return cs
471        else:
472          return set([])
473
474class _RandomForestTreeLearner(Orange.core.Learner):
475    """ A learner which wraps an ordinary TreeLearner with
476    a new split constructor.
477    """
478
479    __new__ = Orange.misc._orange__new__(Orange.core.Learner)
480     
481    def __init__(self, base, rand):
482        self.base = base
483        self.attributes = None
484        self.rand = rand
485        if not self.rand: #for all the built trees
486            self.rand = random.Random(0)
487
488    @deprecated_keywords({"examples":"instances"})
489    def __call__(self, instances, weight=0):
490        """ A current tree learner is copied, modified and then used.
491        Modification: set a different split constructor, which uses
492        a random subset of attributes.
493        """
494        bcopy = copy.copy(self.base)
495
496        #if base tree learner has no measure set
497        if not bcopy.measure:
498            bcopy.measure = Orange.feature.scoring.Gini() \
499                if isinstance(instances.domain.class_var, Orange.feature.Discrete) \
500                else Orange.feature.scoring.MSE()
501
502        bcopy.split = SplitConstructor_AttributeSubset(\
503            bcopy.split, self.attributes, self.rand)
504
505        return bcopy(instances, weight=weight)
506
507class SplitConstructor_AttributeSubset(orange.TreeSplitConstructor):
508    def __init__(self, scons, attributes, rand = None):
509        self.scons = scons           # split constructor of original tree
510        self.attributes = attributes # number of features to consider
511        self.rand = rand
512        if not self.rand:
513            self.rand = random.Random(0)
514
515    @deprecated_keywords({"weightID":"weight_id"})
516    def __call__(self, gen, weight_id, contingencies, apriori, candidates, clsfr):
517        # if number of features for subset is not set, use square root
518        cand = [1]*int(self.attributes) + [0]*(len(candidates) - int(self.attributes))
519        self.rand.shuffle(cand)
520        # instead with all features, we will invoke split constructor
521        # only for the subset of a features
522        t = self.scons(gen, weight_id, contingencies, apriori, cand, clsfr)
523        return t
Note: See TracBrowser for help on using the repository browser.