Ignore:
Timestamp:
02/23/11 11:03:05 (3 years ago)
Author:
markotoplak
Branch:
default
Convert:
d7e95406f3b5f3a5cd07c98604400a022f6a508c
Message:

Created a python class, which wraps orange.TreeClassifier. TreeLearner now returns the new class.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/classification/tree.py

    r7707 r7708  
    599599 
    600600    An abstract base class for split constructors that employ  
    601     a :obj:`Orange.feature.scoring.Measure` to assess a quality of a split. At present, 
     601    a :obj:`Orange.feature.scoring.Measure` to assess a quality of a split.  
     602    At present, 
    602603    all split constructors except for :obj:`SplitConstructor_Combined` 
    603604    are derived from this class. 
     
    607608        A component of type :obj:`Orange.feature.scoring.Measure` used for 
    608609        evaluation of a split. Note that you must select the subclass  
    609         :obj:`Orange.feature.scoring.Measure` capable of handling your class type  
     610        :obj:`Orange.feature.scoring.Measure` capable of handling your  
     611        class type  
    610612        - you cannot use :obj:`Orange.feature.scoring.GainRatio` 
    611613        for building regression trees or :obj:`Orange.feature.scoring.MSE` 
     
    630632    The constructed :obj:`branchSelector` is an instance of  
    631633    :obj:`orange.ClassifierFromVarFD` that returns a value of the  
    632     selected attribute. If the attribute is :obj:`Orange.data.variable.Discrete`, 
     634    selected attribute. If the attribute is  
     635    :obj:`Orange.data.variable.Discrete`, 
    633636    :obj:`branchDescription`'s are the attribute's values. The  
    634637    attribute is marked as spent, so that it cannot reappear in the  
     
    20132016from Orange.core import \ 
    20142017     TreeLearner as TreeLearnerBase, \ 
    2015          TreeClassifier, \ 
     2018         TreeClassifier as _TreeClassifier, \ 
    20162019         C45Learner, \ 
    20172020         C45Classifier, \ 
     
    22972300        self.__dict__.update(kw) 
    22982301       
    2299     def __setattr__(self, name, value): 
    2300         if name in ["split", "binarization", "measure", "worstAcceptable", "minSubset", 
    2301               "stop", "maxMajority", "minExamples", "nodeLearner", "maxDepth", "reliefM", "reliefK"]: 
    2302             self.learner = None 
    2303         self.__dict__[name] = value 
    2304  
    23052302    def __call__(self, examples, weight=0): 
    23062303        """ 
    23072304        Return a classifier from the given examples. 
    23082305        """ 
    2309         if not self.learner: 
    2310             self.learner = self.instance() 
     2306        bl = self._base_learner() 
     2307 
     2308        #build second part of the split 
    23112309        if not hasattr(self, "split") and not hasattr(self, "measure"): 
    23122310            if examples.domain.classVar.varType == Orange.data.Type.Discrete: 
     
    23142312            else: 
    23152313                measure = Orange.feature.scoring.MSE() 
    2316             self.learner.split.continuousSplitConstructor.measure = measure 
    2317             self.learner.split.discreteSplitConstructor.measure = measure 
     2314            bl.split.continuousSplitConstructor.measure = measure 
     2315            bl.split.discreteSplitConstructor.measure = measure 
    23182316             
    2319         tree = self.learner(examples, weight) 
     2317        tree = bl(examples, weight) 
    23202318        if getattr(self, "sameMajorityPruning", 0): 
    2321             tree = Orange.classification.tree.Pruner_SameMajority(tree) 
     2319            tree = Pruner_SameMajority(tree) 
    23222320        if getattr(self, "mForPruning", 0): 
    2323             tree = Orange.classification.tree.Pruner_m(tree, m = self.mForPruning) 
    2324         return tree 
     2321            tree = Pruner_m(tree, m = self.mForPruning) 
     2322 
     2323        return TreeClassifier(baseClassifier=tree)  
    23252324 
    23262325    def instance(self): 
    23272326        """ 
    2328         Return the constructed learner - an object of :class:`TreeLearnerBase`. 
     2327        DEPRECATED. Return a base learner - an object  
     2328        of :class:`TreeLearnerBase`.  
     2329        This method is left for backwards compatibility. 
    23292330        """ 
    2330         learner = Orange.classification.tree.TreeLearnerBase() 
    2331  
    2332         hasSplit = hasattr(self, "split") 
    2333         if hasSplit: 
     2331        return self._base_learner() 
     2332 
     2333    def _build_split(self, learner): 
     2334 
     2335        if hasattr(self, "split"): 
    23342336            learner.split = self.split 
    23352337        else: 
    2336             learner.split = Orange.classification.tree.SplitConstructor_Combined() 
    2337             learner.split.continuousSplitConstructor = Orange.classification.tree.SplitConstructor_Threshold() 
     2338            learner.split = SplitConstructor_Combined() 
     2339            learner.split.continuousSplitConstructor = \ 
     2340                SplitConstructor_Threshold() 
    23382341            binarization = getattr(self, "binarization", 0) 
    23392342            if binarization == 1: 
    2340                 learner.split.discreteSplitConstructor = Orange.classification.tree.SplitConstructor_ExhaustiveBinary() 
     2343                learner.split.discreteSplitConstructor = \ 
     2344                    SplitConstructor_ExhaustiveBinary() 
    23412345            elif binarization == 2: 
    2342                 learner.split.discreteSplitConstructor = Orange.classification.tree.SplitConstructor_OneAgainstOthers() 
     2346                learner.split.discreteSplitConstructor = \ 
     2347                    SplitConstructor_OneAgainstOthers() 
    23432348            else: 
    2344                 learner.split.discreteSplitConstructor = Orange.classification.tree.SplitConstructor_Feature() 
     2349                learner.split.discreteSplitConstructor = \ 
     2350                    SplitConstructor_Feature() 
    23452351 
    23462352            measures = {"infoGain": Orange.feature.scoring.InfoGain, 
     
    23522358 
    23532359            measure = getattr(self, "measure", None) 
    2354             if type(measure) == str: 
     2360            if isinstance(measure, str): 
    23552361                measure = measures[measure]() 
    2356             if not hasSplit and not measure: 
     2362            if not measure: 
    23572363                measure = Orange.feature.scoring.GainRatio() 
    23582364 
    2359             measureIsRelief = type(measure) == Orange.feature.scoring.Relief 
     2365            measureIsRelief = isinstance(measure, Orange.feature.scoring.Relief) 
    23602366            relM = getattr(self, "reliefM", None) 
    23612367            if relM and measureIsRelief: 
     
    23782384                learner.split.continuousSplitConstructor.minSubset = ms 
    23792385                learner.split.discreteSplitConstructor.minSubset = ms 
     2386 
     2387    def _base_learner(self): 
     2388        learner = TreeLearnerBase() 
     2389 
     2390        self._build_split(learner) 
    23802391 
    23812392        if hasattr(self, "stop"): 
     
    23902401                learner.stop.minExamples = self.minExamples 
    23912402 
    2392         for a in ["storeDistributions", "storeContingencies", "storeExamples", "storeNodeClassifier", "nodeLearner", "maxDepth"]: 
     2403        for a in ["storeDistributions", "storeContingencies", "storeExamples",  
     2404            "storeNodeClassifier", "nodeLearner", "maxDepth"]: 
    23932405            if hasattr(self, a): 
    23942406                setattr(learner, a, getattr(self, a)) 
     
    24142426    :type tree: :class:`TreeClassifier` 
    24152427    """ 
    2416     return __countNodes(isinstance(tree, Orange.classification.tree.TreeClassifier) and tree.tree or tree) 
     2428    return __countNodes(tree.tree if isinstance(tree, _TreeClassifier) or \ 
     2429        isinstance(tree, TreeClassifier) else tree) 
    24172430 
    24182431 
     
    24342447    :type tree: :class:`TreeClassifier` 
    24352448    """ 
    2436     return __countLeaves(isinstance(tree, Orange.classification.tree.TreeClassifier) and tree.tree or tree) 
    2437  
     2449    return __countLeaves(tree.tree if isinstance(tree, _TreeClassifier) or \ 
     2450        isinstance(tree, TreeClassifier) else tree) 
    24382451 
    24392452# the following is for the output 
     
    30003013        argkw.get("maxDepth", 1e10), argkw.get("simpleFirst", True), tree, 
    30013014        leafShape=leafShape, nodeShape=nodeShape, fle=fle).dotTree() 
    3002                          
     3015  
     3016class TreeClassifier(Orange.classification.Classifier): 
     3017    """ 
     3018    Wraps :class:`Orange.core.TreeClassifier`. 
     3019    """ 
     3020     
     3021    def __init__(self, baseClassifier=None): 
     3022        if not baseClassifier: baseClassifier = _TreeClassifier() 
     3023        self.nativeClassifier = baseClassifier 
     3024        for k, v in self.nativeClassifier.__dict__.items(): 
     3025            self.__dict__[k] = v 
     3026   
     3027    def __call__(self, instance, result_type=Orange.classification.Classifier.GetValue, 
     3028                 *args, **kwdargs): 
     3029        """Classify a new instance. 
     3030         
     3031        :param instance: instance to be classified. 
     3032        :type instance: :class:`Orange.data.Instance` 
     3033        :param result_type:  
     3034              :class:`Orange.classification.Classifier.GetValue` or \ 
     3035              :class:`Orange.classification.Classifier.GetProbabilities` or 
     3036              :class:`Orange.classification.Classifier.GetBoth` 
     3037         
     3038        :rtype: :class:`Orange.data.Value`,  
     3039              :class:`Orange.statistics.Distribution` or a tuple with both 
     3040        """ 
     3041        return self.nativeClassifier(instance, result_type, *args, **kwdargs) 
     3042 
     3043    def __setattr__(self, name, value): 
     3044        if name == "nativeClassifier": 
     3045            self.__dict__[name] = value 
     3046            return 
     3047        if name in self.nativeClassifier.__dict__: 
     3048            self.nativeClassifier.__dict__[name] = value 
     3049        self.__dict__[name] = value 
     3050     
     3051    
     3052 
    30033053dotTree = printDot 
    30043054""" An alias for :func:`printDot`. Left for compatibility. """ 
Note: See TracChangeset for help on using the changeset viewer.