#
source:
orange/orange/multilabel/brknn.py
@
9669:165371b04b4a

Revision 9669:165371b04b4a, 7.3 KB checked in by anze <anze.staric@…>, 2 years ago (diff) |
---|

Line | |
---|---|

1 | """ |

2 | .. index:: BR-kNN Learner |

3 | |

4 | *************************************** |

5 | BR-kNN Learner |

6 | *************************************** |

7 | |

8 | BR-kNN Classification is an adaptation of the kNN algorithm for multi-label classification that |

9 | is conceptually equivalent to using the popular Binary Relevance problem transformation method |

10 | in conjunction with the kNN algorithm. It also implements two extensions of BR-kNN. |

11 | For more information, see E. Spyromitros, G. Tsoumakas, I. Vlahavas, |

12 | `An Empirical Study of Lazy Multilabel Classification Algorithms <http://mlkd.csd.auth.gr/multilabel.html>`_, |

13 | Proc. 5th Hellenic Conference on Artificial Intelligence (SETN 2008), Springer, Syros, Greece, 2008. |

14 | |

15 | .. index:: BR-kNN Learner |

16 | .. autoclass:: Orange.multilabel.BRkNNLearner |

17 | :members: |

18 | :show-inheritance: |

19 | |

20 | :param instances: a table of instances. |

21 | :type instances: :class:`Orange.data.Table` |

22 | |

23 | .. index:: BRkNN Classifier |

24 | .. autoclass:: Orange.multilabel.BRkNNClassifier |

25 | :members: |

26 | :show-inheritance: |

27 | |

28 | Examples |

29 | ======== |

30 | |

31 | The following example demonstrates a straightforward invocation of |

32 | this algorithm (:download:`mlc-classify.py <code/mlc-classify.py>`, uses |

33 | :download:`emotions.tab <code/emotions.tab>`): |

34 | |

35 | .. literalinclude:: code/mlc-classify.py |

36 | :lines: 6-9 |

37 | |

38 | """ |

39 | import random |

40 | import math |

41 | |

42 | import Orange |

43 | import multiknn as _multiknn |

44 | |

45 | class BRkNNLearner(_multiknn.MultikNNLearner): |

46 | """ |

47 | Class implementing the BR-kNN learner. |

48 | |

49 | .. attribute:: k |

50 | |

51 | Number of neighbours. If set to 0 (which is also the default value), |

52 | the square root of the number of instances is used. |

53 | |

54 | .. attribute:: ext |

55 | |

56 | Extension type. The default is None, means 'Standard BR'; 'a' means |

57 | predicting top ranked label in case of empty prediction set; 'b' means |

58 | predicting top n ranked labels based on size of labelset in neighbours. |

59 | |

60 | .. attribute:: knn |

61 | |

62 | :class:`Orange.classification.knn.FindNearest` for nearest neighbor search |

63 | |

64 | """ |

65 | def __new__(cls, instances = None, k=1, ext = None, weight_id = 0, **argkw): |

66 | """ |

67 | Constructor of BRkNNLearner |

68 | |

69 | :param instances: a table of instances. |

70 | :type instances: :class:`Orange.data.Table` |

71 | |

72 | :param k: number of nearest neighbours used in classification |

73 | :type k: int |

74 | |

75 | :param ext: extension type (default value is None which yields |

76 | the Standard BR), values 'a' and 'b' are also possible. |

77 | :type ext: string |

78 | |

79 | :rtype: :class:`BRkNNLearner` |

80 | """ |

81 | |

82 | self = _multiknn.MultikNNLearner.__new__(cls, k, **argkw) |

83 | |

84 | if ext not in [None, 'a', 'b']: |

85 | raise ValueError, "Invalid ext value: should be None, 'a' or 'b'." |

86 | self.ext = ext |

87 | |

88 | if instances: |

89 | self.instances = instances |

90 | self.__init__(**argkw) |

91 | return self.__call__(instances,weight_id) |

92 | else: |

93 | return self |

94 | |

95 | def __call__(self, instances, weight_id = 0, **kwds): |

96 | if not Orange.multilabel.is_multilabel(instances): |

97 | raise TypeError("The given data set is not a multi-label data set.") |

98 | |

99 | for k in kwds.keys(): |

100 | self.__dict__[k] = kwds[k] |

101 | self._build_knn(instances) |

102 | |

103 | labeling_f = [BRkNNClassifier.get_labels, BRkNNClassifier.get_labels_a, |

104 | BRkNNClassifier.get_labels_b][ [None, 'a', 'b'].index(self.ext) ] |

105 | |

106 | return BRkNNClassifier(instances = instances, |

107 | ext = self.ext, |

108 | knn = self.knn, |

109 | k = self.k, |

110 | labeling_f = labeling_f) |

111 | |

112 | class BRkNNClassifier(_multiknn.MultikNNClassifier): |

113 | def __call__(self, instance, result_type=Orange.classification.Classifier.GetValue): |

114 | """ |

115 | :rtype: a list of :class:`Orange.data.Value`, a list of :class:`Orange.statistics.distribution.Distribution`, or a tuple with both |

116 | """ |

117 | domain = self.instances.domain |

118 | |

119 | neighbours = self.knn(instance, self.k) |

120 | |

121 | prob = self.get_prob(neighbours) |

122 | |

123 | labels = self.labeling_f(self, prob, neighbours) |

124 | |

125 | if result_type == Orange.classification.Classifier.GetValue: |

126 | return labels |

127 | |

128 | dists = [Orange.statistics.distribution.Discrete([1-p, p]) for p in prob] |

129 | for v, d in zip(self.instances.domain.class_vars, dists): |

130 | d.variable = v |

131 | |

132 | if result_type == Orange.classification.Classifier.GetProbabilities: |

133 | return dists |

134 | return labels, dists |

135 | |

136 | def get_prob(self, neighbours): |

137 | """ |

138 | Calculates the probabilities of the labels, based on the neighboring |

139 | instances. |

140 | |

141 | :param neighbours: a list of nearest neighboring instances. |

142 | :type neighbours: list of :class:`Orange.data.Instance` |

143 | |

144 | :rtype: the prob of the labels |

145 | |

146 | """ |

147 | total = 0 |

148 | label_count = len(self.instances.domain.class_vars) |

149 | confidences = [1.0 / max(1, len(self.instances))] * label_count |

150 | |

151 | total = float(label_count) / max(1, len(self.instances)) |

152 | |

153 | for neigh in neighbours: |

154 | vals = neigh.get_classes() |

155 | for j, value in enumerate(vals): |

156 | if value == '1': |

157 | confidences[j] += 1 |

158 | total += 1 |

159 | |

160 | #Normalize distribution |

161 | if total > 0: |

162 | confidences = [con/total for con in confidences] |

163 | |

164 | return confidences |

165 | |

166 | def get_labels(self, prob, _neighs=None, thresh=0.5): |

167 | return [Orange.data.Value(lvar, str(int(p>thresh))) |

168 | for p, lvar in zip(prob, self.instances.domain.class_vars)] |

169 | |

170 | def get_labels_a(self, prob, _neighs=None): |

171 | """ |

172 | used for BRknn-a |

173 | |

174 | :param prob: the probabilities of the labels |

175 | :type prob: list of double |

176 | |

177 | :rtype: the list label value |

178 | """ |

179 | labels = self.get_labels(prob) |

180 | |

181 | #assign the class with the greatest confidence |

182 | if all(l.value=='0' for l in labels): |

183 | index = max((v,i) for i,v in enumerate(prob))[1] |

184 | labels[index].value = '1' |

185 | |

186 | return labels |

187 | |

188 | def get_labels_b(self, prob, neighs): |

189 | """ |

190 | used for BRknn-b |

191 | |

192 | :param prob: the probabilities of the labels |

193 | :type prob: list of double |

194 | |

195 | :rtype: the list label value |

196 | """ |

197 | |

198 | labels = [Orange.data.Value(lvar, '0') |

199 | for p, lvar in zip(prob, self.instances.domain.class_vars)] |

200 | |

201 | avg_label_cnt = sum(sum(l.value=='1' for l in n.get_classes()) |

202 | for n in neighs) / float(len(neighs)) |

203 | avg_label_cnt = int(round(avg_label_cnt)) |

204 | |

205 | for p, lval in sorted(zip(prob, labels), reverse=True)[:avg_label_cnt]: |

206 | lval.value = '1' |

207 | |

208 | return labels |

209 | |

210 | ######################################################################################### |

211 | # Test the code, run from DOS prompt |

212 | # assume the data file is in proper directory |

213 | |

214 | if __name__ == "__main__": |

215 | data = Orange.data.Table("emotions.tab") |

216 | |

217 | classifier = Orange.multilabel.BRkNNLearner(data,5) |

218 | for i in range(10): |

219 | c,p = classifier(data[i],Orange.classification.Classifier.GetBoth) |

220 | print c,p |

**Note:**See TracBrowser for help on using the repository browser.