#
source:
orange/source/orange/tdidt_simple.cpp
@
10965:873ff9bf106c

Revision 10965:873ff9bf106c, 24.6 KB checked in by Ales Erjavec <ales.erjavec@…>, 20 months ago (diff) |
---|

Rev | Line | |
---|---|---|

[8378] | 1 | /* |

[10206] | 2 | This file is part of Orange. |

[8770] | 3 | |

[10206] | 4 | Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana |

5 | Contact: janez.demsar@fri.uni-lj.si | |

[8378] | 6 | |

[10206] | 7 | Orange is free software: you can redistribute it and/or modify |

8 | it under the terms of the GNU General Public License as published by | |

9 | the Free Software Foundation, either version 3 of the License, or | |

10 | (at your option) any later version. | |

[8378] | 11 | |

[10206] | 12 | Orange is distributed in the hope that it will be useful, |

13 | but WITHOUT ANY WARRANTY; without even the implied warranty of | |

14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |

15 | GNU General Public License for more details. | |

[8378] | 16 | |

[10206] | 17 | You should have received a copy of the GNU General Public License |

18 | along with Orange. If not, see <http://www.gnu.org/licenses/>. | |

[8378] | 19 | */ |

20 | ||

[10206] | 21 | #include <iostream> |

22 | #include <sstream> | |

[8378] | 23 | #include <math.h> |

24 | #include <stdlib.h> | |

25 | #include <cstring> | |

26 | ||

27 | #include "vars.hpp" | |

28 | #include "domain.hpp" | |

29 | #include "distvars.hpp" | |

30 | #include "examples.hpp" | |

31 | #include "examplegen.hpp" | |

32 | #include "table.hpp" | |

33 | #include "classify.hpp" | |

34 | ||

35 | #include "tdidt_simple.ppp" | |

36 | ||

[8396] | 37 | #ifndef _MSC_VER |

[10206] | 38 | #include "err.h" |

39 | #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__) | |

[8396] | 40 | #else |

[10206] | 41 | #define ASSERT(x) if(!(x)) exit(1) |

42 | #define log2f(x) log((double) (x)) / log(2.0) | |

[8396] | 43 | #endif // _MSC_VER |

44 | ||

45 | #ifndef INFINITY | |

[10206] | 46 | #include <limits> |

47 | #define INFINITY numeric_limits<float>::infinity() | |

[8396] | 48 | #endif // INFINITY |

49 | ||

[8378] | 50 | struct Args { |

[10206] | 51 | int minInstances, maxDepth; |

52 | float maxMajority, skipProb; | |

[8378] | 53 | |

[10206] | 54 | int type, *attr_split_so_far; |

55 | PDomain domain; | |

[9296] | 56 | PRandomGenerator randomGenerator; |

[8378] | 57 | }; |

58 | ||

[8770] | 59 | struct Example { |

[10206] | 60 | TExample *example; |

61 | float weight; | |

[8770] | 62 | }; |

63 | ||

64 | enum { DiscreteNode, ContinuousNode, PredictorNode }; | |

65 | enum { Classification, Regression }; | |

66 | ||

[8378] | 67 | int compar_attr; |

68 | ||

[8131] | 69 | /* This function uses the global variable compar_attr. |

[8770] | 70 | * Examples with unknowns are larger so that, when sorted, they appear at the bottom. |

[8131] | 71 | */ |

[8378] | 72 | int |

[8131] | 73 | compar_examples(const void *ptr1, const void *ptr2) |

[8378] | 74 | { |

[10206] | 75 | struct Example *e1, *e2; |

[8770] | 76 | |

[10206] | 77 | e1 = (struct Example *)ptr1; |

78 | e2 = (struct Example *)ptr2; | |

79 | if (e1->example->values[compar_attr].isSpecial()) | |

80 | return 1; | |

81 | if (e2->example->values[compar_attr].isSpecial()) | |

82 | return -1; | |

83 | return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]); | |

[8770] | 84 | } |

85 | ||

86 | ||

87 | float | |

88 | entropy(float *xs, int size) | |

89 | { | |

[10206] | 90 | float *ip, *end, sum, e; |

[8770] | 91 | |

[10206] | 92 | for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++) |

93 | if (*ip > 0.0) { | |

94 | e -= *ip * log2f(*ip); | |

95 | sum += *ip; | |

96 | } | |

[8770] | 97 | |

[10206] | 98 | return sum == 0.0 ? 0.0 : e / sum + log2f(sum); |

[8770] | 99 | } |

100 | ||

101 | int | |

102 | test_min_examples(float *attr_dist, int attr_vals, struct Args *args) | |

103 | { | |

[10206] | 104 | int i; |

[8770] | 105 | |

[10206] | 106 | for (i = 0; i < attr_vals; i++) { |

107 | if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances) | |

108 | return 0; | |

109 | } | |

110 | return 1; | |

[8378] | 111 | } |

112 | ||

113 | float | |

[8770] | 114 | gain_ratio_c(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args, float *best_split) |

[8378] | 115 | { |

[10206] | 116 | struct Example *ex, *ex_end, *ex_next; |

117 | int i, cls, cls_vals, minInstances, size_known; | |

118 | float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight; | |

[8378] | 119 | |

[10206] | 120 | cls_vals = args->domain->classVar->noOfValues(); |

[8378] | 121 | |

[10206] | 122 | /* minInstances should be at least 1, otherwise there is no point in splitting */ |

123 | minInstances = args->minInstances < 1 ? 1 : args->minInstances; | |

[8770] | 124 | |

[10206] | 125 | /* allocate space */ |

126 | ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt)); | |

127 | ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge)); | |

128 | ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist)); | |

[8378] | 129 | |

[10206] | 130 | /* sort */ |

131 | compar_attr = attr; | |

132 | qsort(examples, size, sizeof(struct Example), compar_examples); | |

[8378] | 133 | |

[10206] | 134 | /* compute gain ratio for every split */ |

135 | size_known = size; | |

136 | size_weight = 0.0; | |

137 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { | |

138 | if (ex->example->values[attr].isSpecial()) { | |

139 | size_known = ex - examples; | |

140 | break; | |

141 | } | |

142 | if (!ex->example->getClass().isSpecial()) | |

143 | dist_ge[ex->example->getClass().intV] += ex->weight; | |

144 | size_weight += ex->weight; | |

145 | } | |

[8378] | 146 | |

[10206] | 147 | attr_dist[1] = size_weight; |

148 | best_score = -INFINITY; | |

[8131] | 149 | |

[10206] | 150 | for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { |

151 | if (!ex->example->getClass().isSpecial()) { | |

152 | cls = ex->example->getClass().intV; | |

153 | dist_lt[cls] += ex->weight; | |

154 | dist_ge[cls] -= ex->weight; | |

155 | } | |

156 | attr_dist[0] += ex->weight; | |

157 | attr_dist[1] -= ex->weight; | |

[8378] | 158 | |

[10206] | 159 | if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) |

160 | continue; | |

[8378] | 161 | |

[10206] | 162 | /* gain ratio */ |

163 | score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight; | |

164 | score = (cls_entropy - score) / entropy(attr_dist, 2); | |

[8378] | 165 | |

[8770] | 166 | |

[10206] | 167 | if (score > best_score) { |

168 | best_score = score; | |

169 | *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; | |

170 | } | |

171 | } | |

[8378] | 172 | |

[10206] | 173 | /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ |

[8131] | 174 | |

[10206] | 175 | /* cleanup */ |

176 | free(dist_lt); | |

177 | free(dist_ge); | |

178 | free(attr_dist); | |

[8378] | 179 | |

[10206] | 180 | return best_score; |

[8378] | 181 | } |

182 | ||

[8770] | 183 | |

[8378] | 184 | float |

[8770] | 185 | gain_ratio_d(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args) |

[8378] | 186 | { |

[10206] | 187 | struct Example *ex, *ex_end; |

188 | int i, cls_vals, attr_vals, attr_val, cls_val; | |

189 | float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known; | |

[8378] | 190 | |

[10206] | 191 | cls_vals = args->domain->classVar->noOfValues(); |

192 | attr_vals = args->domain->attributes->at(attr)->noOfValues(); | |

[8378] | 193 | |

[10206] | 194 | /* allocate space */ |

195 | ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *))); | |

196 | ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *))); | |

197 | ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *))); | |

[8378] | 198 | |

[10206] | 199 | /* contingency matrix */ |

200 | size_weight = 0.0; | |

201 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { | |

202 | if (!ex->example->values[attr].isSpecial()) { | |

203 | attr_val = ex->example->values[attr].intV; | |

204 | attr_dist[attr_val] += ex->weight; | |

205 | if (!ex->example->getClass().isSpecial()) { | |

206 | cls_val = ex->example->getClass().intV; | |

207 | attr_dist_cls_known[attr_val] += ex->weight; | |

208 | cont[attr_val * cls_vals + cls_val] += ex->weight; | |

209 | } | |

210 | } | |

211 | size_weight += ex->weight; | |

212 | } | |

[8378] | 213 | |

[10206] | 214 | /* min examples in leaves */ |

215 | if (!test_min_examples(attr_dist, attr_vals, args)) { | |

216 | score = -INFINITY; | |

217 | goto finish; | |

218 | } | |

[8378] | 219 | |

[10206] | 220 | size_attr_known = size_attr_cls_known = 0.0; |

221 | for (i = 0; i < attr_vals; i++) { | |

222 | size_attr_known += attr_dist[i]; | |

223 | size_attr_cls_known += attr_dist_cls_known[i]; | |

224 | } | |

[8378] | 225 | |

[10206] | 226 | /* gain ratio */ |

227 | score = 0.0; | |

228 | for (i = 0; i < attr_vals; i++) | |

229 | score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals); | |

230 | attr_entropy = entropy(attr_dist, attr_vals); | |

[8770] | 231 | |

[10206] | 232 | if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) { |

233 | score = -INFINITY; | |

234 | goto finish; | |

235 | } | |

[8770] | 236 | |

[10206] | 237 | score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight); |

[8770] | 238 | |

[10206] | 239 | /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */ |

[8770] | 240 | |

241 | finish: | |

[10206] | 242 | free(cont); |

243 | free(attr_dist); | |

244 | free(attr_dist_cls_known); | |

245 | return score; | |

[8770] | 246 | } |

247 | ||

248 | ||

249 | float | |

250 | mse_c(struct Example *examples, int size, int attr, float cls_mse, struct Args *args, float *best_split) | |

251 | { | |

[10206] | 252 | struct Example *ex, *ex_end, *ex_next; |

253 | int i, cls_vals, minInstances, size_known; | |

254 | float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score; | |

[8770] | 255 | |

[10206] | 256 | struct Variance { |

257 | double n, sum, sum2; | |

258 | } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0}; | |

[8770] | 259 | |

[10206] | 260 | cls_vals = args->domain->classVar->noOfValues(); |

[8770] | 261 | |

[10206] | 262 | /* minInstances should be at least 1, otherwise there is no point in splitting */ |

263 | minInstances = args->minInstances < 1 ? 1 : args->minInstances; | |

[8770] | 264 | |

[10206] | 265 | /* sort */ |

266 | compar_attr = attr; | |

267 | qsort(examples, size, sizeof(struct Example), compar_examples); | |

[8770] | 268 | |

[10206] | 269 | /* compute mse for every split */ |

270 | size_known = size; | |

271 | size_attr_known = 0.0; | |

272 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { | |

273 | if (ex->example->values[attr].isSpecial()) { | |

274 | size_known = ex - examples; | |

275 | break; | |

276 | } | |

277 | if (!ex->example->getClass().isSpecial()) { | |

278 | cls_val = ex->example->getClass().floatV; | |

279 | var_ge.n += ex->weight; | |

280 | var_ge.sum += ex->weight * cls_val; | |

281 | var_ge.sum2 += ex->weight * cls_val * cls_val; | |

282 | } | |

283 | size_attr_known += ex->weight; | |

284 | } | |

[8770] | 285 | |

[10206] | 286 | /* count the remaining examples with unknown values */ |

287 | size_weight = size_attr_known; | |

288 | for (ex_end = examples + size; ex < ex_end; ex++) | |

289 | size_weight += ex->weight; | |

[8770] | 290 | |

[10206] | 291 | size_attr_cls_known = var_ge.n; |

292 | best_score = -INFINITY; | |

[8770] | 293 | |

[10206] | 294 | for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { |

295 | if (!ex->example->getClass().isSpecial()) { | |

296 | cls_val = ex->example->getClass(); | |

297 | var_lt.n += ex->weight; | |

298 | var_lt.sum += ex->weight * cls_val; | |

299 | var_lt.sum2 += ex->weight * cls_val * cls_val; | |

[8770] | 300 | |

[10206] | 301 | /* this calculation might be numarically unstable - fix */ |

302 | var_ge.n -= ex->weight; | |

303 | var_ge.sum -= ex->weight * cls_val; | |

304 | var_ge.sum2 -= ex->weight * cls_val * cls_val; | |

305 | } | |

[8770] | 306 | |

[10206] | 307 | /* Naive calculation of variance (used for testing) |

308 | ||

309 | struct Example *ex2, *ex_end2; | |

310 | float nlt, sumlt, sum2lt, nge, sumge, sum2ge; | |

311 | nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0; | |

[9168] | 312 | |

[10206] | 313 | for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) { |

314 | cls_val = ex2->example->getClass(); | |

315 | if (ex2 < ex) { | |

316 | nlt += ex2->weight; | |

317 | sumlt += ex2->weight * cls_val; | |

318 | sum2lt += ex2->weight * cls_val * cls_val; | |

319 | } else { | |

320 | nge += ex2->weight; | |

321 | sumge += ex2->weight * cls_val; | |

322 | sum2ge += ex2->weight * cls_val * cls_val; | |

323 | } | |

324 | } | |

325 | */ | |

[9168] | 326 | |

327 | ||

[10206] | 328 | if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) |

329 | continue; | |

[8770] | 330 | |

[10206] | 331 | /* compute mse */ |

332 | score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n; | |

333 | score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n; | |

[9168] | 334 | |

[10206] | 335 | score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); |

[8770] | 336 | |

[10206] | 337 | if (score > best_score) { |

338 | best_score = score; | |

339 | *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; | |

340 | } | |

341 | } | |

[8770] | 342 | |

[10206] | 343 | /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ |

344 | return best_score; | |

[8770] | 345 | } |

346 | ||

347 | ||

348 | float | |

349 | mse_d(struct Example *examples, int size, int attr, float cls_mse, struct Args *args) | |

350 | { | |

[10206] | 351 | int i, attr_vals, attr_val; |

352 | float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight; | |

353 | struct Example *ex, *ex_end; | |

[8770] | 354 | |

[10206] | 355 | struct Variance { |

356 | float n, sum, sum2; | |

357 | } *variances, *v, *v_end; | |

[8770] | 358 | |

[10206] | 359 | attr_vals = args->domain->attributes->at(attr)->noOfValues(); |

[8770] | 360 | |

[10206] | 361 | ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances)); |

362 | ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); | |

[8770] | 363 | |

[10206] | 364 | size_weight = size_attr_cls_known = size_attr_known = 0.0; |

365 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { | |

366 | if (!ex->example->values[attr].isSpecial()) { | |

367 | attr_dist[ex->example->values[attr].intV] += ex->weight; | |

368 | size_attr_known += ex->weight; | |

[8770] | 369 | |

[10206] | 370 | if (!ex->example->getClass().isSpecial()) { |

371 | cls_val = ex->example->getClass().floatV; | |

372 | v = variances + ex->example->values[attr].intV; | |

373 | v->n += ex->weight; | |

374 | v->sum += ex->weight * cls_val; | |

375 | v->sum2 += ex->weight * cls_val * cls_val; | |

376 | size_attr_cls_known += ex->weight; | |

377 | } | |

378 | } | |

379 | size_weight += ex->weight; | |

380 | } | |

[8770] | 381 | |

[10206] | 382 | /* minimum examples in leaves */ |

383 | if (!test_min_examples(attr_dist, attr_vals, args)) { | |

384 | score = -INFINITY; | |

385 | goto finish; | |

386 | } | |

[8770] | 387 | |

[10206] | 388 | score = 0.0; |

389 | for (v = variances, v_end = variances + attr_vals; v < v_end; v++) | |

390 | if (v->n > 0.0) | |

391 | score += v->sum2 - v->sum * v->sum / v->n; | |

392 | score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); | |

[8770] | 393 | |

[10206] | 394 | if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0) |

395 | score = 0.0; | |

[8770] | 396 | |

397 | finish: | |

[10206] | 398 | free(attr_dist); |

399 | free(variances); | |

[8770] | 400 | |

[10206] | 401 | return score; |

[8770] | 402 | } |

403 | ||

404 | ||

405 | struct SimpleTreeNode * | |

406 | make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args) | |

407 | { | |

[10206] | 408 | node->type = PredictorNode; |

409 | node->children_size = 0; | |

410 | return node; | |

[8770] | 411 | } |

412 | ||

413 | ||

414 | struct SimpleTreeNode * | |

415 | build_tree(struct Example *examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args) | |

416 | { | |

[10206] | 417 | int i, cls_vals, best_attr; |

418 | float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split; | |

419 | struct SimpleTreeNode *node; | |

420 | struct Example *ex, *ex_end; | |

421 | TVarList::const_iterator it; | |

[8770] | 422 | |

[10206] | 423 | cls_vals = args->domain->classVar->noOfValues(); |

[8770] | 424 | |

[10206] | 425 | ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node)); |

[8770] | 426 | |

[10206] | 427 | if (args->type == Classification) { |

428 | ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *))); | |

[8770] | 429 | |

[10206] | 430 | if (size == 0) { |

431 | assert(parent); | |

432 | node->type = PredictorNode; | |

433 | node->children_size = 0; | |

434 | memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist); | |

435 | return node; | |

436 | } | |

[8770] | 437 | |

[10206] | 438 | /* class distribution */ |

439 | size_weight = 0.0; | |

440 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) | |

441 | if (!ex->example->getClass().isSpecial()) { | |

442 | node->dist[ex->example->getClass().intV] += ex->weight; | |

443 | size_weight += ex->weight; | |

444 | } | |

[8770] | 445 | |

[10206] | 446 | /* stopping criterion: majority class */ |

447 | for (i = 0; i < cls_vals; i++) | |

448 | if (node->dist[i] / size_weight >= args->maxMajority) | |

449 | return make_predictor(node, examples, size, args); | |

[8770] | 450 | |

[10206] | 451 | cls_entropy = entropy(node->dist, cls_vals); |

452 | } else { | |

453 | float n, sum, sum2, cls_val; | |

[8770] | 454 | |

[10206] | 455 | assert(args->type == Regression); |

456 | if (size == 0) { | |

457 | assert(parent); | |

458 | node->type = PredictorNode; | |

459 | node->children_size = 0; | |

460 | node->n = parent->n; | |

461 | node->sum = parent->sum; | |

462 | return node; | |

463 | } | |

[8770] | 464 | |

[10206] | 465 | n = sum = sum2 = 0.0; |

466 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) | |

467 | if (!ex->example->getClass().isSpecial()) { | |

468 | cls_val = ex->example->getClass().floatV; | |

469 | n += ex->weight; | |

470 | sum += ex->weight * cls_val; | |

471 | sum2 += ex->weight * cls_val * cls_val; | |

472 | } | |

[8770] | 473 | |

[10206] | 474 | node->n = n; |

475 | node->sum = sum; | |

476 | cls_mse = (sum2 - sum * sum / n) / n; | |

[9168] | 477 | |

[10206] | 478 | if (cls_mse < 1e-5) { |

479 | return make_predictor(node, examples, size, args); | |

480 | } | |

481 | } | |

[8770] | 482 | |

[10206] | 483 | /* stopping criterion: depth exceeds limit */ |

484 | if (depth == args->maxDepth) | |

485 | return make_predictor(node, examples, size, args); | |

[8770] | 486 | |

[10206] | 487 | /* score attributes */ |

488 | best_score = -INFINITY; | |

[8770] | 489 | |

[10206] | 490 | for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) { |

491 | if (!args->attr_split_so_far[i]) { | |

492 | /* select random subset of attributes */ | |

[9296] | 493 | if (args->randomGenerator->randdouble() < args->skipProb) |

[10206] | 494 | continue; |

[8770] | 495 | |

[10206] | 496 | if ((*it)->varType == TValue::INTVAR) { |

497 | score = args->type == Classification ? | |

498 | gain_ratio_d(examples, size, i, cls_entropy, args) : | |

499 | mse_d(examples, size, i, cls_mse, args); | |

500 | if (score > best_score) { | |

501 | best_score = score; | |

502 | best_attr = i; | |

503 | } | |

504 | } else if ((*it)->varType == TValue::FLOATVAR) { | |

505 | score = args->type == Classification ? | |

506 | gain_ratio_c(examples, size, i, cls_entropy, args, &split) : | |

507 | mse_c(examples, size, i, cls_mse, args, &split); | |

508 | if (score > best_score) { | |

509 | best_score = score; | |

510 | best_split = split; | |

511 | best_attr = i; | |

512 | } | |

513 | } | |

514 | } | |

515 | } | |

[8378] | 516 | |

[10206] | 517 | if (best_score == -INFINITY) |

518 | return make_predictor(node, examples, size, args); | |

[8378] | 519 | |

[10206] | 520 | if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) { |

521 | struct Example *child_examples, *child_ex; | |

522 | int attr_vals; | |

523 | float size_known, *attr_dist; | |

[8378] | 524 | |

[10206] | 525 | /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ |

[8378] | 526 | |

[10206] | 527 | attr_vals = args->domain->attributes->at(best_attr)->noOfValues(); |

[8378] | 528 | |

[10206] | 529 | node->type = DiscreteNode; |

530 | node->split_attr = best_attr; | |

531 | node->children_size = attr_vals; | |

[8378] | 532 | |

[10206] | 533 | ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples)); |

534 | ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children)); | |

535 | ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); | |

[8378] | 536 | |

[10206] | 537 | /* attribute distribution */ |

538 | size_known = 0; | |

539 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) | |

540 | if (!ex->example->values[best_attr].isSpecial()) { | |

541 | attr_dist[ex->example->values[best_attr].intV] += ex->weight; | |

542 | size_known += ex->weight; | |

543 | } | |

[8378] | 544 | |

[10206] | 545 | args->attr_split_so_far[best_attr] = 1; |

[8378] | 546 | |

[10206] | 547 | for (i = 0; i < attr_vals; i++) { |

548 | /* create a new example table */ | |

549 | for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) { | |

550 | if (ex->example->values[best_attr].isSpecial()) { | |

551 | *child_ex = *ex; | |

552 | child_ex->weight *= attr_dist[i] / size_known; | |

553 | child_ex++; | |

554 | } else if (ex->example->values[best_attr].intV == i) { | |

555 | *child_ex++ = *ex; | |

556 | } | |

557 | } | |

[8378] | 558 | |

[10206] | 559 | node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args); |

560 | } | |

561 | ||

562 | args->attr_split_so_far[best_attr] = 0; | |

[8378] | 563 | |

[10206] | 564 | free(attr_dist); |

565 | free(child_examples); | |

566 | } else { | |

567 | struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge; | |

568 | float size_lt, size_ge; | |

[8378] | 569 | |

[10206] | 570 | /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */ |

[8770] | 571 | |

[10206] | 572 | assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR); |

[8770] | 573 | |

[10206] | 574 | ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples)); |

575 | ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples)); | |

[8770] | 576 | |

[10206] | 577 | size_lt = size_ge = 0.0; |

578 | for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) | |

579 | if (!ex->example->values[best_attr].isSpecial()) | |

580 | if (ex->example->values[best_attr].floatV < best_split) | |

581 | size_lt += ex->weight; | |

582 | else | |

583 | size_ge += ex->weight; | |

[8770] | 584 | |

[10206] | 585 | for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++) |

586 | if (ex->example->values[best_attr].isSpecial()) { | |

587 | *ex_lt = *ex; | |

588 | *ex_ge = *ex; | |

589 | ex_lt->weight *= size_lt / (size_lt + size_ge); | |

590 | ex_ge->weight *= size_ge / (size_lt + size_ge); | |

591 | ex_lt++; | |

592 | ex_ge++; | |

593 | } else if (ex->example->values[best_attr].floatV < best_split) { | |

594 | *ex_lt++ = *ex; | |

595 | } else { | |

596 | *ex_ge++ = *ex; | |

597 | } | |

[8378] | 598 | |

[10206] | 599 | node->type = ContinuousNode; |

600 | node->split_attr = best_attr; | |

601 | node->split = best_split; | |

602 | node->children_size = 2; | |

603 | ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children)); | |

[8378] | 604 | |

[10206] | 605 | node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args); |

606 | node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args); | |

[8378] | 607 | |

[10206] | 608 | free(examples_lt); |

609 | free(examples_ge); | |

610 | } | |

[8378] | 611 | |

[10206] | 612 | return node; |

[8378] | 613 | } |

614 | ||

[9296] | 615 | TSimpleTreeLearner::TSimpleTreeLearner(const int &weight, float maxMajority, int minInstances, int maxDepth, float skipProb, PRandomGenerator rgen) : |

[10206] | 616 | maxMajority(maxMajority), |

617 | minInstances(minInstances), | |

618 | maxDepth(maxDepth), | |

619 | skipProb(skipProb) | |

[8378] | 620 | { |

[9296] | 621 | randomGenerator = rgen ? rgen : PRandomGenerator(mlnew TRandomGenerator()); |

[8378] | 622 | } |

623 | ||

[8770] | 624 | PClassifier |

[8378] | 625 | TSimpleTreeLearner::operator()(PExampleGenerator ogen, const int &weight) |

[8770] | 626 | { |

[10206] | 627 | struct Example *examples, *ex; |

628 | struct SimpleTreeNode *tree; | |

629 | struct Args args; | |

630 | int cls_vals; | |

[8378] | 631 | |

[10206] | 632 | if (!ogen->domain->classVar) |

633 | raiseError("class-less domain"); | |

[8378] | 634 | |

[10767] | 635 | if (!ogen->numberOfExamples() > 0) |

636 | raiseError("no examples"); | |

637 | ||

[10206] | 638 | /* create a tabel with pointers to examples */ |

639 | ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples)); | |

640 | ex = examples; | |

641 | PEITERATE(ei, ogen) { | |

642 | ex->example = &(*ei); | |

643 | ex->weight = 1.0; | |

644 | ex++; | |

645 | } | |

[8770] | 646 | |

[10206] | 647 | ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int))); |

648 | args.minInstances = minInstances; | |

649 | args.maxMajority = maxMajority; | |

650 | args.maxDepth = maxDepth; | |

651 | args.skipProb = skipProb; | |

652 | args.domain = ogen->domain; | |

[9296] | 653 | args.randomGenerator = randomGenerator; |

[10206] | 654 | args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression; |

655 | cls_vals = ogen->domain->classVar->noOfValues(); | |

[8378] | 656 | |

[10206] | 657 | tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args); |

[8378] | 658 | |

[10206] | 659 | free(examples); |

660 | free(args.attr_split_so_far); | |

[8378] | 661 | |

[10206] | 662 | return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type, cls_vals); |

[8378] | 663 | } |

664 | ||

665 | ||

666 | /* classifier */ | |

667 | TSimpleTreeClassifier::TSimpleTreeClassifier() | |

668 | { | |

669 | } | |

670 | ||

[10206] | 671 | TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type, int cls_vals) : |

672 | TClassifier(classVar, true), | |

673 | tree(tree), | |

674 | type(type), | |

675 | cls_vals(cls_vals) | |

[8378] | 676 | { |

677 | } | |

678 | ||

[8770] | 679 | |

[8378] | 680 | void |

[8770] | 681 | destroy_tree(struct SimpleTreeNode *node, int type) |

[8378] | 682 | { |

[10206] | 683 | int i; |

[8378] | 684 | |

[10206] | 685 | if (node->type != PredictorNode) { |

686 | for (i = 0; i < node->children_size; i++) | |

687 | destroy_tree(node->children[i], type); | |

688 | free(node->children); | |

689 | } | |

690 | if (type == Classification) | |

691 | free(node->dist); | |

692 | free(node); | |

[8378] | 693 | } |

694 | ||

[8770] | 695 | |

[8378] | 696 | TSimpleTreeClassifier::~TSimpleTreeClassifier() |

697 | { | |

[10206] | 698 | destroy_tree(tree, type); |

[8378] | 699 | } |

700 | ||

[8770] | 701 | |

702 | float * | |

[10492] | 703 | predict_classification(const TExample &ex, struct SimpleTreeNode *node, int *free_dist, int cls_vals) |

[8378] | 704 | { |

[10492] | 705 | int i, j; |

[10206] | 706 | float *dist, *child_dist; |

[8770] | 707 | |

[10206] | 708 | while (node->type != PredictorNode) |

709 | if (ex.values[node->split_attr].isSpecial()) { | |

710 | ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist)); | |

711 | for (i = 0; i < node->children_size; i++) { | |

[10492] | 712 | child_dist = predict_classification(ex, node->children[i], free_dist, cls_vals); |

[10206] | 713 | for (j = 0; j < cls_vals; j++) |

714 | dist[j] += child_dist[j]; | |

715 | if (*free_dist) | |

716 | free(child_dist); | |

717 | } | |

718 | *free_dist = 1; | |

719 | return dist; | |

720 | } else if (node->type == DiscreteNode) { | |

721 | node = node->children[ex.values[node->split_attr].intV]; | |

722 | } else { | |

723 | assert(node->type == ContinuousNode); | |

724 | node = node->children[ex.values[node->split_attr].floatV >= node->split]; | |

725 | } | |

[8770] | 726 | |

[10206] | 727 | *free_dist = 0; |

728 | return node->dist; | |

[8378] | 729 | } |

730 | ||

[8770] | 731 | |

732 | void | |

733 | predict_regression(const TExample &ex, struct SimpleTreeNode *node, float *sum, float *n) | |

734 | { | |

[10206] | 735 | int i; |

736 | float local_sum, local_n; | |

[8770] | 737 | |

[10206] | 738 | while (node->type != PredictorNode) { |

739 | if (ex.values[node->split_attr].isSpecial()) { | |

740 | *sum = *n = 0; | |

741 | for (i = 0; i < node->children_size; i++) { | |

742 | predict_regression(ex, node->children[i], &local_sum, &local_n); | |

743 | *sum += local_sum; | |

744 | *n += local_n; | |

745 | } | |

746 | return; | |

747 | } else if (node->type == DiscreteNode) { | |

748 | assert(ex.values[node->split_attr].intV < node->children_size); | |

749 | node = node->children[ex.values[node->split_attr].intV]; | |

750 | } else { | |

751 | assert(node->type == ContinuousNode); | |

752 | node = node->children[ex.values[node->split_attr].floatV > node->split]; | |

753 | } | |

754 | } | |

[8770] | 755 | |

[10206] | 756 | *sum = node->sum; |

757 | *n = node->n; | |

758 | } | |

759 | ||

760 | ||

761 | void | |

762 | TSimpleTreeClassifier::save_tree(ostringstream &ss, struct SimpleTreeNode *node) | |

763 | { | |

764 | int i; | |

765 | ||

766 | ss << "{ " << node->type << " " << node->children_size << " "; | |

767 | ||

768 | if (node->type != PredictorNode) | |

769 | ss << node->split_attr << " " << node->split << " "; | |

770 | ||

771 | for (i = 0; i < node->children_size; i++) | |

772 | this->save_tree(ss, node->children[i]); | |

773 | ||

774 | if (this->type == Classification) { | |

775 | for (i = 0; i < this->cls_vals; i++) | |

776 | ss << node->dist[i] << " "; | |

777 | } else { | |

778 | assert(this->type == Regression); | |

779 | ss << node->n << " " << node->sum << " "; | |

780 | } | |

781 | ss << "} "; | |

782 | } | |

783 | ||

784 | struct SimpleTreeNode * | |

785 | TSimpleTreeClassifier::load_tree(istringstream &ss) | |

786 | { | |

787 | int i; | |

788 | string lbracket, rbracket; | |

[10252] | 789 | string split_string; |

[10206] | 790 | SimpleTreeNode *node; |

791 | ||

[10252] | 792 | ss.exceptions(istream::failbit); |

793 | ||

[10206] | 794 | ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node)); |

795 | ss >> lbracket >> node->type >> node->children_size; | |

796 | ||

[10252] | 797 | |

[10206] | 798 | if (node->type != PredictorNode) |

[10252] | 799 | { |

800 | ss >> node->split_attr; | |

801 | ||

802 | /* Read split into a string and use strtod to parse it. | |

803 | * istream sometimes (on some platforms) seems to have problems | |

804 | * reading formated floats. | |

805 | */ | |

806 | ss >> split_string; | |

807 | node->split = float(strtod(split_string.c_str(), NULL)); | |

808 | } | |

[10206] | 809 | |

810 | if (node->children_size) { | |

811 | ASSERT(node->children = (SimpleTreeNode **)calloc(node->children_size, sizeof *node->children)); | |

812 | for (i = 0; i < node->children_size; i++) | |

813 | node->children[i] = load_tree(ss); | |

814 | } | |

815 | ||

816 | if (this->type == Classification) { | |

817 | ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *))); | |

818 | for (i = 0; i < this->cls_vals; i++) | |

819 | ss >> node->dist[i]; | |

820 | } else { | |

821 | assert(this->type == Regression); | |

822 | ss >> node->n >> node->sum; | |

823 | } | |

824 | ss >> rbracket; | |

825 | ||

826 | /* Synchronization check */ | |

827 | assert(lbracket == "{" && rbracket == "}"); | |

828 | ||

829 | return node; | |

830 | } | |

831 | ||

832 | void | |

833 | TSimpleTreeClassifier::save_model(ostringstream &ss) | |

834 | { | |

835 | ss.precision(9); /* we have floats */ | |

836 | ss << this->type << " " << this->cls_vals << " "; | |

837 | this->save_tree(ss, this->tree); | |

838 | } | |

839 | ||

840 | void | |

841 | TSimpleTreeClassifier::load_model(istringstream &ss) | |

842 | { | |

843 | ss >> this->type >> this->cls_vals; | |

844 | this->tree = load_tree(ss); | |

[8770] | 845 | } |

846 | ||

847 | ||

848 | TValue | |

[8378] | 849 | TSimpleTreeClassifier::operator()(const TExample &ex) |

850 | { | |

[10206] | 851 | if (type == Classification) { |

852 | int i, free_dist, best_val; | |

853 | float *dist; | |

[8378] | 854 | |

[10492] | 855 | dist = predict_classification(ex, tree, &free_dist, this->cls_vals); |

[10206] | 856 | best_val = 0; |

[10492] | 857 | for (i = 1; i < this->cls_vals; i++) |

[10206] | 858 | if (dist[i] > dist[best_val]) |

859 | best_val = i; | |

[8378] | 860 | |

[10206] | 861 | if (free_dist) |

862 | free(dist); | |

863 | return TValue(best_val); | |

864 | } else { | |

865 | float sum, n; | |

[8770] | 866 | |

[10206] | 867 | assert(type == Regression); |

[8770] | 868 | |

[10206] | 869 | predict_regression(ex, tree, &sum, &n); |

870 | return TValue(sum / n); | |

871 | } | |

[8378] | 872 | } |

873 | ||

874 | PDistribution | |

875 | TSimpleTreeClassifier::classDistribution(const TExample &ex) | |

876 | { | |

[10206] | 877 | if (type == Classification) { |

878 | int i, free_dist; | |

879 | float *dist; | |

[8378] | 880 | |

[10492] | 881 | dist = predict_classification(ex, tree, &free_dist, this->cls_vals); |

[8378] | 882 | |

[10492] | 883 | PDistribution pdist = mlnew TDiscDistribution(this->cls_vals, 0.0); |

[10965] | 884 | pdist->variable = this->classVar; |

[10492] | 885 | for (i = 0; i < this->cls_vals; i++) |

[10206] | 886 | pdist->setint(i, dist[i]); |

887 | pdist->normalize(); | |

[8378] | 888 | |

[10206] | 889 | if (free_dist) |

890 | free(dist); | |

891 | return pdist; | |

892 | } else { | |

893 | return NULL; | |

894 | } | |

[8378] | 895 | } |

896 | ||

897 | void | |

898 | TSimpleTreeClassifier::predictionAndDistribution(const TExample &ex, TValue &value, PDistribution &dist) | |

899 | { | |

[10206] | 900 | value = operator()(ex); |

901 | dist = classDistribution(ex); | |

[8378] | 902 | } |

**Note:**See TracBrowser for help on using the repository browser.