Changeset 9317:6341e2ee5017 in orange


Ignore:
Timestamp:
12/06/11 15:00:18 (2 years ago)
Author:
markotoplak
Branch:
default
Convert:
14d3af9de3bf921a6c454082fc914cde63945f17
Message:

Random forest: changed constructor arguments.

Location:
orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/ensemble/forest.py

    r9316 r9317  
    77 
    88def _default_small_learner(attributes=None, rand=None, base=None): 
    9     """ A helper function for constructing a small tree learner for use 
    10     in the RandomForestLearner. 
    11     :param instances: data - to select the measure. 
    12     :param attributes: number of features used in a randomly drawn 
    13             subset when searching for best feature to split the node 
    14             in tree growing 
    15     :type attributes: int 
    16     :param rand: random generator used in feature subset selection in split constructor.  
    17             If None is passed, then Python's Random from random library is  
    18             used, with seed initialized to 0. 
    19     :type rand: function 
    20     """ 
    219    # tree learner assembled as suggested by Breiman (2001) 
    2210    if not base: 
     
    2513            store_distributions=1, min_instances=5) 
    2614 
    27     return _RandomForestTreeLearner(base=base, attributes=attributes, rand=rand) 
    28  
    29 class SimpleTreeLearnerSetProb(Orange.classification.tree.SimpleTreeLearner): 
    30     """ 
    31     :class:`Orange.classification.tree.SimpleTreeLearner` which sets the  
    32     skip_prob so that the number of randomly chosen features for each split 
    33     is  (on average) as the user specified (default: a square number of 
    34     attributes). 
    35     """ 
    36     def __init__(self, *args, **kwargs): 
    37         self.attributes = kwargs.pop("attributes", None) 
    38         Orange.classification.tree.SimpleTreeLearner.__init__(self, *args, **kwargs) 
    39  
     15    return _RandomForestTreeLearner(base=base, rand=rand) 
     16 
     17def _default_simple_learner(base, randorange): 
     18    if base == None: 
     19        base = Orange.classification.tree.SimpleTreeLearner(min_instances=5) 
     20    return _RandomForestSimpleTreeLearner(base=base, rand=randorange) 
     21 
     22def _wrap_learner(base, rand, randorange): 
     23    if base == None or isinstance(base, Orange.classification.tree.SimpleTreeLearner): 
     24        return _default_simple_learner(base, randorange) 
     25    elif isinstance(base, Orange.classification.tree.TreeLearner): 
     26        return _default_small_learner(None, rand, base) 
     27    else: 
     28        notRightLearnerToWrap() 
     29  
     30class _RandomForestSimpleTreeLearner(Orange.core.Learner): 
     31    """ A learner which wraps an ordinary SimpleTreeLearner.  Sets the 
     32    skip_prob so that the number of randomly chosen features for each 
     33    split is  (on average) as specified.  """ 
     34 
     35    def __new__(cls, examples = None, weightID = 0, **argkw): 
     36        self = Orange.core.Learner.__new__(cls, **argkw) 
     37        if examples: 
     38            self.__init__(**argkw) 
     39            return self.__call__(examples, weightID) 
     40        else: 
     41            return self 
     42       
     43    def __init__(self, base, rand): 
     44        self.base = base 
     45        self.attributes = None 
     46        self.rand = rand 
     47     
    4048    def __call__(self, examples, weight=0): 
    41         self.skip_prob = 1-float(self.attributes)/len(examples.domain.attributes) 
    42         return Orange.classification.tree.SimpleTreeLearner.__call__(self, examples, weight) 
    43  
     49        osp,orand = self.base.skip_prob, self.base.random_generator 
     50        self.base.skip_prob = 1-float(self.attributes)/len(examples.domain.attributes) 
     51        self.base.random_generator = self.rand 
     52        r = self.base(examples, weight) 
     53        self.base.skip_prob, self.base.random_generator = osp, orand 
     54        return r 
     55    
    4456class RandomForestLearner(orange.Learner): 
    4557    """ 
     
    6274        randomized with Random Forest's random 
    6375        feature subset selection.  If None (default), 
    64         :class:`~Orange.classification.tree.TreeLearner` with Gini index 
    65         or MSE for attribute scoring will be used, and it will not split 
     76        :class:`~Orange.classification.tree.SimpleTreeLearner` and it will not split 
    6677        nodes with less than 5 data instances. 
    67     :type base_learner: None or :class:`Orange.classification.tree.TreeLearner` 
     78    :type base_learner: None or :class:`Orange.classification.tree.TreeLearner` or  
     79        :class:`Orange.classification.tree.SimpleTreeLearner` 
    6880    :param rand: random generator used in bootstrap sampling. If None (default),  
    6981        then ``random.Random(0)`` is used. 
    70     :param learner: Tree induction learner. If `"fast"` (default), 
    71         :obj:`~Orange.classification.tree.SimpleTreeLearner` with 
    72         random feature subset selection will be used.  If `None`,  
     82    :param learner: Tree induction learner. If `None` (default),  
    7383        the :obj:`base_learner` will be used (and randomized). If 
    7484        :obj:`learner` is specified, it will be used as such 
     
    93103 
    94104    def __init__(self, trees=100, attributes=None,\ 
    95                     name='Random Forest', rand=None, callback=None, base_learner=None, learner="fast"): 
     105                    name='Random Forest', rand=None, callback=None, base_learner=None, learner=None): 
    96106        self.trees = trees 
    97107        self.name = name 
     
    102112        self.base_learner = base_learner 
    103113 
    104         if base_learner != None and learner not in [ None, "fast" ]: 
     114        if base_learner != None and learner != None: 
    105115            wrongSpecification() 
    106         elif base_learner != None: 
    107             learner = None   #build with base_learner 
    108116 
    109117        if not self.rand: 
    110118            self.rand = random.Random(0) 
    111  
    112119        self.randorange = Orange.core.RandomGenerator(self.rand.randint(0,2**31-1)) 
    113120 
    114         if learner == "fast": 
    115             self.learner = SimpleTreeLearnerSetProb(min_instances=5, random_generator=self.randorange) 
    116         elif learner == None: 
    117             self.learner = _default_small_learner(self.attributes, self.rand, base=self.base_learner) 
     121        if learner == None: 
     122            self.learner = _wrap_learner(base=self.base_learner, rand=self.rand, randorange=self.randorange) 
    118123        else: 
    119124            self.learner = learner 
    120              
     125            
    121126        self.randstate = self.rand.getstate() #original state 
    122127 
     
    460465            return self 
    461466       
    462     def __init__(self, base, attributes, rand): 
     467    def __init__(self, base, rand): 
    463468        self.base = base 
    464         self.attributes = attributes 
     469        self.attributes = None 
    465470        self.rand = rand 
    466471        if not self.rand: #for all the built trees 
     
    480485                else Orange.feature.scoring.MSE() 
    481486 
    482         #ats = self.attributes if self.attributes else int(sqrt(len(candidates))) 
    483         ats = self.attributes 
    484  
    485487        bcopy.split = SplitConstructor_AttributeSubset(\ 
    486             bcopy.split, ats, self.rand) 
     488            bcopy.split, self.attributes, self.rand) 
    487489 
    488490        return bcopy(examples, weight=weight) 
  • orange/doc/modules/ensemble2.py

    r9312 r9317  
    88 
    99data = orange.ExampleTable('bupa.tab') 
    10 import random 
     10tree = orngTree.TreeLearner(minExamples=2, mForPrunning=2, \ 
     11                            sameMajorityPruning=True, name='tree') 
    1112forest = orngEnsemble.RandomForestLearner(trees=50, name="forest") 
    12 tree = orngTree.TreeLearner(min_instances=2, m_for_prunning=2, \ 
    13                             same_majority_pruning=True, name='tree') 
    1413learners = [tree, forest] 
    1514 
Note: See TracChangeset for help on using the changeset viewer.