Changeset 9414:9ba8a2753530 in orange


Ignore:
Timestamp:
12/23/11 18:05:35 (2 years ago)
Author:
lanz <lan.zagar@…>
Branch:
default
Convert:
fb97248d72e6f6ddcf2410471373427ff4048bab
Message:

Changed multitarget module to work with the new multiclass data sets.

Location:
orange/Orange/multitarget
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/multitarget/__init__.py

    r9322 r9414  
    11import Orange 
    2 from Orange.regression.earth import data_label_mask 
    32 
    43 
     
    4140        """ 
    4241 
    43         label_mask = data_label_mask(data.domain) 
    44         if sum(label_mask) == 0: 
    45             raise 'No classes/labels defined.' 
    46         x_vars = [v for v, label in zip(data.domain, label_mask) if not label] 
    47         y_vars = [v for v, label in zip(data.domain, label_mask) if label] 
    48  
    49         classifiers = [self.learner(Orange.data.Table(Orange.data.Domain(x_vars, y), 
    50             data), weight) for y in y_vars] 
    51         return MultitargetClassifier(classifiers=classifiers, x_vars=x_vars, y_vars=y_vars) 
     42        if not data.domain.class_vars: 
     43            raise Exception('No classes defined.') 
     44         
     45        domains = [Orange.data.Domain(data.domain.attributes, y) 
     46                   for y in data.domain.class_vars] 
     47        classifiers = [self.learner(Orange.data.Table(dom, data), weight) 
     48                       for dom in domains] 
     49        return MultitargetClassifier(classifiers=classifiers, domains=domains) 
    5250         
    5351 
     
    6260    """ 
    6361 
    64     def __init__(self, classifiers, x_vars, y_vars): 
     62    def __init__(self, classifiers, domains): 
    6563        self.classifiers = classifiers 
    66         self.x_vars = x_vars 
    67         self.y_vars = y_vars 
     64        self.domains = domains 
    6865 
    6966    def __call__(self, instance, return_type=Orange.core.GetValue): 
    70         predictions = [c(Orange.data.Instance(Orange.data.Domain(self.x_vars, y), 
    71             instance), return_type) for c, y in zip(self.classifiers, self.y_vars)] 
    72         return zip(*predictions) if return_type == Orange.core.GetBoth else predictions 
     67        predictions = [c(Orange.data.Instance(dom, instance), return_type) 
     68                       for c, dom in zip(self.classifiers, self.domains)] 
     69        return zip(*predictions) if return_type == Orange.core.GetBoth \ 
     70               else predictions 
    7371 
  • orange/Orange/multitarget/tree.py

    r9330 r9414  
    33import numpy as np 
    44import Orange 
    5 from Orange.multitarget import data_label_mask 
    65 
    76 
    87def weighted_variance(X, weights=None): 
    98    """Computes the variance using a weighted distance between vectors""" 
    10     if weights: 
    11         X = X * np.array(weights) 
     9    global foo 
     10    foo = X 
     11 
     12    if not weights: 
     13        weights = [1] * len(X[0]) 
     14 
     15    X = X * np.array(weights) 
    1216    return np.sum((X - np.mean(X, 0))**2, 1).sum() 
    1317 
     
    2731    def threshold_function(self, feature, data, cont_distrib=None, weightID=0): 
    2832        f = data.domain[feature] 
    29         label_mask = data_label_mask(data.domain) 
    30         classes = [v for v, label in zip(data.domain, label_mask) if label] 
    3133        values = sorted(set(ins[f].value for ins in data)) 
    3234        ts = [(v1 + v2) / 2. for v1, v2 in zip(values, values[1:])] 
     
    3739            bf = Orange.feature.discretization.IntervalDiscretizer( 
    3840                points=[t]).construct_variable(f) 
    39             data2 = data.select([bf] + classes) 
     41            dom2 = Orange.data.Domain([bf], class_vars=data.domain.class_vars) 
     42            data2 = Orange.data.Table(dom2, data) 
     43            # TODO: remove when bug is fixed (currently => very slow) 
     44            for i1, i2 in zip(data, data2): 
     45                i2.set_classes(i1.get_classes()) 
    4046            scores.append((t, self.__call__(bf, data2))) 
    4147        return scores 
     
    4753 
    4854    def __call__(self, feature, data, apriori_class_distribution=None, weightID=0): 
    49         if data.domain[feature].attributes.has_key('label'): 
    50             return float('-inf') 
    51         label_mask = data_label_mask(data.domain) 
    52         classes = [v for v, label in zip(data.domain, label_mask) if label] 
    5355        split = dict((ins[feature].value, []) for ins in data) 
    5456        for ins in data: 
    55             # TODO: does not work when there are missing class values 
    56             split[ins[feature].value].append([float(ins[c]) for c in classes]) 
     57            split[ins[feature].value].append(ins.get_classes()) 
    5758        score = -sum(weighted_variance(x, self.weights) * len(x) for x in split.values()) 
    5859        return score 
     
    7980        # so we set the class if necessary 
    8081        if data.domain.class_var is None: 
    81             for var in data.domain: 
    82                 if var.attributes.has_key('label'): 
    83                     data = Orange.data.Table(Orange.data.Domain(data.domain, var), 
    84                                              data) 
    85                     break 
    86  
    87         tree = Orange.classification.tree.TreeLearner.__call__(self, data, weight) 
     82            data2 = Orange.data.Table(Orange.data.Domain( 
     83                data.domain.attributes, data.domain.class_vars[0], 
     84                class_vars=data.domain.class_vars), data) 
     85        # until the bug is fixed, manually set correct values of classes 
     86        for i1, i2 in zip(data, data2): 
     87            i2.set_classes(i1.get_classes()) 
     88        tree = Orange.classification.tree.TreeLearner.__call__(self, data2, weight) 
    8889        return MultiTree(base_classifier=tree) 
    8990 
     
    9596        return node.node_classifier(instance, return_type) 
    9697 
     98 
     99if __name__ == '__main__': 
     100    data = Orange.data.Table('emotions') 
     101    print 'Actual classes:\n', data[0].get_classes() 
     102    mt = MultiTreeLearner(max_depth=2) 
     103    c = mt(data) 
     104    print 'Predicted classes:\n', c(data[0]) 
     105 
Note: See TracChangeset for help on using the changeset viewer.