source: orange/orange/Orange/multitarget/__init__.py @ 9414:9ba8a2753530

Revision 9414:9ba8a2753530, 2.4 KB checked in by lanz <lan.zagar@…>, 2 years ago (diff)

Changed multitarget module to work with the new multiclass data sets.

Line 
1import Orange
2
3
4# Other algorithms which also work with multitarget data
5from Orange.regression import pls
6# change the default value of multi_label=True in init
7##from Orange.regression import earth
8
9
10class MultitargetLearner(Orange.classification.Learner):
11    """
12    Wrapper for multitarget problems that constructs independent models
13    of a base learner for each class variable.
14
15    .. attribute:: learner
16
17        The base learner used to learn models for each class.
18    """
19
20    def __new__(cls, learner, data=None, weight=0, **kwargs):
21        self = Orange.classification.Learner.__new__(cls, **kwargs)
22        if data:
23            self.__init__(learner, **kwargs)
24            return self.__call__(data, weight)
25        else:
26            return self
27   
28    def __init__(self, learner, **kwargs):
29        self.learner = learner
30        self.__dict__.update(kwargs)
31
32    def __call__(self, data, weight=0):
33        """Learn independent models of the base learner for each class.
34
35        :param data: Multitarget data instances (with more than 1 class).
36        :type data: Orange.data.Table
37        :param weight: Id of meta attribute with weights of instances
38        :type weight: int
39        :rtype: :class:`Orange.multitarget.MultitargetClassifier`
40        """
41
42        if not data.domain.class_vars:
43            raise Exception('No classes defined.')
44       
45        domains = [Orange.data.Domain(data.domain.attributes, y)
46                   for y in data.domain.class_vars]
47        classifiers = [self.learner(Orange.data.Table(dom, data), weight)
48                       for dom in domains]
49        return MultitargetClassifier(classifiers=classifiers, domains=domains)
50       
51
52class MultitargetClassifier(Orange.classification.Classifier):
53    """
54    Multitarget classifier returning a list of predictions from each
55    of the independent base classifiers.
56
57    .. attribute classifiers
58
59        List of individual classifiers for each class.
60    """
61
62    def __init__(self, classifiers, domains):
63        self.classifiers = classifiers
64        self.domains = domains
65
66    def __call__(self, instance, return_type=Orange.core.GetValue):
67        predictions = [c(Orange.data.Instance(dom, instance), return_type)
68                       for c, dom in zip(self.classifiers, self.domains)]
69        return zip(*predictions) if return_type == Orange.core.GetBoth \
70               else predictions
71
Note: See TracBrowser for help on using the repository browser.