Changeset 9310:981ae7a5873b in orange


Ignore:
Timestamp:
12/06/11 13:20:05 (2 years ago)
Author:
markotoplak
Branch:
default
Convert:
a22530a3b23748dfcb7c02bc312f5a588cce38ea
Message:

random forest uses simpletreelearner by default - preliminary commit

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/ensemble/forest.py

    r9037 r9310  
    2626 
    2727    return _RandomForestTreeLearner(base=base, attributes=attributes, rand=rand) 
     28 
     29class SimpleTreeLearnerSetProb(Orange.classification.tree.SimpleTreeLearner): 
     30    """ 
     31    :class:`Orange.classification.tree.SimpleTreeLearner` which sets the  
     32    skip_prob so that the number of randomly chosen features for each split 
     33    is  (on average) as the user specified (default: a square number of 
     34    attributes). 
     35    """ 
     36    def __init__(self, *args, **kwargs): 
     37        self.attributes = kwargs.pop("attributes", None) 
     38        Orange.classification.tree.SimpleTreeLearner.__init__(self, *args, **kwargs) 
     39 
     40    def __call__(self, examples, weight=0): 
     41        self.skip_prob = 1-float(self.attributes)/len(examples.domain.attributes) 
     42        return Orange.classification.tree.SimpleTreeLearner.__call__(self, examples, weight) 
    2843 
    2944class RandomForestLearner(orange.Learner): 
     
    7691 
    7792    def __init__(self, trees=100, attributes=None,\ 
    78                     name='Random Forest', rand=None, callback=None, base_learner=None, learner=None): 
     93                    name='Random Forest', rand=None, callback=None, base_learner=None, learner="fast"): 
    7994        self.trees = trees 
    8095        self.name = name 
    81         self.learner = learner 
    8296        self.attributes = attributes 
    8397        self.callback = callback 
    8498        self.rand = rand 
     99 
    85100        self.base_learner = base_learner 
    86          
     101 
     102        if base_learner != None and learner not in [ None, "orig", "fast" ]: 
     103            wrongSpecification() 
     104        elif base_learner != None or learner == "orig": 
     105            learner = None   #build with base_learner 
     106 
    87107        if not self.rand: 
    88108            self.rand = random.Random(0) 
     109 
     110        if learner == "fast": 
     111            self.learner = SimpleTreeLearnerSetProb(min_instances=5) 
     112        elif learner == None: 
     113            self.learner = _default_small_learner(self.attributes, self.rand, base=self.base_learner) 
     114        else: 
     115            self.learner = learner 
    89116             
    90117        self.randstate = self.rand.getstate() #original state 
     
    96123        :param instances: learning data. 
    97124        :type instances: class:`Orange.data.Table` 
    98         :param origWeight: weight. 
    99         :type origWeight: int 
     125        :param weight: weight. 
     126        :type weight: int 
    100127        :rtype: :class:`Orange.ensemble.forest.RandomForestClassifier` 
    101128        """ 
    102          
    103         if not self.learner: 
    104             learner = _default_small_learner(self.attributes, self.rand, base=self.base_learner) 
    105         else: 
    106             learner = self.learner 
    107          
    108129        self.rand.setstate(self.randstate) #when learning again, set the same state 
     130 
     131        if "attributes" in self.learner.__dict__: 
     132            self.learner.attributes = len(instances.domain.attributes)**0.5 if self.attributes == None else self.attributes 
     133 
     134        learner = self.learner 
    109135 
    110136        n = len(instances) 
     
    202228        else: 
    203229            # Handle continuous class 
    204              
     230         
    205231            # voting for class probabilities 
    206232            if resultType == orange.GetProbabilities or resultType == orange.GetBoth: 
    207                 probs = [c(instance, orange.GetProbabilities) for c in self.classifiers] 
     233                probs = [c(instance, orange.GetBoth) for c in self.classifiers] 
    208234                cprob = dict() 
    209                 for prob in probs: 
    210                     a = dict(prob.items()) 
     235                for val,prob in probs: 
     236                    if prob != None: #no probability output 
     237                        a = dict(prob.items()) 
     238                    else: 
     239                        a = { val.value : 1. } 
    211240                    cprob = dict( (n, a.get(n, 0)+cprob.get(n, 0)) for n in set(a)|set(cprob) ) 
    212241                cprob = Orange.statistics.distribution.Continuous(cprob) 
     
    458487    def __call__(self, gen, weightID, contingencies, apriori, candidates, clsfr): 
    459488        # if number of features for subset is not set, use square root 
    460         if not self.attributes: 
    461             self.attributes = int(sqrt(len(candidates))) 
    462  
    463         cand = [1]*self.attributes + [0]*(len(candidates) - self.attributes) 
     489        cand = [1]*int(self.attributes) + [0]*(len(candidates) - int(self.attributes)) 
    464490        self.rand.shuffle(cand) 
    465491        # instead with all features, we will invoke split constructor  
Note: See TracChangeset for help on using the changeset viewer.