Ignore:
Timestamp:
03/02/11 13:20:05 (3 years ago)
Author:
markotoplak
Branch:
default
Convert:
9554fef7e943bdad771739bdefa46cc868a339b2
Message:

Added build_stop and build_split to the TreeLearner.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/classification/tree.py

    r7718 r7719  
    21212121 
    21222122 
     2123import Orange.feature.scoring as fscoring 
     2124 
    21232125class TreeLearner(Orange.core.Learner): 
    21242126    """ 
     
    22742276       
    22752277    def __init__(self, **kw): 
     2278        self.split = None 
     2279        self.stop = None 
     2280        self.measure = None 
    22762281        self.__dict__.update(kw) 
    22772282       
     
    22822287        bl = self._base_learner() 
    22832288 
    2284         #build second part of the split 
    2285         if not hasattr(self, "split") and not hasattr(self, "measure"): 
    2286             if examples.domain.classVar.varType == Orange.data.Type.Discrete: 
    2287                 measure = Orange.feature.scoring.GainRatio() 
    2288             else: 
    2289                 measure = Orange.feature.scoring.MSE() 
     2289        #set the scoring criteria for regression if it was not 
     2290        #set by the user 
     2291        if not self.split and not self.measure: 
     2292            measure = fscoring.GainRatio() \ 
     2293                if examples.domain.classVar.varType == Orange.data.Type.Discrete \ 
     2294                else fscoring.MSE() 
    22902295            bl.split.continuousSplitConstructor.measure = measure 
    22912296            bl.split.discreteSplitConstructor.measure = measure 
    2292              
     2297          
     2298        #post pruning 
    22932299        tree = bl(examples, weight) 
    22942300        if getattr(self, "sameMajorityPruning", 0): 
    22952301            tree = Pruner_SameMajority(tree) 
    22962302        if getattr(self, "mForPruning", 0): 
    2297             tree = Pruner_m(tree, m = self.mForPruning) 
     2303            tree = Pruner_m(tree, m=self.mForPruning) 
    22982304 
    22992305        return TreeClassifier(baseClassifier=tree)  
     
    23072313        return self._base_learner() 
    23082314 
    2309     def _build_split(self, learner): 
    2310  
    2311         if hasattr(self, "split"): 
    2312             learner.split = self.split 
    2313         else: 
    2314             learner.split = SplitConstructor_Combined() 
    2315             learner.split.continuousSplitConstructor = \ 
     2315    def build_split(self): 
     2316        """ 
     2317        Return the split constructor built according to object attributes. 
     2318        """ 
     2319         
     2320        split = self.split 
     2321 
     2322        if not split: 
     2323            split = SplitConstructor_Combined() 
     2324            split.continuousSplitConstructor = \ 
    23162325                SplitConstructor_Threshold() 
    23172326            binarization = getattr(self, "binarization", 0) 
    23182327            if binarization == 1: 
    2319                 learner.split.discreteSplitConstructor = \ 
     2328                split.discreteSplitConstructor = \ 
    23202329                    SplitConstructor_ExhaustiveBinary() 
    23212330            elif binarization == 2: 
    2322                 learner.split.discreteSplitConstructor = \ 
     2331                split.discreteSplitConstructor = \ 
    23232332                    SplitConstructor_OneAgainstOthers() 
    23242333            else: 
    2325                 learner.split.discreteSplitConstructor = \ 
     2334                split.discreteSplitConstructor = \ 
    23262335                    SplitConstructor_Feature() 
    23272336 
    2328             measures = {"infoGain": Orange.feature.scoring.InfoGain, 
    2329                 "gainRatio": Orange.feature.scoring.GainRatio, 
    2330                 "gini": Orange.feature.scoring.Gini, 
    2331                 "relief": Orange.feature.scoring.Relief, 
    2332                 "retis": Orange.feature.scoring.MSE 
     2337            measures = {"infoGain": fscoring.InfoGain, 
     2338                "gainRatio": fscoring.GainRatio, 
     2339                "gini": fscoring.Gini, 
     2340                "relief": fscoring.Relief, 
     2341                "retis": fscoring.MSE 
    23332342                } 
    23342343 
    2335             measure = getattr(self, "measure", None) 
     2344            measure = self.measure 
    23362345            if isinstance(measure, str): 
    23372346                measure = measures[measure]() 
    23382347            if not measure: 
    2339                 measure = Orange.feature.scoring.GainRatio() 
    2340  
    2341             measureIsRelief = isinstance(measure, Orange.feature.scoring.Relief) 
     2348                measure = fscoring.GainRatio() 
     2349 
     2350            measureIsRelief = isinstance(measure, fscoring.Relief) 
    23422351            relM = getattr(self, "reliefM", None) 
    23432352            if relM and measureIsRelief: 
     
    23482357                measure.k = relK 
    23492358 
    2350             learner.split.continuousSplitConstructor.measure = measure 
    2351             learner.split.discreteSplitConstructor.measure = measure 
     2359            split.continuousSplitConstructor.measure = measure 
     2360            split.discreteSplitConstructor.measure = measure 
    23522361 
    23532362            wa = getattr(self, "worstAcceptable", 0) 
    23542363            if wa: 
    2355                 learner.split.continuousSplitConstructor.worstAcceptable = wa 
    2356                 learner.split.discreteSplitConstructor.worstAcceptable = wa 
     2364                split.continuousSplitConstructor.worstAcceptable = wa 
     2365                split.discreteSplitConstructor.worstAcceptable = wa 
    23572366 
    23582367            ms = getattr(self, "minSubset", 0) 
    23592368            if ms: 
    2360                 learner.split.continuousSplitConstructor.minSubset = ms 
    2361                 learner.split.discreteSplitConstructor.minSubset = ms 
     2369                split.continuousSplitConstructor.minSubset = ms 
     2370                split.discreteSplitConstructor.minSubset = ms 
     2371 
     2372        return split 
     2373 
     2374    def build_stop(self): 
     2375        """ 
     2376        Return the stop criteria built according to object's attributes. 
     2377        """ 
     2378        stop = self.stop 
     2379        if not stop: 
     2380            stop = Orange.classification.tree.StopCriteria_common() 
     2381            mm = getattr(self, "maxMajority", 1.0) 
     2382            if mm < 1.0: 
     2383                stop.maxMajority = self.maxMajority 
     2384            me = getattr(self, "minExamples", 0) 
     2385            if me: 
     2386                stop.minExamples = self.minExamples 
     2387        return stop 
    23622388 
    23632389    def _base_learner(self): 
    23642390        learner = TreeLearnerBase() 
    23652391 
    2366         self._build_split(learner) 
    2367  
    2368         if hasattr(self, "stop"): 
    2369             learner.stop = self.stop 
    2370         else: 
    2371             learner.stop = Orange.classification.tree.StopCriteria_common() 
    2372             mm = getattr(self, "maxMajority", 1.0) 
    2373             if mm < 1.0: 
    2374                 learner.stop.maxMajority = self.maxMajority 
    2375             me = getattr(self, "minExamples", 0) 
    2376             if me: 
    2377                 learner.stop.minExamples = self.minExamples 
     2392        learner.split = self.build_split() 
     2393        learner.stop = self.build_stop() 
    23782394 
    23792395        for a in ["storeDistributions", "storeContingencies", "storeExamples",  
Note: See TracChangeset for help on using the changeset viewer.