source: orange/source/orange/tdidt_simple.cpp @ 11703:9b8d8ab7820c

Revision 11703:9b8d8ab7820c, 24.3 KB checked in by janezd <janez.demsar@…>, 7 months ago (diff)

Removed the GPL copyright notice from all files except orangeqt.

Line 
1#include <iostream>
2#include <sstream>
3#include <math.h>
4#include <stdlib.h>
5#include <cstring>
6
7#include "vars.hpp"
8#include "domain.hpp"
9#include "distvars.hpp"
10#include "examples.hpp"
11#include "examplegen.hpp"
12#include "table.hpp"
13#include "classify.hpp"
14
15#include "tdidt_simple.ppp"
16
17#ifndef _MSC_VER
18    #include "err.h"
19    #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__)
20#else
21    #define ASSERT(x) if(!(x)) exit(1)
22    #define log2f(x) log((double) (x)) / log(2.0)
23#endif // _MSC_VER
24
25#ifndef INFINITY
26    #include <limits>
27    #define INFINITY numeric_limits<float>::infinity()
28#endif // INFINITY
29
30struct Args {
31    int minInstances, maxDepth;
32    float maxMajority, skipProb;
33
34    int type, *attr_split_so_far;
35    PDomain domain;
36    PRandomGenerator randomGenerator;
37};
38
39struct Example {
40    TExample *example;
41    float weight;
42};
43
44enum { DiscreteNode, ContinuousNode, PredictorNode };
45enum { Classification, Regression };
46
47int compar_attr;
48
49/* This function uses the global variable compar_attr.
50 * Examples with unknowns are larger so that, when sorted, they appear at the bottom.
51 */
52int
53compar_examples(const void *ptr1, const void *ptr2)
54{
55    struct Example *e1, *e2;
56
57    e1 = (struct Example *)ptr1;
58    e2 = (struct Example *)ptr2;
59    if (e1->example->values[compar_attr].isSpecial())
60        return 1;
61    if (e2->example->values[compar_attr].isSpecial())
62        return -1;
63    return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]);
64}
65
66
67float
68entropy(float *xs, int size)
69{
70    float *ip, *end, sum, e;
71
72    for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++)
73        if (*ip > 0.0) {
74            e -= *ip * log2f(*ip);
75            sum += *ip;
76        }
77
78    return sum == 0.0 ? 0.0 : e / sum + log2f(sum);
79}
80
81int
82test_min_examples(float *attr_dist, int attr_vals, struct Args *args)
83{
84    int i;
85
86    for (i = 0; i < attr_vals; i++) {
87        if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances)
88            return 0;
89    }
90    return 1;
91}
92
93float
94gain_ratio_c(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args, float *best_split)
95{
96    struct Example *ex, *ex_end, *ex_next;
97    int i, cls, cls_vals, minInstances, size_known;
98    float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight;
99
100    cls_vals = args->domain->classVar->noOfValues();
101
102    /* minInstances should be at least 1, otherwise there is no point in splitting */
103    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
104
105    /* allocate space */
106    ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt));
107    ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge));
108    ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist));
109
110    /* sort */
111    compar_attr = attr;
112    qsort(examples, size, sizeof(struct Example), compar_examples);
113
114    /* compute gain ratio for every split */
115    size_known = size;
116    size_weight = 0.0;
117    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
118        if (ex->example->values[attr].isSpecial()) {
119            size_known = ex - examples;
120            break;
121        }
122        if (!ex->example->getClass().isSpecial())
123            dist_ge[ex->example->getClass().intV] += ex->weight;
124        size_weight += ex->weight;
125    }
126
127    attr_dist[1] = size_weight;
128    best_score = -INFINITY;
129
130    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
131        if (!ex->example->getClass().isSpecial()) {
132            cls = ex->example->getClass().intV;
133            dist_lt[cls] += ex->weight;
134            dist_ge[cls] -= ex->weight;
135        }
136        attr_dist[0] += ex->weight;
137        attr_dist[1] -= ex->weight;
138
139        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
140            continue;
141
142        /* gain ratio */
143        score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight;
144        score = (cls_entropy - score) / entropy(attr_dist, 2);
145
146
147        if (score > best_score) {
148            best_score = score;
149            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
150        }
151    }
152
153    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
154
155    /* cleanup */
156    free(dist_lt);
157    free(dist_ge);
158    free(attr_dist);
159
160    return best_score;
161}
162
163
164float
165gain_ratio_d(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args)
166{
167    struct Example *ex, *ex_end;
168    int i, cls_vals, attr_vals, attr_val, cls_val;
169    float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known;
170
171    cls_vals = args->domain->classVar->noOfValues();
172    attr_vals = args->domain->attributes->at(attr)->noOfValues();
173
174    /* allocate space */
175    ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *)));
176    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *)));
177    ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *)));
178
179    /* contingency matrix */
180    size_weight = 0.0;
181    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
182        if (!ex->example->values[attr].isSpecial()) {
183            attr_val = ex->example->values[attr].intV;
184            attr_dist[attr_val] += ex->weight;
185            if (!ex->example->getClass().isSpecial()) {
186                cls_val = ex->example->getClass().intV;
187                attr_dist_cls_known[attr_val] += ex->weight;
188                cont[attr_val * cls_vals + cls_val] += ex->weight;
189            }
190        }
191        size_weight += ex->weight;
192    }
193
194    /* min examples in leaves */
195    if (!test_min_examples(attr_dist, attr_vals, args)) {
196        score = -INFINITY;
197        goto finish;
198    }
199
200    size_attr_known = size_attr_cls_known = 0.0;
201    for (i = 0; i < attr_vals; i++) {
202        size_attr_known += attr_dist[i];
203        size_attr_cls_known += attr_dist_cls_known[i];
204    }
205
206    /* gain ratio */
207    score = 0.0;
208    for (i = 0; i < attr_vals; i++)
209        score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals);
210    attr_entropy = entropy(attr_dist, attr_vals);
211
212    if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) {
213        score = -INFINITY;
214        goto finish;
215    }
216
217    score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight);
218
219    /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */
220
221finish:
222    free(cont);
223    free(attr_dist);
224    free(attr_dist_cls_known);
225    return score;
226}
227
228
229float
230mse_c(struct Example *examples, int size, int attr, float cls_mse, struct Args *args, float *best_split)
231{
232    struct Example *ex, *ex_end, *ex_next;
233    int i, cls_vals, minInstances, size_known;
234    float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score;
235
236    struct Variance {
237        double n, sum, sum2;
238    } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0};
239
240    cls_vals = args->domain->classVar->noOfValues();
241
242    /* minInstances should be at least 1, otherwise there is no point in splitting */
243    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
244
245    /* sort */
246    compar_attr = attr;
247    qsort(examples, size, sizeof(struct Example), compar_examples);
248
249    /* compute mse for every split */
250    size_known = size;
251    size_attr_known = 0.0;
252    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
253        if (ex->example->values[attr].isSpecial()) {
254            size_known = ex - examples;
255            break;
256        }
257        if (!ex->example->getClass().isSpecial()) {
258            cls_val = ex->example->getClass().floatV;
259            var_ge.n += ex->weight;
260            var_ge.sum += ex->weight * cls_val;
261            var_ge.sum2 += ex->weight * cls_val * cls_val;
262        }
263        size_attr_known += ex->weight;
264    }
265
266    /* count the remaining examples with unknown values */
267    size_weight = size_attr_known;
268    for (ex_end = examples + size; ex < ex_end; ex++)
269        size_weight += ex->weight;
270
271    size_attr_cls_known = var_ge.n;
272    best_score = -INFINITY;
273
274    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
275        if (!ex->example->getClass().isSpecial()) {
276            cls_val = ex->example->getClass();
277            var_lt.n += ex->weight;
278            var_lt.sum += ex->weight * cls_val;
279            var_lt.sum2 += ex->weight * cls_val * cls_val;
280
281            /* this calculation might be numarically unstable - fix */
282            var_ge.n -= ex->weight;
283            var_ge.sum -= ex->weight * cls_val;
284            var_ge.sum2 -= ex->weight * cls_val * cls_val;
285        }
286
287        /* Naive calculation of variance (used for testing)
288         
289        struct Example *ex2, *ex_end2;
290        float nlt, sumlt, sum2lt, nge, sumge, sum2ge;
291        nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0;
292
293        for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) {
294            cls_val = ex2->example->getClass();
295            if (ex2 < ex) {
296                nlt += ex2->weight;
297                sumlt += ex2->weight * cls_val;
298                sum2lt += ex2->weight * cls_val * cls_val;
299            } else {
300                nge += ex2->weight;
301                sumge += ex2->weight * cls_val;
302                sum2ge += ex2->weight * cls_val * cls_val;
303            }
304        }
305        */
306
307
308        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
309            continue;
310
311        /* compute mse */
312        score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n;
313        score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n;
314
315        score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
316
317        if (score > best_score) {
318            best_score = score;
319            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
320        }
321    }
322
323    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
324    return best_score;
325}
326
327
328float
329mse_d(struct Example *examples, int size, int attr, float cls_mse, struct Args *args)
330{
331    int i, attr_vals, attr_val;
332    float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight;
333    struct Example *ex, *ex_end;
334
335    struct Variance {
336        float n, sum, sum2;
337    } *variances, *v, *v_end;
338
339    attr_vals = args->domain->attributes->at(attr)->noOfValues();
340
341    ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances));
342    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
343
344    size_weight = size_attr_cls_known = size_attr_known = 0.0;
345    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
346        if (!ex->example->values[attr].isSpecial()) {
347            attr_dist[ex->example->values[attr].intV] += ex->weight;
348            size_attr_known += ex->weight;
349
350            if (!ex->example->getClass().isSpecial()) {
351                    cls_val = ex->example->getClass().floatV;
352                    v = variances + ex->example->values[attr].intV;
353                    v->n += ex->weight;
354                    v->sum += ex->weight * cls_val;
355                    v->sum2 += ex->weight * cls_val * cls_val;
356                    size_attr_cls_known += ex->weight;
357            }
358        }
359        size_weight += ex->weight;
360    }
361
362    /* minimum examples in leaves */
363    if (!test_min_examples(attr_dist, attr_vals, args)) {
364        score = -INFINITY;
365        goto finish;
366    }
367
368    score = 0.0;
369    for (v = variances, v_end = variances + attr_vals; v < v_end; v++)
370        if (v->n > 0.0)
371            score += v->sum2 - v->sum * v->sum / v->n;
372    score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
373
374    if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0)
375        score = 0.0;
376
377finish:
378    free(attr_dist);
379    free(variances);
380
381    return score;
382}
383
384
385struct SimpleTreeNode *
386make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args)
387{
388    node->type = PredictorNode;
389    node->children_size = 0;
390    return node;
391}
392
393
394struct SimpleTreeNode *
395build_tree(struct Example *examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args)
396{
397    int i, cls_vals, best_attr;
398    float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split;
399    struct SimpleTreeNode *node;
400    struct Example *ex, *ex_end;
401    TVarList::const_iterator it;
402
403    cls_vals = args->domain->classVar->noOfValues();
404
405    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
406
407    if (args->type == Classification) {
408        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
409
410        if (size == 0) {
411            assert(parent);
412            node->type = PredictorNode;
413            node->children_size = 0;
414            memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist);
415            return node;
416        }
417
418        /* class distribution */
419        size_weight = 0.0;
420        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
421            if (!ex->example->getClass().isSpecial()) {
422                node->dist[ex->example->getClass().intV] += ex->weight;
423                size_weight += ex->weight;
424            }
425
426        /* stopping criterion: majority class */
427        for (i = 0; i < cls_vals; i++)
428            if (node->dist[i] / size_weight >= args->maxMajority)
429                return make_predictor(node, examples, size, args);
430
431        cls_entropy = entropy(node->dist, cls_vals);
432    } else {
433        float n, sum, sum2, cls_val;
434
435        assert(args->type == Regression);
436        if (size == 0) {
437            assert(parent);
438            node->type = PredictorNode;
439            node->children_size = 0;
440            node->n = parent->n;
441            node->sum = parent->sum;
442            return node;
443        }
444
445        n = sum = sum2 = 0.0;
446        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
447            if (!ex->example->getClass().isSpecial()) {
448                cls_val = ex->example->getClass().floatV;
449                n += ex->weight;
450                sum += ex->weight * cls_val;
451                sum2 += ex->weight * cls_val * cls_val;
452            }
453
454        node->n = n;
455        node->sum = sum;
456        cls_mse = (sum2 - sum * sum / n) / n;
457
458        if (cls_mse < 1e-5) {
459            return make_predictor(node, examples, size, args);
460        }
461    }
462
463    /* stopping criterion: depth exceeds limit */
464    if (depth == args->maxDepth)
465        return make_predictor(node, examples, size, args);
466
467    /* score attributes */
468    best_score = -INFINITY;
469
470    for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) {
471        if (!args->attr_split_so_far[i]) {
472            /* select random subset of attributes */
473            if (args->randomGenerator->randdouble() < args->skipProb)
474                continue;
475
476            if ((*it)->varType == TValue::INTVAR) {
477                score = args->type == Classification ?
478                  gain_ratio_d(examples, size, i, cls_entropy, args) :
479                  mse_d(examples, size, i, cls_mse, args);
480                if (score > best_score) {
481                    best_score = score;
482                    best_attr = i;
483                }
484            } else if ((*it)->varType == TValue::FLOATVAR) {
485                score = args->type == Classification ?
486                  gain_ratio_c(examples, size, i, cls_entropy, args, &split) :
487                  mse_c(examples, size, i, cls_mse, args, &split);
488                if (score > best_score) {
489                    best_score = score;
490                    best_split = split;
491                    best_attr = i;
492                }
493            }
494        }
495    }
496
497    if (best_score == -INFINITY)
498        return make_predictor(node, examples, size, args);
499
500    if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) {
501        struct Example *child_examples, *child_ex;
502        int attr_vals;
503        float size_known, *attr_dist;
504
505        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */
506
507        attr_vals = args->domain->attributes->at(best_attr)->noOfValues(); 
508
509        node->type = DiscreteNode;
510        node->split_attr = best_attr;
511        node->children_size = attr_vals;
512
513        ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples));
514        ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children));
515        ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
516
517        /* attribute distribution */
518        size_known = 0;
519        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
520            if (!ex->example->values[best_attr].isSpecial()) {
521                attr_dist[ex->example->values[best_attr].intV] += ex->weight;
522                size_known += ex->weight;
523            }
524
525        args->attr_split_so_far[best_attr] = 1;
526
527        for (i = 0; i < attr_vals; i++) {
528            /* create a new example table */
529            for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) {
530                if (ex->example->values[best_attr].isSpecial()) {
531                    *child_ex = *ex;
532                    child_ex->weight *= attr_dist[i] / size_known;
533                    child_ex++;
534                } else if (ex->example->values[best_attr].intV == i) {
535                    *child_ex++ = *ex;
536                }
537            }
538
539            node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args);
540        }
541                   
542        args->attr_split_so_far[best_attr] = 0;
543
544        free(attr_dist);
545        free(child_examples);
546    } else {
547        struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge;
548        float size_lt, size_ge;
549
550        /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */
551
552        assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR);
553
554        ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples));
555        ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples));
556
557        size_lt = size_ge = 0.0;
558        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
559            if (!ex->example->values[best_attr].isSpecial())
560                if (ex->example->values[best_attr].floatV < best_split)
561                    size_lt += ex->weight;
562                else
563                    size_ge += ex->weight;
564
565        for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++)
566            if (ex->example->values[best_attr].isSpecial()) {
567                *ex_lt = *ex;
568                *ex_ge = *ex;
569                ex_lt->weight *= size_lt / (size_lt + size_ge);
570                ex_ge->weight *= size_ge / (size_lt + size_ge);
571                ex_lt++;
572                ex_ge++;
573            } else if (ex->example->values[best_attr].floatV < best_split) {
574                *ex_lt++ = *ex;
575            } else {
576                *ex_ge++ = *ex;
577            }
578
579        /*
580         * Check there was an actual reduction of size in the the two subsets.
581         * This test fails when all best_attr's (the only attr) values  are
582         * the same (and equal best_split) so the data is split in 0 | n size
583         * subsets and recursing would lead to an infinite recursion.
584         */
585        if ((ex_lt - examples_lt) < size && (ex_ge - examples_ge) < size) {
586            node->type = ContinuousNode;
587            node->split_attr = best_attr;
588            node->split = best_split;
589            node->children_size = 2;
590            ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children));
591
592            node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args);
593            node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args);
594        } else {
595            node = make_predictor(node, examples, size, args);
596        }
597
598        free(examples_lt);
599        free(examples_ge);
600    }
601
602    return node;
603}
604
605TSimpleTreeLearner::TSimpleTreeLearner(const int &weight, float maxMajority, int minInstances, int maxDepth, float skipProb, PRandomGenerator rgen) :
606    maxMajority(maxMajority),
607    minInstances(minInstances),
608    maxDepth(maxDepth),
609    skipProb(skipProb)
610{
611    randomGenerator = rgen ? rgen : PRandomGenerator(mlnew TRandomGenerator());
612}
613
614PClassifier
615TSimpleTreeLearner::operator()(PExampleGenerator ogen, const int &weight)
616{
617    struct Example *examples, *ex;
618    struct SimpleTreeNode *tree;
619    struct Args args;
620    int cls_vals;
621
622    if (!ogen->domain->classVar)
623        raiseError("class-less domain");
624
625    if (!ogen->numberOfExamples() > 0)
626        raiseError("no examples");
627
628    /* create a tabel with pointers to examples */
629    ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples));
630    ex = examples;
631    PEITERATE(ei, ogen) {
632        ex->example = &(*ei);
633        ex->weight = 1.0;
634        ex++;
635    }
636
637    ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int)));
638    args.minInstances = minInstances;
639    args.maxMajority = maxMajority;
640    args.maxDepth = maxDepth;
641    args.skipProb = skipProb;
642    args.domain = ogen->domain;
643    args.randomGenerator = randomGenerator;
644    args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression;
645    cls_vals = ogen->domain->classVar->noOfValues();
646
647    tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args);
648
649    free(examples);
650    free(args.attr_split_so_far);
651
652    return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type, cls_vals);
653}
654
655
656/* classifier */
657TSimpleTreeClassifier::TSimpleTreeClassifier()
658{
659}
660
661TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type, int cls_vals) : 
662    TClassifier(classVar, true),
663    tree(tree),
664    type(type),
665    cls_vals(cls_vals)
666{
667}
668
669
670void
671destroy_tree(struct SimpleTreeNode *node, int type)
672{
673    int i;
674
675    if (node->type != PredictorNode) {
676        for (i = 0; i < node->children_size; i++)
677            destroy_tree(node->children[i], type);
678        free(node->children);
679    }
680    if (type == Classification)
681        free(node->dist);
682    free(node);
683}
684
685
686TSimpleTreeClassifier::~TSimpleTreeClassifier()
687{
688    destroy_tree(tree, type);
689}
690
691
692float *
693predict_classification(const TExample &ex, struct SimpleTreeNode *node, int *free_dist, int cls_vals)
694{
695    int i, j;
696    float *dist, *child_dist;
697
698    while (node->type != PredictorNode)
699        if (ex.values[node->split_attr].isSpecial()) {
700            ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist));
701            for (i = 0; i < node->children_size; i++) {
702                child_dist = predict_classification(ex, node->children[i], free_dist, cls_vals);
703                for (j = 0; j < cls_vals; j++)
704                    dist[j] += child_dist[j];
705                if (*free_dist)
706                    free(child_dist);
707            }
708            *free_dist = 1;
709            return dist;
710        } else if (node->type == DiscreteNode) {
711            node = node->children[ex.values[node->split_attr].intV];
712        } else {
713            assert(node->type == ContinuousNode);
714            node = node->children[ex.values[node->split_attr].floatV >= node->split];
715        }
716
717    *free_dist = 0;
718    return node->dist;
719}
720
721
722void
723predict_regression(const TExample &ex, struct SimpleTreeNode *node, float *sum, float *n)
724{
725    int i;
726    float local_sum, local_n;
727
728    while (node->type != PredictorNode) {
729        if (ex.values[node->split_attr].isSpecial()) {
730            *sum = *n = 0;
731            for (i = 0; i < node->children_size; i++) {
732                predict_regression(ex, node->children[i], &local_sum, &local_n);
733                *sum += local_sum;
734                *n += local_n;
735            }
736            return;
737        } else if (node->type == DiscreteNode) {
738            assert(ex.values[node->split_attr].intV < node->children_size);
739            node = node->children[ex.values[node->split_attr].intV];
740        } else {
741            assert(node->type == ContinuousNode);
742            node = node->children[ex.values[node->split_attr].floatV > node->split];
743        }
744    }
745
746    *sum = node->sum;
747    *n = node->n;
748}
749
750
751void
752TSimpleTreeClassifier::save_tree(ostringstream &ss, struct SimpleTreeNode *node)
753{
754    int i;
755
756    ss << "{ " << node->type << " " << node->children_size << " ";
757
758    if (node->type != PredictorNode)
759        ss << node->split_attr << " " << node->split << " ";
760
761    for (i = 0; i < node->children_size; i++)
762        this->save_tree(ss, node->children[i]);
763
764    if (this->type == Classification) {
765        for (i = 0; i < this->cls_vals; i++)
766            ss << node->dist[i] << " ";
767    } else {
768        assert(this->type == Regression);
769        ss << node->n << " " << node->sum << " ";
770    }
771    ss << "} ";
772}
773
774struct SimpleTreeNode *
775TSimpleTreeClassifier::load_tree(istringstream &ss)
776{
777    int i;
778    string lbracket, rbracket;
779    string split_string;
780    SimpleTreeNode *node;
781
782    ss.exceptions(istream::failbit);
783
784    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
785    ss >> lbracket >> node->type >> node->children_size;
786
787
788    if (node->type != PredictorNode)
789    {
790        ss >> node->split_attr;
791
792        /* Read split into a string and use strtod to parse it.
793         * istream sometimes (on some platforms) seems to have problems
794         * reading formated floats.
795         */
796        ss >> split_string;
797        node->split = float(strtod(split_string.c_str(), NULL));
798    }
799
800    if (node->children_size) {
801        ASSERT(node->children = (SimpleTreeNode **)calloc(node->children_size, sizeof *node->children));
802        for (i = 0; i < node->children_size; i++)
803            node->children[i] = load_tree(ss);
804    }
805
806    if (this->type == Classification) {
807        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
808        for (i = 0; i < this->cls_vals; i++)
809            ss >> node->dist[i];
810    } else {
811        assert(this->type == Regression);
812        ss >> node->n >> node->sum;
813    }
814    ss >> rbracket;
815
816    /* Synchronization check */
817    assert(lbracket == "{" && rbracket == "}");
818
819    return node;
820}
821
822void
823TSimpleTreeClassifier::save_model(ostringstream &ss)
824{
825    ss.precision(9); /* we have floats */
826    ss << this->type << " " << this->cls_vals << " ";
827    this->save_tree(ss, this->tree);
828}
829
830void
831TSimpleTreeClassifier::load_model(istringstream &ss)
832{
833    ss >> this->type >> this->cls_vals;
834    this->tree = load_tree(ss);
835}
836
837
838TValue
839TSimpleTreeClassifier::operator()(const TExample &ex)
840{
841    if (type == Classification) {
842        int i, free_dist, best_val;
843        float *dist;
844
845        dist = predict_classification(ex, tree, &free_dist, this->cls_vals);
846        best_val = 0;
847        for (i = 1; i < this->cls_vals; i++)
848            if (dist[i] > dist[best_val])
849                best_val = i;
850
851        if (free_dist)
852            free(dist);
853        return TValue(best_val);
854    } else {
855        float sum, n;
856
857        assert(type == Regression);
858
859        predict_regression(ex, tree, &sum, &n);
860        return TValue(sum / n);
861    }
862}
863
864PDistribution
865TSimpleTreeClassifier::classDistribution(const TExample &ex)
866{
867    if (type == Classification) {
868        int i, free_dist;
869        float *dist;
870
871        dist = predict_classification(ex, tree, &free_dist, this->cls_vals);
872
873        PDistribution pdist = mlnew TDiscDistribution(this->cls_vals, 0.0);
874        pdist->variable = this->classVar;
875        for (i = 0; i < this->cls_vals; i++)
876            pdist->setint(i, dist[i]);
877        pdist->normalize();
878
879        if (free_dist)
880            free(dist);
881        return pdist;
882    } else {
883        return NULL;
884    }
885}
886
887void
888TSimpleTreeClassifier::predictionAndDistribution(const TExample &ex, TValue &value, PDistribution &dist)
889{
890    value = operator()(ex);
891    dist = classDistribution(ex);
892}
Note: See TracBrowser for help on using the repository browser.