source: orange/Orange/multitarget/tree.py @ 10335:18f3ac9e1ec6

Revision 10335:18f3ac9e1ec6, 9.1 KB checked in by Lan Zagar <lan.zagar@…>, 2 years ago (diff)

Expanded description of MultitargetVariance.

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`~Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`~Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`~Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for
18:obj:`~Orange.classification.tree.TreeLearner.measure` and
19:obj:`~Orange.classification.tree.TreeLearner.node_learner`.
20
21Examples
22========
23
24The following example demonstrates how to build a prediction model with
25MultitargetTreeLearner and use it to predict (multiple) class values for
26a given instance (:download:`multitarget.py <code/multitarget.py>`):
27
28.. literalinclude:: code/multitarget.py
29    :lines: 1-4, 10-12
30
31
32.. index:: Multi-target Variance
33.. autoclass:: Orange.multitarget.tree.MultitargetVariance
34    :members:
35    :show-inheritance:
36
37.. index:: Multi-target Tree Learner
38.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
39    :members:
40    :show-inheritance:
41
42.. index:: Multi-target Tree Classifier
43.. autoclass:: Orange.multitarget.tree.MultiTree
44    :members:
45    :show-inheritance:
46
47"""
48
49from operator import itemgetter
50
51import Orange
52import numpy as np
53
54
55def weighted_variance(X, weights=None):
56    """Computes the variance using a weighted distance to the centroid."""
57    if not weights:
58        weights = [1] * len(X[0])
59    X = X * np.array(weights)
60    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
61
62class MultitargetVariance(Orange.feature.scoring.Score):
63    """
64    A multi-target score that ranks features based on the average class
65    variance of the subsets.
66
67    To compute it, a prototype has to be defined for each subset. Here, it
68    is just the mean vector of class variables. Then the sum of squared
69    distances to the prototypes is computed in each subset. The final score
70    is obtained as the average of subset variances (weighted, to account for
71    subset sizes).
72   
73    Weights can be passed to the constructor to normalize classes with values
74    of different magnitudes or to increase the importance of some classes. In
75    this case, class values are first scaled according to the given weights.
76    """
77
78    def __init__(self, weights=None):
79        """
80        :param weights: Weights of the class variables used when computing
81                        distances. If None, all weights are set to 1.
82        :type weigts: list
83        """
84
85        # Types of classes allowed
86        self.handles_discrete = True
87        ### TODO: for discrete classes with >2 values entropy should be used instead of variance
88        self.handles_continuous = True
89        # Can handle continuous features
90        self.computes_thresholds = True
91        # Needs instances
92        self.needs = Orange.feature.scoring.Score.Generator
93
94        self.weights = weights
95
96
97    def threshold_function(self, feature, data, cont_distrib=None, weightID=0):
98        """
99        Evaluates possible splits of a continuous feature into a binary one
100        and scores them.
101       
102        :param feature: Continuous feature to be split.
103        :type feature: :class:`Orange.feature.Descriptor`
104
105        :param data: The data set to be split using the given continuous feature.
106        :type data: :class:`Orange.data.Table`
107
108        :return: :obj:`list` of :obj:`tuples <tuple>` [(threshold, score, None),]
109        """
110
111        f = data.domain[feature]
112        values = sorted(set(ins[f].value for ins in data))
113        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
114        if len(ts) > 40:
115            ts = ts[::len(ts)/20]
116        scores = []
117        for t in ts:
118            bf = Orange.feature.discretization.IntervalDiscretizer(
119                points=[t]).construct_variable(f)
120            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
121            data2 = Orange.data.Table(dom2, data)
122            scores.append((t, self.__call__(bf, data2)))
123        return scores
124
125    def best_threshold(self, feature, data):
126        """
127        Computes the best threshold for a split of a continuous feature.
128
129        :param feature: Continuous feature to be split.
130        :type feature: :class:`Orange.feature.Descriptor`
131
132        :param data: The data set to be split using the given continuous feature.
133        :type data: :class:`Orange.data.Table`
134
135        :return: :obj:`tuple` (threshold, score, None)
136        """
137
138        scores = self.threshold_function(feature, data)
139        threshold, score = max(scores, key=itemgetter(1))
140        return (threshold, score, None)
141
142    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0):
143        """
144        :param feature: The feature to be scored.
145        :type feature: :class:`Orange.feature.Descriptor`
146
147        :param data: The data set on which to score the feature.
148        :type data: :class:`Orange.data.Table`
149
150        :return: :obj:`float`
151        """
152
153        split = dict((ins[feature].value, []) for ins in data)
154        for ins in data:
155            split[ins[feature].value].append(ins.get_classes())
156        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values())
157        return score
158
159
160class MultiTreeLearner(Orange.classification.tree.TreeLearner):
161    """
162    MultiTreeLearner is a multi-target version of a tree learner. It is the
163    same as :class:`~Orange.classification.tree.TreeLearner`, except for the
164    default values of two parameters:
165   
166    .. attribute:: measure
167       
168        A multi-target score is used by default: :class:`MultitargetVariance`.
169
170    .. attribute:: node_learner
171       
172        Standard trees use :class:`~Orange.classification.majority.MajorityLearner`
173        to construct prediction models in the leaves of the tree.
174        MultiTreeLearner uses the multi-target equivalent which can be
175        obtained simply by wrapping the majority learner:
176
177        :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`).
178
179    """
180
181    def __init__(self, **kwargs):
182        """
183        The constructor passes all arguments to
184        :class:`~Orange.classification.tree.TreeLearner`'s constructor
185        :obj:`Orange.classification.tree.TreeLearner.__init__`.
186        """
187       
188        measure = MultitargetVariance()
189        node_learner = Orange.multitarget.MultitargetLearner(
190            Orange.classification.majority.MajorityLearner())
191        Orange.classification.tree.TreeLearner.__init__(
192            self, measure=measure, node_learner=node_learner, **kwargs)
193
194    def __call__(self, data, weight=0):
195        """
196        :param data: Data instances to learn from.
197        :type data: :class:`Orange.data.Table`
198
199        :param weight: Id of meta attribute with weights of instances.
200        :type weight: :obj:`int`
201        """
202       
203        for ins in data:
204            for cval in ins.get_classes():
205                if cval.is_special():
206                    raise ValueError('Data has missing class values.')
207        # TreeLearner does not work on class-less domains,
208        # so we set the class if necessary
209        if data.domain.class_var is None:
210            data2 = Orange.data.Table(Orange.data.Domain(
211                data.domain.attributes, data.domain.class_vars[0],
212                class_vars=data.domain.class_vars), data)
213        tree = Orange.classification.tree.TreeLearner.__call__(
214            self, data2, weight)
215        return MultiTree(base_classifier=tree)
216
217class MultiTree(Orange.classification.tree.TreeClassifier):
218    """
219    MultiTree classifier is almost the same as the base class it extends
220    (:class:`~Orange.classification.tree.TreeClassifier`). Only the
221    :obj:`__call__` method is modified so it works with multi-target data.
222    """
223
224    def __call__(self, instance, return_type=Orange.core.GetValue):
225        """
226        :param instance: Instance to be classified.
227        :type instance: :class:`Orange.data.Instance`
228
229        :param return_type: One of
230            :class:`Orange.classification.Classifier.GetValue`,
231            :class:`Orange.classification.Classifier.GetProbabilities` or
232            :class:`Orange.classification.Classifier.GetBoth`
233        """
234
235        node = self.descender(self.tree, instance)[0]
236        return node.node_classifier(instance, return_type)
237
238
239if __name__ == '__main__':
240    data = Orange.data.Table('multitarget-synthetic')
241    print 'Actual classes:\n', data[0].get_classes()
242   
243    majority = Orange.classification.majority.MajorityLearner()
244    mt_majority = Orange.multitarget.MultitargetLearner(majority)
245    c_mtm = mt_majority(data)
246    print 'Majority predictions:\n', c_mtm(data[0])
247
248    mt_tree = MultiTreeLearner(max_depth=3)
249    c_mtt = mt_tree(data)
250    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.