source: orange/orange/Orange/multitarget/tree.py @ 9536:aa05bd0e20ac

Revision 9536:aa05bd0e20ac, 8.3 KB checked in by lanz <lan.zagar@…>, 2 years ago (diff)

Example moved to top.

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for `measure` and
18`node_learner`.
19
20Examples
21========
22
23The following example demonstrates how to build a prediction model with
24MultitargetTreeLearner and use it to predict (multiple) class values for
25a given instance (:download:`multitarget.py <code/multitarget.py>`,
26uses :download:`emotions.tab <code/emotions.tab>`):
27
28.. literalinclude:: code/multitarget.py
29   :lines: 1-4, 10-12
30
31
32.. index:: Multi-target Variance
33.. autoclass:: Orange.multitarget.tree.MultitargetVariance
34   :members:
35   :show-inheritance:
36
37.. index:: Multi-target Tree Learner
38.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
39   :members:
40   :show-inheritance:
41
42.. index:: Multi-target Tree Classifier
43.. autoclass:: Orange.multitarget.tree.MultiTree
44   :members:
45   :show-inheritance:
46
47"""
48
49from operator import itemgetter
50
51import numpy as np
52import Orange
53
54
55def weighted_variance(X, weights=None):
56    """Computes the variance using a weighted distance to the centroid."""
57    if not weights:
58        weights = [1] * len(X[0])
59    X = X * np.array(weights)
60    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
61
62class MultitargetVariance(Orange.feature.scoring.Score):
63    """
64    A multi-target score that ranks features based on the variance of the
65    subsets. A weighted distance can be used to compute the variance.
66    """
67
68    def __init__(self, weights=None):
69        """
70        :param weights: Weights of the features used when computing distances.
71                        If None, all weights are set to 1.
72        :type weigts: list
73        """
74
75        # Types of classes allowed
76        self.handles_discrete = True
77        ### TODO: for discrete classes with >2 values entropy should be used instead of variance
78        self.handles_continuous = True
79        # Can handle continuous features
80        self.computes_thresholds = True
81        # Needs instances
82        self.needs = Orange.feature.scoring.Score.Generator
83
84        self.weights = weights
85
86
87    def threshold_function(self, feature, data, cont_distrib=None, weightID=0):
88        """
89        Evaluates possible splits of a continuous feature into a binary one
90        and scores them.
91       
92        :param feature: Continuous feature to be split.
93        :type feature: :class:`Orange.data.variable`
94
95        :param data: The data set to be split using the given continuous feature.
96        :type data: :class:`Orange.data.Table`
97
98        :return: :obj:`list` of tuples (threshold, score, None)
99        """
100
101        f = data.domain[feature]
102        values = sorted(set(ins[f].value for ins in data))
103        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
104        if len(ts) > 40:
105            ts = ts[::len(ts)/20]
106        scores = []
107        for t in ts:
108            bf = Orange.feature.discretization.IntervalDiscretizer(
109                points=[t]).construct_variable(f)
110            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
111            data2 = Orange.data.Table(dom2, data)
112            scores.append((t, self.__call__(bf, data2)))
113        return scores
114
115    def best_threshold(self, feature, data):
116        """
117        Computes the best threshold for a split of a continuous feature.
118
119        :param feature: Continuous feature to be split.
120        :type feature: :class:`Orange.data.variable`
121
122        :param data: The data set to be split using the given continuous feature.
123        :type data: :class:`Orange.data.Table`
124
125        :return: :obj:`tuple` (threshold, score, None)
126        """
127
128        scores = self.threshold_function(feature, data)
129        threshold, score = max(scores, key=itemgetter(1))
130        return (threshold, score, None)
131
132    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0):
133        """
134        :param feature: The feature to be scored.
135        :type feature: :class:`Orange.data.variable`
136
137        :param data: The data set on which to score the feature.
138        :type data: :class:`Orange.data.Table`
139
140        :return: :obj:`float`
141        """
142
143        split = dict((ins[feature].value, []) for ins in data)
144        for ins in data:
145            split[ins[feature].value].append(ins.get_classes())
146        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values())
147        return score
148
149
150class MultiTreeLearner(Orange.classification.tree.TreeLearner):
151    """
152    MultiTreeLearner is a multi-target version of a tree learner. It is the
153    same as :class:`Orange.classification.tree.TreeLearner`, except for the
154    default values of two parameters:
155   
156    .. attribute:: measure
157       
158        A multi-target score is used by default: :class:`Orange.multitarget.tree.MultitargetVariance`.
159
160    .. attribute:: node_learner
161       
162        Standard trees use :class:`Orange.classification.majority.MajorityLearner`
163        to construct prediction models in the leaves of the tree.
164        MultiTreeLearner uses the multi-target equivalent which can be
165        obtained simply by wrapping the majority learner:
166
167        :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`).
168
169    """
170
171    def __init__(self, **kwargs):
172        """
173        The constructor simply passes all given arguments to TreeLearner's constructor
174        :obj:`Orange.classification.tree.TreeLearner.__init__`.
175        """
176       
177        measure = MultitargetVariance()
178        node_learner = Orange.multitarget.MultitargetLearner(
179            Orange.classification.majority.MajorityLearner())
180        Orange.classification.tree.TreeLearner.__init__(
181            self, measure=measure, node_learner=node_learner, **kwargs)
182
183    def __call__(self, data, weight=0):
184        """
185        :param data: Data instances to learn from.
186        :type data: :class:`Orange.data.Table`
187
188        :param weight: Id of meta attribute with weights of instances.
189        :type weight: :obj:`int`
190        """
191       
192        # TreeLearner does not work on class-less domains,
193        # so we set the class if necessary
194        if data.domain.class_var is None:
195            data2 = Orange.data.Table(Orange.data.Domain(
196                data.domain.attributes, data.domain.class_vars[0],
197                class_vars=data.domain.class_vars), data)
198        tree = Orange.classification.tree.TreeLearner.__call__(
199            self, data2, weight)
200        return MultiTree(base_classifier=tree)
201
202class MultiTree(Orange.classification.tree.TreeClassifier):
203    """
204    MultiTree classifier is almost the same as the base class it extends
205    (:class:`Orange.classification.tree.TreeClassifier`). Only the
206    :obj:`__call__` method is modified so it works with multi-target data.
207    """
208
209    def __call__(self, instance, return_type=Orange.core.GetValue):
210        """
211        :param instance: Instance to be classified.
212        :type instance: :class:`Orange.data.Instance`
213
214        :param return_type: One of
215            :class:`Orange.classification.Classifier.GetValue`,
216            :class:`Orange.classification.Classifier.GetProbabilities` or
217            :class:`Orange.classification.Classifier.GetBoth`
218        """
219
220        node = self.descender(self.tree, instance)[0]
221        return node.node_classifier(instance, return_type)
222
223
224if __name__ == '__main__':
225    data = Orange.data.Table('test-pls')
226    print 'Actual classes:\n', data[0].get_classes()
227   
228    majority = Orange.classification.majority.MajorityLearner()
229    mt_majority = Orange.multitarget.MultitargetLearner(majority)
230    c_mtm = mt_majority(data)
231    print 'Majority predictions:\n', c_mtm(data[0])
232
233    mt_tree = MultiTreeLearner(max_depth=3)
234    c_mtt = mt_tree(data)
235    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.