source: orange/source/orange/tdidt_simple.cpp @ 10252:ec524648dbaa

Revision 10252:ec524648dbaa, 24.5 KB checked in by ales_erjavec, 2 years ago (diff)

A workaround for a strange behavior of 'operator >>' for floats (on Mac OSX gcc 4.2.1). Fixes #1104.

Line 
1/*
2    This file is part of Orange.
3
4    Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana
5    Contact: janez.demsar@fri.uni-lj.si
6
7    Orange is free software: you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation, either version 3 of the License, or
10    (at your option) any later version.
11
12    Orange is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with Orange.  If not, see <http://www.gnu.org/licenses/>.
19*/
20
21#include <iostream>
22#include <sstream>
23#include <math.h>
24#include <stdlib.h>
25#include <cstring>
26
27#include "vars.hpp"
28#include "domain.hpp"
29#include "distvars.hpp"
30#include "examples.hpp"
31#include "examplegen.hpp"
32#include "table.hpp"
33#include "classify.hpp"
34
35#include "tdidt_simple.ppp"
36
37#ifndef _MSC_VER
38    #include "err.h"
39    #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__)
40#else
41    #define ASSERT(x) if(!(x)) exit(1)
42    #define log2f(x) log((double) (x)) / log(2.0)
43#endif // _MSC_VER
44
45#ifndef INFINITY
46    #include <limits>
47    #define INFINITY numeric_limits<float>::infinity()
48#endif // INFINITY
49
50struct Args {
51    int minInstances, maxDepth;
52    float maxMajority, skipProb;
53
54    int type, *attr_split_so_far;
55    PDomain domain;
56    PRandomGenerator randomGenerator;
57};
58
59struct Example {
60    TExample *example;
61    float weight;
62};
63
64enum { DiscreteNode, ContinuousNode, PredictorNode };
65enum { Classification, Regression };
66
67int compar_attr;
68
69/* This function uses the global variable compar_attr.
70 * Examples with unknowns are larger so that, when sorted, they appear at the bottom.
71 */
72int
73compar_examples(const void *ptr1, const void *ptr2)
74{
75    struct Example *e1, *e2;
76
77    e1 = (struct Example *)ptr1;
78    e2 = (struct Example *)ptr2;
79    if (e1->example->values[compar_attr].isSpecial())
80        return 1;
81    if (e2->example->values[compar_attr].isSpecial())
82        return -1;
83    return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]);
84}
85
86
87float
88entropy(float *xs, int size)
89{
90    float *ip, *end, sum, e;
91
92    for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++)
93        if (*ip > 0.0) {
94            e -= *ip * log2f(*ip);
95            sum += *ip;
96        }
97
98    return sum == 0.0 ? 0.0 : e / sum + log2f(sum);
99}
100
101int
102test_min_examples(float *attr_dist, int attr_vals, struct Args *args)
103{
104    int i;
105
106    for (i = 0; i < attr_vals; i++) {
107        if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances)
108            return 0;
109    }
110    return 1;
111}
112
113float
114gain_ratio_c(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args, float *best_split)
115{
116    struct Example *ex, *ex_end, *ex_next;
117    int i, cls, cls_vals, minInstances, size_known;
118    float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight;
119
120    cls_vals = args->domain->classVar->noOfValues();
121
122    /* minInstances should be at least 1, otherwise there is no point in splitting */
123    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
124
125    /* allocate space */
126    ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt));
127    ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge));
128    ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist));
129
130    /* sort */
131    compar_attr = attr;
132    qsort(examples, size, sizeof(struct Example), compar_examples);
133
134    /* compute gain ratio for every split */
135    size_known = size;
136    size_weight = 0.0;
137    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
138        if (ex->example->values[attr].isSpecial()) {
139            size_known = ex - examples;
140            break;
141        }
142        if (!ex->example->getClass().isSpecial())
143            dist_ge[ex->example->getClass().intV] += ex->weight;
144        size_weight += ex->weight;
145    }
146
147    attr_dist[1] = size_weight;
148    best_score = -INFINITY;
149
150    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
151        if (!ex->example->getClass().isSpecial()) {
152            cls = ex->example->getClass().intV;
153            dist_lt[cls] += ex->weight;
154            dist_ge[cls] -= ex->weight;
155        }
156        attr_dist[0] += ex->weight;
157        attr_dist[1] -= ex->weight;
158
159        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
160            continue;
161
162        /* gain ratio */
163        score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight;
164        score = (cls_entropy - score) / entropy(attr_dist, 2);
165
166
167        if (score > best_score) {
168            best_score = score;
169            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
170        }
171    }
172
173    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
174
175    /* cleanup */
176    free(dist_lt);
177    free(dist_ge);
178    free(attr_dist);
179
180    return best_score;
181}
182
183
184float
185gain_ratio_d(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args)
186{
187    struct Example *ex, *ex_end;
188    int i, cls_vals, attr_vals, attr_val, cls_val;
189    float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known;
190
191    cls_vals = args->domain->classVar->noOfValues();
192    attr_vals = args->domain->attributes->at(attr)->noOfValues();
193
194    /* allocate space */
195    ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *)));
196    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *)));
197    ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *)));
198
199    /* contingency matrix */
200    size_weight = 0.0;
201    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
202        if (!ex->example->values[attr].isSpecial()) {
203            attr_val = ex->example->values[attr].intV;
204            attr_dist[attr_val] += ex->weight;
205            if (!ex->example->getClass().isSpecial()) {
206                cls_val = ex->example->getClass().intV;
207                attr_dist_cls_known[attr_val] += ex->weight;
208                cont[attr_val * cls_vals + cls_val] += ex->weight;
209            }
210        }
211        size_weight += ex->weight;
212    }
213
214    /* min examples in leaves */
215    if (!test_min_examples(attr_dist, attr_vals, args)) {
216        score = -INFINITY;
217        goto finish;
218    }
219
220    size_attr_known = size_attr_cls_known = 0.0;
221    for (i = 0; i < attr_vals; i++) {
222        size_attr_known += attr_dist[i];
223        size_attr_cls_known += attr_dist_cls_known[i];
224    }
225
226    /* gain ratio */
227    score = 0.0;
228    for (i = 0; i < attr_vals; i++)
229        score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals);
230    attr_entropy = entropy(attr_dist, attr_vals);
231
232    if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) {
233        score = -INFINITY;
234        goto finish;
235    }
236
237    score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight);
238
239    /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */
240
241finish:
242    free(cont);
243    free(attr_dist);
244    free(attr_dist_cls_known);
245    return score;
246}
247
248
249float
250mse_c(struct Example *examples, int size, int attr, float cls_mse, struct Args *args, float *best_split)
251{
252    struct Example *ex, *ex_end, *ex_next;
253    int i, cls_vals, minInstances, size_known;
254    float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score;
255
256    struct Variance {
257        double n, sum, sum2;
258    } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0};
259
260    cls_vals = args->domain->classVar->noOfValues();
261
262    /* minInstances should be at least 1, otherwise there is no point in splitting */
263    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
264
265    /* sort */
266    compar_attr = attr;
267    qsort(examples, size, sizeof(struct Example), compar_examples);
268
269    /* compute mse for every split */
270    size_known = size;
271    size_attr_known = 0.0;
272    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
273        if (ex->example->values[attr].isSpecial()) {
274            size_known = ex - examples;
275            break;
276        }
277        if (!ex->example->getClass().isSpecial()) {
278            cls_val = ex->example->getClass().floatV;
279            var_ge.n += ex->weight;
280            var_ge.sum += ex->weight * cls_val;
281            var_ge.sum2 += ex->weight * cls_val * cls_val;
282        }
283        size_attr_known += ex->weight;
284    }
285
286    /* count the remaining examples with unknown values */
287    size_weight = size_attr_known;
288    for (ex_end = examples + size; ex < ex_end; ex++)
289        size_weight += ex->weight;
290
291    size_attr_cls_known = var_ge.n;
292    best_score = -INFINITY;
293
294    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
295        if (!ex->example->getClass().isSpecial()) {
296            cls_val = ex->example->getClass();
297            var_lt.n += ex->weight;
298            var_lt.sum += ex->weight * cls_val;
299            var_lt.sum2 += ex->weight * cls_val * cls_val;
300
301            /* this calculation might be numarically unstable - fix */
302            var_ge.n -= ex->weight;
303            var_ge.sum -= ex->weight * cls_val;
304            var_ge.sum2 -= ex->weight * cls_val * cls_val;
305        }
306
307        /* Naive calculation of variance (used for testing)
308         
309        struct Example *ex2, *ex_end2;
310        float nlt, sumlt, sum2lt, nge, sumge, sum2ge;
311        nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0;
312
313        for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) {
314            cls_val = ex2->example->getClass();
315            if (ex2 < ex) {
316                nlt += ex2->weight;
317                sumlt += ex2->weight * cls_val;
318                sum2lt += ex2->weight * cls_val * cls_val;
319            } else {
320                nge += ex2->weight;
321                sumge += ex2->weight * cls_val;
322                sum2ge += ex2->weight * cls_val * cls_val;
323            }
324        }
325        */
326
327
328        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
329            continue;
330
331        /* compute mse */
332        score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n;
333        score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n;
334
335        score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
336
337        if (score > best_score) {
338            best_score = score;
339            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
340        }
341    }
342
343    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
344    return best_score;
345}
346
347
348float
349mse_d(struct Example *examples, int size, int attr, float cls_mse, struct Args *args)
350{
351    int i, attr_vals, attr_val;
352    float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight;
353    struct Example *ex, *ex_end;
354
355    struct Variance {
356        float n, sum, sum2;
357    } *variances, *v, *v_end;
358
359    attr_vals = args->domain->attributes->at(attr)->noOfValues();
360
361    ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances));
362    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
363
364    size_weight = size_attr_cls_known = size_attr_known = 0.0;
365    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
366        if (!ex->example->values[attr].isSpecial()) {
367            attr_dist[ex->example->values[attr].intV] += ex->weight;
368            size_attr_known += ex->weight;
369
370            if (!ex->example->getClass().isSpecial()) {
371                    cls_val = ex->example->getClass().floatV;
372                    v = variances + ex->example->values[attr].intV;
373                    v->n += ex->weight;
374                    v->sum += ex->weight * cls_val;
375                    v->sum2 += ex->weight * cls_val * cls_val;
376                    size_attr_cls_known += ex->weight;
377            }
378        }
379        size_weight += ex->weight;
380    }
381
382    /* minimum examples in leaves */
383    if (!test_min_examples(attr_dist, attr_vals, args)) {
384        score = -INFINITY;
385        goto finish;
386    }
387
388    score = 0.0;
389    for (v = variances, v_end = variances + attr_vals; v < v_end; v++)
390        if (v->n > 0.0)
391            score += v->sum2 - v->sum * v->sum / v->n;
392    score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
393
394    if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0)
395        score = 0.0;
396
397finish:
398    free(attr_dist);
399    free(variances);
400
401    return score;
402}
403
404
405struct SimpleTreeNode *
406make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args)
407{
408    node->type = PredictorNode;
409    node->children_size = 0;
410    return node;
411}
412
413
414struct SimpleTreeNode *
415build_tree(struct Example *examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args)
416{
417    int i, cls_vals, best_attr;
418    float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split;
419    struct SimpleTreeNode *node;
420    struct Example *ex, *ex_end;
421    TVarList::const_iterator it;
422
423    cls_vals = args->domain->classVar->noOfValues();
424
425    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
426
427    if (args->type == Classification) {
428        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
429
430        if (size == 0) {
431            assert(parent);
432            node->type = PredictorNode;
433            node->children_size = 0;
434            memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist);
435            return node;
436        }
437
438        /* class distribution */
439        size_weight = 0.0;
440        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
441            if (!ex->example->getClass().isSpecial()) {
442                node->dist[ex->example->getClass().intV] += ex->weight;
443                size_weight += ex->weight;
444            }
445
446        /* stopping criterion: majority class */
447        for (i = 0; i < cls_vals; i++)
448            if (node->dist[i] / size_weight >= args->maxMajority)
449                return make_predictor(node, examples, size, args);
450
451        cls_entropy = entropy(node->dist, cls_vals);
452    } else {
453        float n, sum, sum2, cls_val;
454
455        assert(args->type == Regression);
456        if (size == 0) {
457            assert(parent);
458            node->type = PredictorNode;
459            node->children_size = 0;
460            node->n = parent->n;
461            node->sum = parent->sum;
462            return node;
463        }
464
465        n = sum = sum2 = 0.0;
466        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
467            if (!ex->example->getClass().isSpecial()) {
468                cls_val = ex->example->getClass().floatV;
469                n += ex->weight;
470                sum += ex->weight * cls_val;
471                sum2 += ex->weight * cls_val * cls_val;
472            }
473
474        node->n = n;
475        node->sum = sum;
476        cls_mse = (sum2 - sum * sum / n) / n;
477
478        if (cls_mse < 1e-5) {
479            return make_predictor(node, examples, size, args);
480        }
481    }
482
483    /* stopping criterion: depth exceeds limit */
484    if (depth == args->maxDepth)
485        return make_predictor(node, examples, size, args);
486
487    /* score attributes */
488    best_score = -INFINITY;
489
490    for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) {
491        if (!args->attr_split_so_far[i]) {
492            /* select random subset of attributes */
493            if (args->randomGenerator->randdouble() < args->skipProb)
494                continue;
495
496            if ((*it)->varType == TValue::INTVAR) {
497                score = args->type == Classification ?
498                  gain_ratio_d(examples, size, i, cls_entropy, args) :
499                  mse_d(examples, size, i, cls_mse, args);
500                if (score > best_score) {
501                    best_score = score;
502                    best_attr = i;
503                }
504            } else if ((*it)->varType == TValue::FLOATVAR) {
505                score = args->type == Classification ?
506                  gain_ratio_c(examples, size, i, cls_entropy, args, &split) :
507                  mse_c(examples, size, i, cls_mse, args, &split);
508                if (score > best_score) {
509                    best_score = score;
510                    best_split = split;
511                    best_attr = i;
512                }
513            }
514        }
515    }
516
517    if (best_score == -INFINITY)
518        return make_predictor(node, examples, size, args);
519
520    if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) {
521        struct Example *child_examples, *child_ex;
522        int attr_vals;
523        float size_known, *attr_dist;
524
525        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */
526
527        attr_vals = args->domain->attributes->at(best_attr)->noOfValues(); 
528
529        node->type = DiscreteNode;
530        node->split_attr = best_attr;
531        node->children_size = attr_vals;
532
533        ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples));
534        ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children));
535        ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
536
537        /* attribute distribution */
538        size_known = 0;
539        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
540            if (!ex->example->values[best_attr].isSpecial()) {
541                attr_dist[ex->example->values[best_attr].intV] += ex->weight;
542                size_known += ex->weight;
543            }
544
545        args->attr_split_so_far[best_attr] = 1;
546
547        for (i = 0; i < attr_vals; i++) {
548            /* create a new example table */
549            for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) {
550                if (ex->example->values[best_attr].isSpecial()) {
551                    *child_ex = *ex;
552                    child_ex->weight *= attr_dist[i] / size_known;
553                    child_ex++;
554                } else if (ex->example->values[best_attr].intV == i) {
555                    *child_ex++ = *ex;
556                }
557            }
558
559            node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args);
560        }
561                   
562        args->attr_split_so_far[best_attr] = 0;
563
564        free(attr_dist);
565        free(child_examples);
566    } else {
567        struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge;
568        float size_lt, size_ge;
569
570        /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */
571
572        assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR);
573
574        ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples));
575        ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples));
576
577        size_lt = size_ge = 0.0;
578        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
579            if (!ex->example->values[best_attr].isSpecial())
580                if (ex->example->values[best_attr].floatV < best_split)
581                    size_lt += ex->weight;
582                else
583                    size_ge += ex->weight;
584
585        for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++)
586            if (ex->example->values[best_attr].isSpecial()) {
587                *ex_lt = *ex;
588                *ex_ge = *ex;
589                ex_lt->weight *= size_lt / (size_lt + size_ge);
590                ex_ge->weight *= size_ge / (size_lt + size_ge);
591                ex_lt++;
592                ex_ge++;
593            } else if (ex->example->values[best_attr].floatV < best_split) {
594                *ex_lt++ = *ex;
595            } else {
596                *ex_ge++ = *ex;
597            }
598
599        node->type = ContinuousNode;
600        node->split_attr = best_attr;
601        node->split = best_split;
602        node->children_size = 2;
603        ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children));
604
605        node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args);
606        node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args);
607
608        free(examples_lt);
609        free(examples_ge);
610    }
611
612    return node;
613}
614
615TSimpleTreeLearner::TSimpleTreeLearner(const int &weight, float maxMajority, int minInstances, int maxDepth, float skipProb, PRandomGenerator rgen) :
616    maxMajority(maxMajority),
617    minInstances(minInstances),
618    maxDepth(maxDepth),
619    skipProb(skipProb)
620{
621    randomGenerator = rgen ? rgen : PRandomGenerator(mlnew TRandomGenerator());
622}
623
624PClassifier
625TSimpleTreeLearner::operator()(PExampleGenerator ogen, const int &weight)
626{
627    struct Example *examples, *ex;
628    struct SimpleTreeNode *tree;
629    struct Args args;
630    int cls_vals;
631
632    if (!ogen->domain->classVar)
633        raiseError("class-less domain");
634
635    /* create a tabel with pointers to examples */
636    ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples));
637    ex = examples;
638    PEITERATE(ei, ogen) {
639        ex->example = &(*ei);
640        ex->weight = 1.0;
641        ex++;
642    }
643
644    ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int)));
645    args.minInstances = minInstances;
646    args.maxMajority = maxMajority;
647    args.maxDepth = maxDepth;
648    args.skipProb = skipProb;
649    args.domain = ogen->domain;
650    args.randomGenerator = randomGenerator;
651    args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression;
652    cls_vals = ogen->domain->classVar->noOfValues();
653
654    tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args);
655
656    free(examples);
657    free(args.attr_split_so_far);
658
659    return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type, cls_vals);
660}
661
662
663/* classifier */
664TSimpleTreeClassifier::TSimpleTreeClassifier()
665{
666}
667
668TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type, int cls_vals) : 
669    TClassifier(classVar, true),
670    tree(tree),
671    type(type),
672    cls_vals(cls_vals)
673{
674}
675
676
677void
678destroy_tree(struct SimpleTreeNode *node, int type)
679{
680    int i;
681
682    if (node->type != PredictorNode) {
683        for (i = 0; i < node->children_size; i++)
684            destroy_tree(node->children[i], type);
685        free(node->children);
686    }
687    if (type == Classification)
688        free(node->dist);
689    free(node);
690}
691
692
693TSimpleTreeClassifier::~TSimpleTreeClassifier()
694{
695    destroy_tree(tree, type);
696}
697
698
699float *
700predict_classification(const TExample &ex, struct SimpleTreeNode *node, int *free_dist)
701{
702    int i, j, cls_vals;
703    float *dist, *child_dist;
704
705    while (node->type != PredictorNode)
706        if (ex.values[node->split_attr].isSpecial()) {
707            cls_vals = ex.domain->classVar->noOfValues();
708            ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist));
709            for (i = 0; i < node->children_size; i++) {
710                child_dist = predict_classification(ex, node->children[i], free_dist);
711                for (j = 0; j < cls_vals; j++)
712                    dist[j] += child_dist[j];
713                if (*free_dist)
714                    free(child_dist);
715            }
716            *free_dist = 1;
717            return dist;
718        } else if (node->type == DiscreteNode) {
719            node = node->children[ex.values[node->split_attr].intV];
720        } else {
721            assert(node->type == ContinuousNode);
722            node = node->children[ex.values[node->split_attr].floatV >= node->split];
723        }
724
725    *free_dist = 0;
726    return node->dist;
727}
728
729
730void
731predict_regression(const TExample &ex, struct SimpleTreeNode *node, float *sum, float *n)
732{
733    int i;
734    float local_sum, local_n;
735
736    while (node->type != PredictorNode) {
737        if (ex.values[node->split_attr].isSpecial()) {
738            *sum = *n = 0;
739            for (i = 0; i < node->children_size; i++) {
740                predict_regression(ex, node->children[i], &local_sum, &local_n);
741                *sum += local_sum;
742                *n += local_n;
743            }
744            return;
745        } else if (node->type == DiscreteNode) {
746            assert(ex.values[node->split_attr].intV < node->children_size);
747            node = node->children[ex.values[node->split_attr].intV];
748        } else {
749            assert(node->type == ContinuousNode);
750            node = node->children[ex.values[node->split_attr].floatV > node->split];
751        }
752    }
753
754    *sum = node->sum;
755    *n = node->n;
756}
757
758
759void
760TSimpleTreeClassifier::save_tree(ostringstream &ss, struct SimpleTreeNode *node)
761{
762    int i;
763
764    ss << "{ " << node->type << " " << node->children_size << " ";
765
766    if (node->type != PredictorNode)
767        ss << node->split_attr << " " << node->split << " ";
768
769    for (i = 0; i < node->children_size; i++)
770        this->save_tree(ss, node->children[i]);
771
772    if (this->type == Classification) {
773        for (i = 0; i < this->cls_vals; i++)
774            ss << node->dist[i] << " ";
775    } else {
776        assert(this->type == Regression);
777        ss << node->n << " " << node->sum << " ";
778    }
779    ss << "} ";
780}
781
782struct SimpleTreeNode *
783TSimpleTreeClassifier::load_tree(istringstream &ss)
784{
785    int i;
786    string lbracket, rbracket;
787    string split_string;
788    SimpleTreeNode *node;
789
790    ss.exceptions(istream::failbit);
791
792    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
793    ss >> lbracket >> node->type >> node->children_size;
794
795
796    if (node->type != PredictorNode)
797    {
798        ss >> node->split_attr;
799
800        /* Read split into a string and use strtod to parse it.
801         * istream sometimes (on some platforms) seems to have problems
802         * reading formated floats.
803         */
804        ss >> split_string;
805        node->split = float(strtod(split_string.c_str(), NULL));
806    }
807
808    if (node->children_size) {
809        ASSERT(node->children = (SimpleTreeNode **)calloc(node->children_size, sizeof *node->children));
810        for (i = 0; i < node->children_size; i++)
811            node->children[i] = load_tree(ss);
812    }
813
814    if (this->type == Classification) {
815        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
816        for (i = 0; i < this->cls_vals; i++)
817            ss >> node->dist[i];
818    } else {
819        assert(this->type == Regression);
820        ss >> node->n >> node->sum;
821    }
822    ss >> rbracket;
823
824    /* Synchronization check */
825    assert(lbracket == "{" && rbracket == "}");
826
827    return node;
828}
829
830void
831TSimpleTreeClassifier::save_model(ostringstream &ss)
832{
833    ss.precision(9); /* we have floats */
834    ss << this->type << " " << this->cls_vals << " ";
835    this->save_tree(ss, this->tree);
836}
837
838void
839TSimpleTreeClassifier::load_model(istringstream &ss)
840{
841    ss >> this->type >> this->cls_vals;
842    this->tree = load_tree(ss);
843}
844
845
846TValue
847TSimpleTreeClassifier::operator()(const TExample &ex)
848{
849    if (type == Classification) {
850        int i, free_dist, best_val;
851        float *dist;
852
853        dist = predict_classification(ex, tree, &free_dist);
854        best_val = 0;
855        for (i = 1; i < ex.domain->classVar->noOfValues(); i++)
856            if (dist[i] > dist[best_val])
857                best_val = i;
858
859        if (free_dist)
860            free(dist);
861        return TValue(best_val);
862    } else {
863        float sum, n;
864
865        assert(type == Regression);
866
867        predict_regression(ex, tree, &sum, &n);
868        return TValue(sum / n);
869    }
870}
871
872PDistribution
873TSimpleTreeClassifier::classDistribution(const TExample &ex)
874{
875    if (type == Classification) {
876        int i, free_dist;
877        float *dist;
878
879        dist = predict_classification(ex, tree, &free_dist);
880
881        PDistribution pdist = TDistribution::create(ex.domain->classVar);
882        for (i = 0; i < ex.domain->classVar->noOfValues(); i++)
883            pdist->setint(i, dist[i]);
884        pdist->normalize();
885
886        if (free_dist)
887            free(dist);
888        return pdist;
889    } else {
890        return NULL;
891    }
892}
893
894void
895TSimpleTreeClassifier::predictionAndDistribution(const TExample &ex, TValue &value, PDistribution &dist)
896{
897    value = operator()(ex);
898    dist = classDistribution(ex);
899}
Note: See TracBrowser for help on using the repository browser.