source: orange/Orange/multilabel/brknn.py @ 10502:6b593a8cd5a0

Revision 10502:6b593a8cd5a0, 7.3 KB checked in by Matija Polajnar <matija.polajnar@…>, 2 years ago (diff)

Make multi-label warning and errors more clear on what a 'multi-label dataset' is from our perspective; Miha warned me students had troubles with this.

RevLine 
[9470]1"""
2.. index:: BR-kNN Learner
3
4***************************************
5BR-kNN Learner
6***************************************
7
8BR-kNN Classification is an adaptation of the kNN algorithm for multi-label classification that
9is conceptually equivalent to using the popular Binary Relevance problem transformation method
10in conjunction with the kNN algorithm. It also implements two extensions of BR-kNN.
11For more information, see E. Spyromitros, G. Tsoumakas, I. Vlahavas,
[9495]12`An Empirical Study of Lazy Multilabel Classification Algorithms <http://mlkd.csd.auth.gr/multilabel.html>`_,
[9470]13Proc. 5th Hellenic Conference on Artificial Intelligence (SETN 2008), Springer, Syros, Greece, 2008. 
14
15.. index:: BR-kNN Learner
16.. autoclass:: Orange.multilabel.BRkNNLearner
17   :members:
18   :show-inheritance:
19 
[9500]20   :param instances: a table of instances.
[9470]21   :type instances: :class:`Orange.data.Table`
22
23.. index:: BRkNN Classifier
24.. autoclass:: Orange.multilabel.BRkNNClassifier
25   :members:
26   :show-inheritance:
27   
28Examples
29========
30
31The following example demonstrates a straightforward invocation of
[9994]32this algorithm (:download:`mlc-classify.py <code/mlc-classify.py>`):
[9470]33
34.. literalinclude:: code/mlc-classify.py
[9505]35   :lines: 6-9
[9470]36
37"""
38import random
[9500]39import math
40
[9470]41import Orange
42import multiknn as _multiknn
43
44class BRkNNLearner(_multiknn.MultikNNLearner):
45    """
[9500]46    Class implementing the BR-kNN learner.
[9470]47   
48    .. attribute:: k
49   
50        Number of neighbours. If set to 0 (which is also the default value),
51        the square root of the number of instances is used.
52   
53    .. attribute:: ext
54   
[9500]55        Extension type. The default is None, means 'Standard BR'; 'a' means
56        predicting top ranked label in case of empty prediction set; 'b' means
57        predicting top n ranked labels based on size of labelset in neighbours.
[9470]58   
59    .. attribute:: knn
60       
61        :class:`Orange.classification.knn.FindNearest` for nearest neighbor search
62   
63    """
[9475]64    def __new__(cls, instances = None, k=1, ext = None, weight_id = 0, **argkw):
[9470]65        """
66        Constructor of BRkNNLearner
67       
[9500]68        :param instances: a table of instances.
[9470]69        :type instances: :class:`Orange.data.Table`
70       
71        :param k: number of nearest neighbours used in classification
72        :type k: int
73       
[9500]74        :param ext: extension type (default value is None which yields
75            the Standard BR), values 'a' and 'b' are also possible.
76        :type ext: string
[9470]77       
78        :rtype: :class:`BRkNNLearner`
79        """
80       
81        self = _multiknn.MultikNNLearner.__new__(cls, k, **argkw)
82       
[9500]83        if ext not in [None, 'a', 'b']:
84            raise ValueError, "Invalid ext value: should be None, 'a' or 'b'."
[9470]85        self.ext = ext
86       
87        if instances:
88            self.instances = instances
89            self.__init__(**argkw)
[9475]90            return self.__call__(instances,weight_id)
[9470]91        else:
92            return self
93
[9475]94    def __call__(self, instances, weight_id = 0, **kwds):
[9500]95        if not Orange.multilabel.is_multilabel(instances):
[10502]96            raise TypeError("The given data set is not a multi-label data set"
97                            " with class values 0 and 1.")
[9500]98
[9470]99        for k in kwds.keys():
100            self.__dict__[k] = kwds[k]
[9500]101        self._build_knn(instances)
[9470]102
[9500]103        labeling_f = [BRkNNClassifier.get_labels, BRkNNClassifier.get_labels_a,
104                      BRkNNClassifier.get_labels_b][ [None, 'a', 'b'].index(self.ext) ]
[9470]105       
[9500]106        return BRkNNClassifier(instances = instances,
[9470]107                               ext = self.ext,
108                               knn = self.knn,
[9500]109                               k = self.k,
110                               labeling_f = labeling_f)
[9470]111
[9500]112class BRkNNClassifier(_multiknn.MultikNNClassifier):
113    def __call__(self, instance, result_type=Orange.classification.Classifier.GetValue):
[9505]114        """
115        :rtype: a list of :class:`Orange.data.Value`, a list of :class:`Orange.statistics.distribution.Distribution`, or a tuple with both
116        """
[9500]117        domain = self.instances.domain
[9470]118
[9500]119        neighbours = self.knn(instance, self.k)
[9470]120       
[9500]121        prob = self.get_prob(neighbours)
[9470]122       
[9500]123        labels = self.labeling_f(self, prob, neighbours)
[9470]124       
125        if result_type == Orange.classification.Classifier.GetValue:
126            return labels
[9500]127
128        dists = [Orange.statistics.distribution.Discrete([1-p, p]) for p in prob]
129        for v, d in zip(self.instances.domain.class_vars, dists):
130            d.variable = v
131
[9470]132        if result_type == Orange.classification.Classifier.GetProbabilities:
[9500]133            return dists
134        return labels, dists
[9470]135   
[9500]136    def get_prob(self, neighbours):
[9470]137        """
[9500]138        Calculates the probabilities of the labels, based on the neighboring
139        instances.
[9470]140     
[9500]141        :param neighbours: a list of nearest neighboring instances.
[9470]142        :type neighbours: list of :class:`Orange.data.Instance`
143       
144        :rtype: the prob of the labels
145       
146        """
147        total = 0
[9500]148        label_count = len(self.instances.domain.class_vars)
149        confidences = [1.0 / max(1, len(self.instances))] * label_count
[9470]150
[9500]151        total = float(label_count) / max(1, len(self.instances))
[9470]152       
[9500]153        for neigh in neighbours:
154            vals = neigh.get_classes()
155            for j, value in enumerate(vals):
156                if value == '1':
157                    confidences[j] += 1
158            total += 1
[9470]159
[9500]160        #Normalize distribution
[9470]161        if total > 0:
162            confidences = [con/total for con in confidences]
163       
164        return confidences
165   
[9500]166    def get_labels(self, prob, _neighs=None, thresh=0.5):
167        return [Orange.data.Value(lvar, str(int(p>thresh)))
168                for p, lvar in zip(prob, self.instances.domain.class_vars)]
[9470]169   
[9500]170    def get_labels_a(self, prob, _neighs=None):
[9470]171        """
172        used for BRknn-a
173       
174        :param prob: the probabilities of the labels
175        :type prob: list of double
176       
177        :rtype: the list label value
178        """
[9500]179        labels = self.get_labels(prob)
[9470]180           
[9500]181        #assign the class with the greatest confidence
182        if all(l.value=='0' for l in labels):
183            index = max((v,i) for i,v in enumerate(prob))[1]
184            labels[index].value = '1'
[9470]185       
186        return labels
187   
[9500]188    def get_labels_b(self, prob, neighs):
[9470]189        """
190        used for BRknn-b
191       
192        :param prob: the probabilities of the labels
193        :type prob: list of double
194       
195        :rtype: the list label value
196        """
197       
[9500]198        labels = [Orange.data.Value(lvar, '0')
199                  for p, lvar in zip(prob, self.instances.domain.class_vars)]
[9470]200       
[9500]201        avg_label_cnt = sum(sum(l.value=='1' for l in n.get_classes())
202                            for n in neighs) / float(len(neighs))
203        avg_label_cnt = int(round(avg_label_cnt))
[9470]204       
[9500]205        for p, lval in sorted(zip(prob, labels), reverse=True)[:avg_label_cnt]:
206            lval.value = '1'
[9470]207
[9500]208        return labels
[9475]209   
210#########################################################################################
[9477]211# Test the code, run from DOS prompt
212# assume the data file is in proper directory
213
[9475]214if __name__ == "__main__":
215    data = Orange.data.Table("emotions.tab")
216
217    classifier = Orange.multilabel.BRkNNLearner(data,5)
218    for i in range(10):
219        c,p = classifier(data[i],Orange.classification.Classifier.GetBoth)
220        print c,p
Note: See TracBrowser for help on using the repository browser.