source: orange/Orange/multitarget/tree.py @ 10699:f021e2b1ed50

Revision 10699:f021e2b1ed50, 9.6 KB checked in by Lan Zagar <lan.zagar@…>, 2 years ago (diff)

Make MultiTree work on single-class data (fixes #1167).

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`~Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`~Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`~Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for
18:obj:`~Orange.classification.tree.TreeLearner.measure` and
19:obj:`~Orange.classification.tree.TreeLearner.node_learner`.
20
21Examples
22========
23
24The following example demonstrates how to build a prediction model with
25MultitargetTreeLearner and use it to predict (multiple) class values for
26a given instance (:download:`multitarget.py <code/multitarget.py>`):
27
28.. literalinclude:: code/multitarget.py
29    :lines: 1-4, 10-12
30
31
32.. index:: Multi-target Variance
33.. autoclass:: Orange.multitarget.tree.MultitargetVariance
34    :members:
35    :show-inheritance:
36
37.. index:: Multi-target Tree Learner
38.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
39    :members:
40    :show-inheritance:
41
42.. index:: Multi-target Tree Classifier
43.. autoclass:: Orange.multitarget.tree.MultiTree
44    :members:
45    :show-inheritance:
46
47"""
48
49from operator import itemgetter
50
51import Orange
52import numpy as np
53
54
55def weighted_variance(X, weights=None):
56    """Computes the variance using a weighted distance to the centroid."""
57    if not weights:
58        weights = [1] * len(X[0])
59    X = X * np.array(weights)
60    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
61
62class MultitargetVariance(Orange.feature.scoring.Score):
63    """
64    A multi-target score that ranks features based on the average class
65    variance of the subsets.
66
67    To compute it, a prototype has to be defined for each subset. Here, it
68    is just the mean vector of class variables. Then the sum of squared
69    distances to the prototypes is computed in each subset. The final score
70    is obtained as the average of subset variances (weighted, to account for
71    subset sizes).
72   
73    Weights can be passed to the constructor to normalize classes with values
74    of different magnitudes or to increase the importance of some classes. In
75    this case, class values are first scaled according to the given weights.
76    """
77
78    def __init__(self, weights=None):
79        """
80        :param weights: Weights of the class variables used when computing
81                        distances. If None, all weights are set to 1.
82        :type weigts: list
83        """
84
85        # Types of classes allowed
86        self.handles_discrete = True
87        ## TODO: for discrete classes with >2 values entropy should be used
88        ## instead of variance
89        self.handles_continuous = True
90        # Can handle continuous features
91        self.computes_thresholds = True
92        # Needs instances
93        self.needs = Orange.feature.scoring.Score.Generator
94
95        self.weights = weights
96
97
98    def threshold_function(self, feature, data, cont_distrib=None, weights=0):
99        """
100        Evaluates possible splits of a continuous feature into a binary one
101        and scores them.
102       
103        :param feature: Continuous feature to be split.
104        :type feature: :class:`Orange.feature.Descriptor`
105
106        :param data: The data set to be split using the given continuous
107                     feature.
108        :type data: :class:`Orange.data.Table`
109
110        :return: :obj:`list` of :obj:`tuples <tuple>`
111                 [(threshold, score, None),]
112        """
113
114        f = data.domain[feature]
115        values = sorted(set(ins[f].value for ins in data))
116        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
117        if len(ts) > 40:
118            ts = ts[::len(ts)/20]
119        scores = []
120        for t in ts:
121            bf = Orange.feature.discretization.IntervalDiscretizer(
122                points=[t]).construct_variable(f)
123            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
124            data2 = Orange.data.Table(dom2, data)
125            scores.append((t, self.__call__(bf, data2)))
126        return scores
127
128    def best_threshold(self, feature, data):
129        """
130        Computes the best threshold for a split of a continuous feature.
131
132        :param feature: Continuous feature to be split.
133        :type feature: :class:`Orange.feature.Descriptor`
134
135        :param data: The data set to be split using the given continuous
136                     feature.
137        :type data: :class:`Orange.data.Table`
138
139        :return: :obj:`tuple` (threshold, score, None)
140        """
141
142        scores = self.threshold_function(feature, data)
143        threshold, score = max(scores, key=itemgetter(1))
144        return (threshold, score, None)
145
146    def __call__(self, feature, data, apriori_class_distribution=None,
147                 weights=0):
148        """
149        :param feature: The feature to be scored.
150        :type feature: :class:`Orange.feature.Descriptor`
151
152        :param data: The data set on which to score the feature.
153        :type data: :class:`Orange.data.Table`
154
155        :return: :obj:`float`
156        """
157
158        split = dict((ins[feature].value, []) for ins in data)
159        for ins in data:
160            split[ins[feature].value].append(ins.get_classes())
161        score = -sum(weighted_variance(x, self.weights) * len(x)
162                     for x in split.values())
163        return score
164
165
166class MultiTreeLearner(Orange.classification.tree.TreeLearner):
167    """
168    MultiTreeLearner is a multi-target version of a tree learner. It is the
169    same as :class:`~Orange.classification.tree.TreeLearner`, except for the
170    default values of two parameters:
171   
172    .. attribute:: measure
173       
174        A multi-target score is used by default: :class:`MultitargetVariance`.
175
176    .. attribute:: node_learner
177       
178        Standard trees use
179        :class:`~Orange.classification.majority.MajorityLearner`
180        to construct prediction models in the leaves of the tree.
181        MultiTreeLearner uses the multi-target equivalent which can be
182        obtained simply by wrapping the majority learner:
183
184        :class:`Orange.multitarget.MultitargetLearner`
185        (:class:`Orange.classification.majority.MajorityLearner()`).
186
187    """
188
189    def __init__(self, **kwargs):
190        """
191        The constructor passes all arguments to
192        :class:`~Orange.classification.tree.TreeLearner`'s constructor
193        :obj:`Orange.classification.tree.TreeLearner.__init__`.
194        """
195       
196        if 'measure' not in kwargs:
197            kwargs['measure'] = MultitargetVariance()
198        if 'node_learner' not in kwargs:
199            kwargs['node_learner'] = Orange.multitarget.MultitargetLearner(
200                Orange.classification.majority.MajorityLearner())
201        Orange.classification.tree.TreeLearner.__init__(self, **kwargs)
202
203    def __call__(self, data, weight=0):
204        """
205        :param data: Data instances to learn from.
206        :type data: :class:`Orange.data.Table`
207
208        :param weight: Id of meta attribute with weights of instances.
209        :type weight: :obj:`int`
210        """
211       
212        # Use the class, if data does not have class_vars
213        if not data.domain.class_vars and data.domain.class_var:
214            dom = Orange.data.Domain(data.domain.features,
215                data.domain.class_var, class_vars=[data.domain.class_var])
216            data = Orange.data.Table(dom, data)
217
218        # Check for missing class values in data
219        for ins in data:
220            for cval in ins.get_classes():
221                if cval.is_special():
222                    raise ValueError('Data has missing class values.')
223
224        # TreeLearner does not work on class-less domains,
225        # so we set the class if necessary
226        if not data.domain.class_var and data.domain.class_vars:
227            dom = Orange.data.Domain(data.domain.features,
228                data.domain.class_vars[0], class_vars=data.domain.class_vars)
229            data = Orange.data.Table(dom, data)
230
231        tree = Orange.classification.tree.TreeLearner.__call__(
232            self, data, weight)
233        return MultiTree(base_classifier=tree)
234
235class MultiTree(Orange.classification.tree.TreeClassifier):
236    """
237    MultiTree classifier is almost the same as the base class it extends
238    (:class:`~Orange.classification.tree.TreeClassifier`). Only the
239    :obj:`__call__` method is modified so it works with multi-target data.
240    """
241
242    def __call__(self, instance, return_type=Orange.core.GetValue):
243        """
244        :param instance: Instance to be classified.
245        :type instance: :class:`Orange.data.Instance`
246
247        :param return_type: One of
248            :class:`Orange.classification.Classifier.GetValue`,
249            :class:`Orange.classification.Classifier.GetProbabilities` or
250            :class:`Orange.classification.Classifier.GetBoth`
251        """
252
253        node = self.descender(self.tree, instance)[0]
254        return node.node_classifier(instance, return_type)
255
256
257if __name__ == '__main__':
258    data = Orange.data.Table('multitarget-synthetic')
259    print 'Actual classes:\n', data[0].get_classes()
260   
261    majority = Orange.classification.majority.MajorityLearner()
262    mt_majority = Orange.multitarget.MultitargetLearner(majority)
263    c_mtm = mt_majority(data)
264    print 'Majority predictions:\n', c_mtm(data[0])
265
266    mt_tree = MultiTreeLearner(max_depth=3)
267    c_mtt = mt_tree(data)
268    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.