source: orange/Orange/multitarget/tree.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 8.4 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`~Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`~Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`~Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for
18:obj:`~Orange.classification.tree.TreeLearner.measure` and
19:obj:`~Orange.classification.tree.TreeLearner.node_learner`.
20
21Examples
22========
23
24The following example demonstrates how to build a prediction model with
25MultitargetTreeLearner and use it to predict (multiple) class values for
26a given instance (:download:`multitarget.py <code/multitarget.py>`,
27uses :download:`test-pls.tab <code/test-pls.tab>`):
28
29.. literalinclude:: code/multitarget.py
30    :lines: 1-4, 10-12
31
32
33.. index:: Multi-target Variance
34.. autoclass:: Orange.multitarget.tree.MultitargetVariance
35    :members:
36    :show-inheritance:
37
38.. index:: Multi-target Tree Learner
39.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
40    :members:
41    :show-inheritance:
42
43.. index:: Multi-target Tree Classifier
44.. autoclass:: Orange.multitarget.tree.MultiTree
45    :members:
46    :show-inheritance:
47
48"""
49
50from operator import itemgetter
51
52import Orange
53import numpy as np
54
55
56def weighted_variance(X, weights=None):
57    """Computes the variance using a weighted distance to the centroid."""
58    if not weights:
59        weights = [1] * len(X[0])
60    X = X * np.array(weights)
61    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
62
63class MultitargetVariance(Orange.feature.scoring.Score):
64    """
65    A multi-target score that ranks features based on the variance of the
66    subsets. A weighted distance can be used to compute the variance.
67    """
68
69    def __init__(self, weights=None):
70        """
71        :param weights: Weights of the features used when computing distances.
72                        If None, all weights are set to 1.
73        :type weigts: list
74        """
75
76        # Types of classes allowed
77        self.handles_discrete = True
78        ### TODO: for discrete classes with >2 values entropy should be used instead of variance
79        self.handles_continuous = True
80        # Can handle continuous features
81        self.computes_thresholds = True
82        # Needs instances
83        self.needs = Orange.feature.scoring.Score.Generator
84
85        self.weights = weights
86
87
88    def threshold_function(self, feature, data, cont_distrib=None, weightID=0):
89        """
90        Evaluates possible splits of a continuous feature into a binary one
91        and scores them.
92       
93        :param feature: Continuous feature to be split.
94        :type feature: :class:`Orange.data.variable`
95
96        :param data: The data set to be split using the given continuous feature.
97        :type data: :class:`Orange.data.Table`
98
99        :return: :obj:`list` of :obj:`tuples <tuple>` [(threshold, score, None),]
100        """
101
102        f = data.domain[feature]
103        values = sorted(set(ins[f].value for ins in data))
104        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
105        if len(ts) > 40:
106            ts = ts[::len(ts)/20]
107        scores = []
108        for t in ts:
109            bf = Orange.feature.discretization.IntervalDiscretizer(
110                points=[t]).construct_variable(f)
111            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
112            data2 = Orange.data.Table(dom2, data)
113            scores.append((t, self.__call__(bf, data2)))
114        return scores
115
116    def best_threshold(self, feature, data):
117        """
118        Computes the best threshold for a split of a continuous feature.
119
120        :param feature: Continuous feature to be split.
121        :type feature: :class:`Orange.data.variable`
122
123        :param data: The data set to be split using the given continuous feature.
124        :type data: :class:`Orange.data.Table`
125
126        :return: :obj:`tuple` (threshold, score, None)
127        """
128
129        scores = self.threshold_function(feature, data)
130        threshold, score = max(scores, key=itemgetter(1))
131        return (threshold, score, None)
132
133    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0):
134        """
135        :param feature: The feature to be scored.
136        :type feature: :class:`Orange.data.variable`
137
138        :param data: The data set on which to score the feature.
139        :type data: :class:`Orange.data.Table`
140
141        :return: :obj:`float`
142        """
143
144        split = dict((ins[feature].value, []) for ins in data)
145        for ins in data:
146            split[ins[feature].value].append(ins.get_classes())
147        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values())
148        return score
149
150
151class MultiTreeLearner(Orange.classification.tree.TreeLearner):
152    """
153    MultiTreeLearner is a multi-target version of a tree learner. It is the
154    same as :class:`~Orange.classification.tree.TreeLearner`, except for the
155    default values of two parameters:
156   
157    .. attribute:: measure
158       
159        A multi-target score is used by default: :class:`MultitargetVariance`.
160
161    .. attribute:: node_learner
162       
163        Standard trees use :class:`~Orange.classification.majority.MajorityLearner`
164        to construct prediction models in the leaves of the tree.
165        MultiTreeLearner uses the multi-target equivalent which can be
166        obtained simply by wrapping the majority learner:
167
168        :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`).
169
170    """
171
172    def __init__(self, **kwargs):
173        """
174        The constructor simply passes all given arguments to
175        :class:`~Orange.classification.tree.TreeLearner`'s constructor
176        :obj:`Orange.classification.tree.TreeLearner.__init__`.
177        """
178       
179        measure = MultitargetVariance()
180        node_learner = Orange.multitarget.MultitargetLearner(
181            Orange.classification.majority.MajorityLearner())
182        Orange.classification.tree.TreeLearner.__init__(
183            self, measure=measure, node_learner=node_learner, **kwargs)
184
185    def __call__(self, data, weight=0):
186        """
187        :param data: Data instances to learn from.
188        :type data: :class:`Orange.data.Table`
189
190        :param weight: Id of meta attribute with weights of instances.
191        :type weight: :obj:`int`
192        """
193       
194        # TreeLearner does not work on class-less domains,
195        # so we set the class if necessary
196        if data.domain.class_var is None:
197            data2 = Orange.data.Table(Orange.data.Domain(
198                data.domain.attributes, data.domain.class_vars[0],
199                class_vars=data.domain.class_vars), data)
200        tree = Orange.classification.tree.TreeLearner.__call__(
201            self, data2, weight)
202        return MultiTree(base_classifier=tree)
203
204class MultiTree(Orange.classification.tree.TreeClassifier):
205    """
206    MultiTree classifier is almost the same as the base class it extends
207    (:class:`~Orange.classification.tree.TreeClassifier`). Only the
208    :obj:`__call__` method is modified so it works with multi-target data.
209    """
210
211    def __call__(self, instance, return_type=Orange.core.GetValue):
212        """
213        :param instance: Instance to be classified.
214        :type instance: :class:`Orange.data.Instance`
215
216        :param return_type: One of
217            :class:`Orange.classification.Classifier.GetValue`,
218            :class:`Orange.classification.Classifier.GetProbabilities` or
219            :class:`Orange.classification.Classifier.GetBoth`
220        """
221
222        node = self.descender(self.tree, instance)[0]
223        return node.node_classifier(instance, return_type)
224
225
226if __name__ == '__main__':
227    data = Orange.data.Table('multitarget-synthetic')
228    print 'Actual classes:\n', data[0].get_classes()
229   
230    majority = Orange.classification.majority.MajorityLearner()
231    mt_majority = Orange.multitarget.MultitargetLearner(majority)
232    c_mtm = mt_majority(data)
233    print 'Majority predictions:\n', c_mtm(data[0])
234
235    mt_tree = MultiTreeLearner(max_depth=3)
236    c_mtt = mt_tree(data)
237    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.