Ignore:
Timestamp:
03/30/12 15:55:49 (2 years ago)
Author:
Lan Zagar <lan.zagar@…>
Branch:
default
rebase_source:
18b21e913d7340679c1eaf89b3b2d3edb31d71b9
Message:

Make MultiTree work on single-class data (fixes #1167).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/multitarget/tree.py

    r10432 r10699  
    104104        :type feature: :class:`Orange.feature.Descriptor` 
    105105 
    106         :param data: The data set to be split using the given continuous feature. 
    107         :type data: :class:`Orange.data.Table` 
    108  
    109         :return: :obj:`list` of :obj:`tuples <tuple>` [(threshold, score, None),] 
     106        :param data: The data set to be split using the given continuous 
     107                     feature. 
     108        :type data: :class:`Orange.data.Table` 
     109 
     110        :return: :obj:`list` of :obj:`tuples <tuple>` 
     111                 [(threshold, score, None),] 
    110112        """ 
    111113 
     
    131133        :type feature: :class:`Orange.feature.Descriptor` 
    132134 
    133         :param data: The data set to be split using the given continuous feature. 
     135        :param data: The data set to be split using the given continuous 
     136                     feature. 
    134137        :type data: :class:`Orange.data.Table` 
    135138 
     
    173176    .. attribute:: node_learner 
    174177         
    175         Standard trees use :class:`~Orange.classification.majority.MajorityLearner` 
     178        Standard trees use 
     179        :class:`~Orange.classification.majority.MajorityLearner` 
    176180        to construct prediction models in the leaves of the tree. 
    177181        MultiTreeLearner uses the multi-target equivalent which can be  
    178182        obtained simply by wrapping the majority learner: 
    179183 
    180         :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`). 
     184        :class:`Orange.multitarget.MultitargetLearner` 
     185        (:class:`Orange.classification.majority.MajorityLearner()`). 
    181186 
    182187    """ 
     
    189194        """ 
    190195         
    191         measure = MultitargetVariance() 
    192         node_learner = Orange.multitarget.MultitargetLearner( 
    193             Orange.classification.majority.MajorityLearner()) 
    194         Orange.classification.tree.TreeLearner.__init__( 
    195             self, measure=measure, node_learner=node_learner, **kwargs) 
     196        if 'measure' not in kwargs: 
     197            kwargs['measure'] = MultitargetVariance() 
     198        if 'node_learner' not in kwargs: 
     199            kwargs['node_learner'] = Orange.multitarget.MultitargetLearner( 
     200                Orange.classification.majority.MajorityLearner()) 
     201        Orange.classification.tree.TreeLearner.__init__(self, **kwargs) 
    196202 
    197203    def __call__(self, data, weight=0): 
     
    204210        """ 
    205211         
     212        # Use the class, if data does not have class_vars 
     213        if not data.domain.class_vars and data.domain.class_var: 
     214            dom = Orange.data.Domain(data.domain.features, 
     215                data.domain.class_var, class_vars=[data.domain.class_var]) 
     216            data = Orange.data.Table(dom, data) 
     217 
     218        # Check for missing class values in data 
    206219        for ins in data: 
    207220            for cval in ins.get_classes(): 
    208221                if cval.is_special(): 
    209222                    raise ValueError('Data has missing class values.') 
     223 
    210224        # TreeLearner does not work on class-less domains, 
    211225        # so we set the class if necessary 
    212         if data.domain.class_var is None: 
    213             data2 = Orange.data.Table(Orange.data.Domain( 
    214                 data.domain.attributes, data.domain.class_vars[0], 
    215                 class_vars=data.domain.class_vars), data) 
     226        if not data.domain.class_var and data.domain.class_vars: 
     227            dom = Orange.data.Domain(data.domain.features, 
     228                data.domain.class_vars[0], class_vars=data.domain.class_vars) 
     229            data = Orange.data.Table(dom, data) 
     230 
    216231        tree = Orange.classification.tree.TreeLearner.__call__( 
    217             self, data2, weight) 
     232            self, data, weight) 
    218233        return MultiTree(base_classifier=tree) 
    219234 
Note: See TracChangeset for help on using the changeset viewer.