source: orange/Orange/multitarget/tree.py @ 10092:cb849cfe6729

Revision 10092:cb849cfe6729, 8.6 KB checked in by Lan Zagar <lan.zagar@…>, 2 years ago (diff)

Added checking for missing class values.

Line 
1"""
2.. index:: Multi-target Tree Learner
3
4***************************************
5Multi-target Tree Learner
6***************************************
7
8To use the tree learning algorithm for multi-target data, standard
9orange trees (:class:`Orange.classification.tree.TreeLearner`) can be used.
10Only the :obj:`~Orange.classification.tree.TreeLearner.measure` for feature
11scoring and the :obj:`~Orange.classification.tree.TreeLearner.node_learner`
12components have to be chosen so that they work on multi-target data domains.
13
14This module provides one such measure (:class:`MultitargetVariance`) that
15can be used and a helper class :class:`MultiTreeLearner` which extends
16:class:`~Orange.classification.tree.TreeLearner` and is the same in all
17aspects except for different (multi-target) defaults for
18:obj:`~Orange.classification.tree.TreeLearner.measure` and
19:obj:`~Orange.classification.tree.TreeLearner.node_learner`.
20
21Examples
22========
23
24The following example demonstrates how to build a prediction model with
25MultitargetTreeLearner and use it to predict (multiple) class values for
26a given instance (:download:`multitarget.py <code/multitarget.py>`):
27
28.. literalinclude:: code/multitarget.py
29    :lines: 1-4, 10-12
30
31
32.. index:: Multi-target Variance
33.. autoclass:: Orange.multitarget.tree.MultitargetVariance
34    :members:
35    :show-inheritance:
36
37.. index:: Multi-target Tree Learner
38.. autoclass:: Orange.multitarget.tree.MultiTreeLearner
39    :members:
40    :show-inheritance:
41
42.. index:: Multi-target Tree Classifier
43.. autoclass:: Orange.multitarget.tree.MultiTree
44    :members:
45    :show-inheritance:
46
47"""
48
49from operator import itemgetter
50
51import Orange
52import numpy as np
53
54
55def weighted_variance(X, weights=None):
56    """Computes the variance using a weighted distance to the centroid."""
57    if not weights:
58        weights = [1] * len(X[0])
59    X = X * np.array(weights)
60    return np.sum(np.sum((X - np.mean(X, 0))**2, 1))
61
62class MultitargetVariance(Orange.feature.scoring.Score):
63    """
64    A multi-target score that ranks features based on the variance of the
65    subsets. A weighted distance can be used to compute the variance.
66    """
67
68    def __init__(self, weights=None):
69        """
70        :param weights: Weights of the features used when computing distances.
71                        If None, all weights are set to 1.
72        :type weigts: list
73        """
74
75        # Types of classes allowed
76        self.handles_discrete = True
77        ### TODO: for discrete classes with >2 values entropy should be used instead of variance
78        self.handles_continuous = True
79        # Can handle continuous features
80        self.computes_thresholds = True
81        # Needs instances
82        self.needs = Orange.feature.scoring.Score.Generator
83
84        self.weights = weights
85
86
87    def threshold_function(self, feature, data, cont_distrib=None, weightID=0):
88        """
89        Evaluates possible splits of a continuous feature into a binary one
90        and scores them.
91       
92        :param feature: Continuous feature to be split.
93        :type feature: :class:`Orange.feature.Descriptor`
94
95        :param data: The data set to be split using the given continuous feature.
96        :type data: :class:`Orange.data.Table`
97
98        :return: :obj:`list` of :obj:`tuples <tuple>` [(threshold, score, None),]
99        """
100
101        f = data.domain[feature]
102        values = sorted(set(ins[f].value for ins in data))
103        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
104        if len(ts) > 40:
105            ts = ts[::len(ts)/20]
106        scores = []
107        for t in ts:
108            bf = Orange.feature.discretization.IntervalDiscretizer(
109                points=[t]).construct_variable(f)
110            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
111            data2 = Orange.data.Table(dom2, data)
112            scores.append((t, self.__call__(bf, data2)))
113        return scores
114
115    def best_threshold(self, feature, data):
116        """
117        Computes the best threshold for a split of a continuous feature.
118
119        :param feature: Continuous feature to be split.
120        :type feature: :class:`Orange.feature.Descriptor`
121
122        :param data: The data set to be split using the given continuous feature.
123        :type data: :class:`Orange.data.Table`
124
125        :return: :obj:`tuple` (threshold, score, None)
126        """
127
128        scores = self.threshold_function(feature, data)
129        threshold, score = max(scores, key=itemgetter(1))
130        return (threshold, score, None)
131
132    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0):
133        """
134        :param feature: The feature to be scored.
135        :type feature: :class:`Orange.feature.Descriptor`
136
137        :param data: The data set on which to score the feature.
138        :type data: :class:`Orange.data.Table`
139
140        :return: :obj:`float`
141        """
142
143        split = dict((ins[feature].value, []) for ins in data)
144        for ins in data:
145            split[ins[feature].value].append(ins.get_classes())
146        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values())
147        return score
148
149
150class MultiTreeLearner(Orange.classification.tree.TreeLearner):
151    """
152    MultiTreeLearner is a multi-target version of a tree learner. It is the
153    same as :class:`~Orange.classification.tree.TreeLearner`, except for the
154    default values of two parameters:
155   
156    .. attribute:: measure
157       
158        A multi-target score is used by default: :class:`MultitargetVariance`.
159
160    .. attribute:: node_learner
161       
162        Standard trees use :class:`~Orange.classification.majority.MajorityLearner`
163        to construct prediction models in the leaves of the tree.
164        MultiTreeLearner uses the multi-target equivalent which can be
165        obtained simply by wrapping the majority learner:
166
167        :class:`Orange.multitarget.MultitargetLearner` (:class:`Orange.classification.majority.MajorityLearner()`).
168
169    """
170
171    def __init__(self, **kwargs):
172        """
173        The constructor passes all arguments to
174        :class:`~Orange.classification.tree.TreeLearner`'s constructor
175        :obj:`Orange.classification.tree.TreeLearner.__init__`.
176        """
177       
178        measure = MultitargetVariance()
179        node_learner = Orange.multitarget.MultitargetLearner(
180            Orange.classification.majority.MajorityLearner())
181        Orange.classification.tree.TreeLearner.__init__(
182            self, measure=measure, node_learner=node_learner, **kwargs)
183
184    def __call__(self, data, weight=0):
185        """
186        :param data: Data instances to learn from.
187        :type data: :class:`Orange.data.Table`
188
189        :param weight: Id of meta attribute with weights of instances.
190        :type weight: :obj:`int`
191        """
192       
193        for ins in data:
194            for cval in ins.get_classes():
195                if cval.is_special():
196                    raise ValueError('Data has missing class values.')
197        # TreeLearner does not work on class-less domains,
198        # so we set the class if necessary
199        if data.domain.class_var is None:
200            data2 = Orange.data.Table(Orange.data.Domain(
201                data.domain.attributes, data.domain.class_vars[0],
202                class_vars=data.domain.class_vars), data)
203        tree = Orange.classification.tree.TreeLearner.__call__(
204            self, data2, weight)
205        return MultiTree(base_classifier=tree)
206
207class MultiTree(Orange.classification.tree.TreeClassifier):
208    """
209    MultiTree classifier is almost the same as the base class it extends
210    (:class:`~Orange.classification.tree.TreeClassifier`). Only the
211    :obj:`__call__` method is modified so it works with multi-target data.
212    """
213
214    def __call__(self, instance, return_type=Orange.core.GetValue):
215        """
216        :param instance: Instance to be classified.
217        :type instance: :class:`Orange.data.Instance`
218
219        :param return_type: One of
220            :class:`Orange.classification.Classifier.GetValue`,
221            :class:`Orange.classification.Classifier.GetProbabilities` or
222            :class:`Orange.classification.Classifier.GetBoth`
223        """
224
225        node = self.descender(self.tree, instance)[0]
226        return node.node_classifier(instance, return_type)
227
228
229if __name__ == '__main__':
230    data = Orange.data.Table('multitarget-synthetic')
231    print 'Actual classes:\n', data[0].get_classes()
232   
233    majority = Orange.classification.majority.MajorityLearner()
234    mt_majority = Orange.multitarget.MultitargetLearner(majority)
235    c_mtm = mt_majority(data)
236    print 'Majority predictions:\n', c_mtm(data[0])
237
238    mt_tree = MultiTreeLearner(max_depth=3)
239    c_mtt = mt_tree(data)
240    print 'Multi-target Tree predictions:\n', c_mtt(data[0])
Note: See TracBrowser for help on using the repository browser.