source: orange/Orange/multitarget/tree.py @ 10432:d7544a94c00f

Revision 10432:d7544a94c00f, 9.1 KB checked in by Lan Zagar <lan.zagar@…>, 2 years ago (diff)

Minor code clean-up.

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`~Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`~Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`~Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for
18:obj:`~Orange.classification.tree.TreeLearner.measure` and
19:obj:`~Orange.classification.tree.TreeLearner.node_learner`.
20
21Examples
22========
23
24The following example demonstrates how to build a prediction model with
25MultitargetTreeLearner and use it to predict (multiple) class values for
26a given instance (:download:`multitarget.py <code/multitarget.py>`):
27
28.. literalinclude:: code/multitarget.py
29    :lines: 1-4, 10-12
30
31
32.. index:: Multi-target Variance
33.. autoclass:: Orange.multitarget.tree.MultitargetVariance
34    :members:
35    :show-inheritance:
36
37.. index:: Multi-target Tree Learner
38.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
39    :members:
40    :show-inheritance:
41
42.. index:: Multi-target Tree Classifier
43.. autoclass:: Orange.multitarget.tree.MultiTree
44    :members:
45    :show-inheritance:
46
47"""
48
49from operator import itemgetter
50
51import Orange
52import numpy as np
53
54
55def weighted_variance(X, weights=None):
56    """Computes the variance using a weighted distance to the centroid."""
57    if not weights:
58        weights = [1] * len(X[0])
59    X = X * np.array(weights)
60    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
61
62class MultitargetVariance(Orange.feature.scoring.Score):
63    """
64    A multi-target score that ranks features based on the average class
65    variance of the subsets.
66
67    To compute it, a prototype has to be defined for each subset. Here, it
68    is just the mean vector of class variables. Then the sum of squared
69    distances to the prototypes is computed in each subset. The final score
70    is obtained as the average of subset variances (weighted, to account for
71    subset sizes).
72   
73    Weights can be passed to the constructor to normalize classes with values
74    of different magnitudes or to increase the importance of some classes. In
75    this case, class values are first scaled according to the given weights.
76    """
77
78    def __init__(self, weights=None):
79        """
80        :param weights: Weights of the class variables used when computing
81                        distances. If None, all weights are set to 1.
82        :type weigts: list
83        """
84
85        # Types of classes allowed
86        self.handles_discrete = True
87        ## TODO: for discrete classes with >2 values entropy should be used
88        ## instead of variance
89        self.handles_continuous = True
90        # Can handle continuous features
91        self.computes_thresholds = True
92        # Needs instances
93        self.needs = Orange.feature.scoring.Score.Generator
94
95        self.weights = weights
96
97
98    def threshold_function(self, feature, data, cont_distrib=None, weights=0):
99        """
100        Evaluates possible splits of a continuous feature into a binary one
101        and scores them.
102       
103        :param feature: Continuous feature to be split.
104        :type feature: :class:`Orange.feature.Descriptor`
105
106        :param data: The data set to be split using the given continuous feature.
107        :type data: :class:`Orange.data.Table`
108
109        :return: :obj:`list` of :obj:`tuples <tuple>` [(threshold, score, None),]
110        """
111
112        f = data.domain[feature]
113        values = sorted(set(ins[f].value for ins in data))
114        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
115        if len(ts) > 40:
116            ts = ts[::len(ts)/20]
117        scores = []
118        for t in ts:
119            bf = Orange.feature.discretization.IntervalDiscretizer(
120                points=[t]).construct_variable(f)
121            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
122            data2 = Orange.data.Table(dom2, data)
123            scores.append((t, self.__call__(bf, data2)))
124        return scores
125
126    def best_threshold(self, feature, data):
127        """
128        Computes the best threshold for a split of a continuous feature.
129
130        :param feature: Continuous feature to be split.
131        :type feature: :class:`Orange.feature.Descriptor`
132
133        :param data: The data set to be split using the given continuous feature.
134        :type data: :class:`Orange.data.Table`
135
136        :return: :obj:`tuple` (threshold, score, None)
137        """
138
139        scores = self.threshold_function(feature, data)
140        threshold, score = max(scores, key=itemgetter(1))
141        return (threshold, score, None)
142
143    def __call__(self, feature, data, apriori_class_distribution=None,
144                 weights=0):
145        """
146        :param feature: The feature to be scored.
147        :type feature: :class:`Orange.feature.Descriptor`
148
149        :param data: The data set on which to score the feature.
150        :type data: :class:`Orange.data.Table`
151
152        :return: :obj:`float`
153        """
154
155        split = dict((ins[feature].value, []) for ins in data)
156        for ins in data:
157            split[ins[feature].value].append(ins.get_classes())
158        score = -sum(weighted_variance(x, self.weights) * len(x)
159                     for x in split.values())
160        return score
161
162
163class MultiTreeLearner(Orange.classification.tree.TreeLearner):
164    """
165    MultiTreeLearner is a multi-target version of a tree learner. It is the
166    same as :class:`~Orange.classification.tree.TreeLearner`, except for the
167    default values of two parameters:
168   
169    .. attribute:: measure
170       
171        A multi-target score is used by default: :class:`MultitargetVariance`.
172
173    .. attribute:: node_learner
174       
175        Standard trees use :class:`~Orange.classification.majority.MajorityLearner`
176        to construct prediction models in the leaves of the tree.
177        MultiTreeLearner uses the multi-target equivalent which can be
178        obtained simply by wrapping the majority learner:
179
180        :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`).
181
182    """
183
184    def __init__(self, **kwargs):
185        """
186        The constructor passes all arguments to
187        :class:`~Orange.classification.tree.TreeLearner`'s constructor
188        :obj:`Orange.classification.tree.TreeLearner.__init__`.
189        """
190       
191        measure = MultitargetVariance()
192        node_learner = Orange.multitarget.MultitargetLearner(
193            Orange.classification.majority.MajorityLearner())
194        Orange.classification.tree.TreeLearner.__init__(
195            self, measure=measure, node_learner=node_learner, **kwargs)
196
197    def __call__(self, data, weight=0):
198        """
199        :param data: Data instances to learn from.
200        :type data: :class:`Orange.data.Table`
201
202        :param weight: Id of meta attribute with weights of instances.
203        :type weight: :obj:`int`
204        """
205       
206        for ins in data:
207            for cval in ins.get_classes():
208                if cval.is_special():
209                    raise ValueError('Data has missing class values.')
210        # TreeLearner does not work on class-less domains,
211        # so we set the class if necessary
212        if data.domain.class_var is None:
213            data2 = Orange.data.Table(Orange.data.Domain(
214                data.domain.attributes, data.domain.class_vars[0],
215                class_vars=data.domain.class_vars), data)
216        tree = Orange.classification.tree.TreeLearner.__call__(
217            self, data2, weight)
218        return MultiTree(base_classifier=tree)
219
220class MultiTree(Orange.classification.tree.TreeClassifier):
221    """
222    MultiTree classifier is almost the same as the base class it extends
223    (:class:`~Orange.classification.tree.TreeClassifier`). Only the
224    :obj:`__call__` method is modified so it works with multi-target data.
225    """
226
227    def __call__(self, instance, return_type=Orange.core.GetValue):
228        """
229        :param instance: Instance to be classified.
230        :type instance: :class:`Orange.data.Instance`
231
232        :param return_type: One of
233            :class:`Orange.classification.Classifier.GetValue`,
234            :class:`Orange.classification.Classifier.GetProbabilities` or
235            :class:`Orange.classification.Classifier.GetBoth`
236        """
237
238        node = self.descender(self.tree, instance)[0]
239        return node.node_classifier(instance, return_type)
240
241
242if __name__ == '__main__':
243    data = Orange.data.Table('multitarget-synthetic')
244    print 'Actual classes:\n', data[0].get_classes()
245   
246    majority = Orange.classification.majority.MajorityLearner()
247    mt_majority = Orange.multitarget.MultitargetLearner(majority)
248    c_mtm = mt_majority(data)
249    print 'Majority predictions:\n', c_mtm(data[0])
250
251    mt_tree = MultiTreeLearner(max_depth=3)
252    c_mtt = mt_tree(data)
253    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.