Changeset 9322:207d1bb67d5c in orange


Ignore:
Timestamp:
12/07/11 17:25:00 (2 years ago)
Author:
lanz <lan.zagar@…>
Branch:
default
Convert:
94b65e83b099c35a81814b3dab483993660828fb
Message:

Added MultitargetLearner (wrapper for single-class learners) to multitarget module.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/multitarget/__init__.py

    r9308 r9322  
     1import Orange 
     2from Orange.regression.earth import data_label_mask 
     3 
     4 
     5# Other algorithms which also work with multitarget data 
    16from Orange.regression import pls 
     7# change the default value of multi_label=True in init 
     8##from Orange.regression import earth 
     9 
     10 
     11class MultitargetLearner(Orange.classification.Learner): 
     12    """ 
     13    Wrapper for multitarget problems that constructs independent models 
     14    of a base learner for each class variable. 
     15 
     16    .. attribute:: learner 
     17 
     18        The base learner used to learn models for each class. 
     19    """ 
     20 
     21    def __new__(cls, learner, data=None, weight=0, **kwargs): 
     22        self = Orange.classification.Learner.__new__(cls, **kwargs) 
     23        if data: 
     24            self.__init__(learner, **kwargs) 
     25            return self.__call__(data, weight) 
     26        else: 
     27            return self 
     28     
     29    def __init__(self, learner, **kwargs): 
     30        self.learner = learner 
     31        self.__dict__.update(kwargs) 
     32 
     33    def __call__(self, data, weight=0): 
     34        """Learn independent models of the base learner for each class. 
     35 
     36        :param data: Multitarget data instances (with more than 1 class). 
     37        :type data: Orange.data.Table 
     38        :param weight: Id of meta attribute with weights of instances 
     39        :type weight: int 
     40        :rtype: :class:`Orange.multitarget.MultitargetClassifier` 
     41        """ 
     42 
     43        label_mask = data_label_mask(data.domain) 
     44        if sum(label_mask) == 0: 
     45            raise 'No classes/labels defined.' 
     46        x_vars = [v for v, label in zip(data.domain, label_mask) if not label] 
     47        y_vars = [v for v, label in zip(data.domain, label_mask) if label] 
     48 
     49        classifiers = [self.learner(Orange.data.Table(Orange.data.Domain(x_vars, y), 
     50            data), weight) for y in y_vars] 
     51        return MultitargetClassifier(classifiers=classifiers, x_vars=x_vars, y_vars=y_vars) 
     52         
     53 
     54class MultitargetClassifier(Orange.classification.Classifier): 
     55    """ 
     56    Multitarget classifier returning a list of predictions from each 
     57    of the independent base classifiers. 
     58 
     59    .. attribute classifiers 
     60 
     61        List of individual classifiers for each class. 
     62    """ 
     63 
     64    def __init__(self, classifiers, x_vars, y_vars): 
     65        self.classifiers = classifiers 
     66        self.x_vars = x_vars 
     67        self.y_vars = y_vars 
     68 
     69    def __call__(self, instance, return_type=Orange.core.GetValue): 
     70        predictions = [c(Orange.data.Instance(Orange.data.Domain(self.x_vars, y), 
     71            instance), return_type) for c, y in zip(self.classifiers, self.y_vars)] 
     72        return zip(*predictions) if return_type == Orange.core.GetBoth else predictions 
     73 
Note: See TracChangeset for help on using the changeset viewer.