source: orange/orange/Orange/multitarget/tree.py @ 9416:0e72e57bf87b

Revision 9416:0e72e57bf87b, 3.6 KB checked in by lanz <lan.zagar@…>, 2 years ago (diff)

Removed a workaround for a bug that was fixed (#1012).

Line 
1from operator import itemgetter
2
3import numpy as np
4import Orange
5
6
7def weighted_variance(X, weights=None):
8    """Computes the variance using a weighted distance between vectors"""
9    global foo
10    foo = X
11
12    if not weights:
13        weights = [1] * len(X[0])
14
15    X = X * np.array(weights)
16    return np.sum((X - np.mean(X, 0))**2, 1).sum()
17
18class MultitargetVariance(Orange.feature.scoring.Score):
19    def __init__(self, weights=None):
20        # Types of classes allowed
21        self.handles_discrete = True
22        self.handles_continuous = True
23        # Can handle continuous features
24        self.computes_thresholds = True
25        # Needs instances
26        self.needs = Orange.feature.scoring.Score.Generator
27
28        self.weights = weights
29
30
31    def threshold_function(self, feature, data, cont_distrib=None, weightID=0):
32        f = data.domain[feature]
33        values = sorted(set(ins[f].value for ins in data))
34        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])]
35        if len(ts) > 40:
36            ts = ts[::len(ts)/20]
37        scores = []
38        for t in ts:
39            bf = Orange.feature.discretization.IntervalDiscretizer(
40                points=[t]).construct_variable(f)
41            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars)
42            data2 = Orange.data.Table(dom2, data)
43            scores.append((t, self.__call__(bf, data2)))
44        return scores
45
46    def best_threshold(self, feature, data):
47        scores = self.threshold_function(feature, data)
48        threshold, score = max(scores, key=itemgetter(1))
49        return (threshold, score, None)
50
51    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0):
52        split = dict((ins[feature].value, []) for ins in data)
53        for ins in data:
54            split[ins[feature].value].append(ins.get_classes())
55        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values())
56        return score
57
58
59class MultiTreeLearner(Orange.classification.tree.TreeLearner):
60    """
61    MultiTreeLearner is a multitarget equivalent of the TreeLearner.
62    It is the same as Orange.classification.tree.TreeLearner, except for
63    the default values of two parameters:
64    measure: MultitargetVariance
65    node_learner: Orange.multitarget.MultitargetLearner(Orange.classification.majority.MajorityLearner())
66    """
67
68    def __init__(self, measure=MultitargetVariance(), 
69                 node_learner=Orange.multitarget.MultitargetLearner(
70                     Orange.classification.majority.MajorityLearner()),
71                 **kwargs):
72        Orange.classification.tree.TreeLearner.__init__(
73            self, measure=measure, node_learner=node_learner, **kwargs)
74
75    def __call__(self, data, weight=0):
76        # TreeLearner does not work on class-less domains,
77        # so we set the class if necessary
78        if data.domain.class_var is None:
79            data2 = Orange.data.Table(Orange.data.Domain(
80                data.domain.attributes, data.domain.class_vars[0],
81                class_vars=data.domain.class_vars), data)
82        tree = Orange.classification.tree.TreeLearner.__call__(self, data2, weight)
83        return MultiTree(base_classifier=tree)
84
85class MultiTree(Orange.classification.tree.TreeClassifier):
86    """MultiTree classifier"""
87
88    def __call__(self, instance, return_type=Orange.core.GetValue):
89        node = self.descender(self.tree, instance)[0]
90        return node.node_classifier(instance, return_type)
91
92
93if __name__ == '__main__':
94    data = Orange.data.Table('emotions')
95    print 'Actual classes:\n', data[0].get_classes()
96    mt = MultiTreeLearner(max_depth=2)
97    c = mt(data)
98    print 'Predicted classes:\n', c(data[0])
99
Note: See TracBrowser for help on using the repository browser.