source: orange/source/orange/tdidt_simple.cpp @ 10965:873ff9bf106c

Revision 10965:873ff9bf106c, 24.6 KB checked in by Ales Erjavec <ales.erjavec@…>, 20 months ago (diff)

Set distribution variable in 'SimpleTreeLearner.classDistribution'.

RevLine 
[8378]1/*
[10206]2    This file is part of Orange.
[8770]3
[10206]4    Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana
5    Contact: janez.demsar@fri.uni-lj.si
[8378]6
[10206]7    Orange is free software: you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation, either version 3 of the License, or
10    (at your option) any later version.
[8378]11
[10206]12    Orange is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
[8378]16
[10206]17    You should have received a copy of the GNU General Public License
18    along with Orange.  If not, see <http://www.gnu.org/licenses/>.
[8378]19*/
20
[10206]21#include <iostream>
22#include <sstream>
[8378]23#include <math.h>
24#include <stdlib.h>
25#include <cstring>
26
27#include "vars.hpp"
28#include "domain.hpp"
29#include "distvars.hpp"
30#include "examples.hpp"
31#include "examplegen.hpp"
32#include "table.hpp"
33#include "classify.hpp"
34
35#include "tdidt_simple.ppp"
36
[8396]37#ifndef _MSC_VER
[10206]38    #include "err.h"
39    #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__)
[8396]40#else
[10206]41    #define ASSERT(x) if(!(x)) exit(1)
42    #define log2f(x) log((double) (x)) / log(2.0)
[8396]43#endif // _MSC_VER
44
45#ifndef INFINITY
[10206]46    #include <limits>
47    #define INFINITY numeric_limits<float>::infinity()
[8396]48#endif // INFINITY
49
[8378]50struct Args {
[10206]51    int minInstances, maxDepth;
52    float maxMajority, skipProb;
[8378]53
[10206]54    int type, *attr_split_so_far;
55    PDomain domain;
[9296]56    PRandomGenerator randomGenerator;
[8378]57};
58
[8770]59struct Example {
[10206]60    TExample *example;
61    float weight;
[8770]62};
63
64enum { DiscreteNode, ContinuousNode, PredictorNode };
65enum { Classification, Regression };
66
[8378]67int compar_attr;
68
[8131]69/* This function uses the global variable compar_attr.
[8770]70 * Examples with unknowns are larger so that, when sorted, they appear at the bottom.
[8131]71 */
[8378]72int
[8131]73compar_examples(const void *ptr1, const void *ptr2)
[8378]74{
[10206]75    struct Example *e1, *e2;
[8770]76
[10206]77    e1 = (struct Example *)ptr1;
78    e2 = (struct Example *)ptr2;
79    if (e1->example->values[compar_attr].isSpecial())
80        return 1;
81    if (e2->example->values[compar_attr].isSpecial())
82        return -1;
83    return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]);
[8770]84}
85
86
87float
88entropy(float *xs, int size)
89{
[10206]90    float *ip, *end, sum, e;
[8770]91
[10206]92    for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++)
93        if (*ip > 0.0) {
94            e -= *ip * log2f(*ip);
95            sum += *ip;
96        }
[8770]97
[10206]98    return sum == 0.0 ? 0.0 : e / sum + log2f(sum);
[8770]99}
100
101int
102test_min_examples(float *attr_dist, int attr_vals, struct Args *args)
103{
[10206]104    int i;
[8770]105
[10206]106    for (i = 0; i < attr_vals; i++) {
107        if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances)
108            return 0;
109    }
110    return 1;
[8378]111}
112
113float
[8770]114gain_ratio_c(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args, float *best_split)
[8378]115{
[10206]116    struct Example *ex, *ex_end, *ex_next;
117    int i, cls, cls_vals, minInstances, size_known;
118    float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight;
[8378]119
[10206]120    cls_vals = args->domain->classVar->noOfValues();
[8378]121
[10206]122    /* minInstances should be at least 1, otherwise there is no point in splitting */
123    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
[8770]124
[10206]125    /* allocate space */
126    ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt));
127    ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge));
128    ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist));
[8378]129
[10206]130    /* sort */
131    compar_attr = attr;
132    qsort(examples, size, sizeof(struct Example), compar_examples);
[8378]133
[10206]134    /* compute gain ratio for every split */
135    size_known = size;
136    size_weight = 0.0;
137    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
138        if (ex->example->values[attr].isSpecial()) {
139            size_known = ex - examples;
140            break;
141        }
142        if (!ex->example->getClass().isSpecial())
143            dist_ge[ex->example->getClass().intV] += ex->weight;
144        size_weight += ex->weight;
145    }
[8378]146
[10206]147    attr_dist[1] = size_weight;
148    best_score = -INFINITY;
[8131]149
[10206]150    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
151        if (!ex->example->getClass().isSpecial()) {
152            cls = ex->example->getClass().intV;
153            dist_lt[cls] += ex->weight;
154            dist_ge[cls] -= ex->weight;
155        }
156        attr_dist[0] += ex->weight;
157        attr_dist[1] -= ex->weight;
[8378]158
[10206]159        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
160            continue;
[8378]161
[10206]162        /* gain ratio */
163        score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight;
164        score = (cls_entropy - score) / entropy(attr_dist, 2);
[8378]165
[8770]166
[10206]167        if (score > best_score) {
168            best_score = score;
169            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
170        }
171    }
[8378]172
[10206]173    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
[8131]174
[10206]175    /* cleanup */
176    free(dist_lt);
177    free(dist_ge);
178    free(attr_dist);
[8378]179
[10206]180    return best_score;
[8378]181}
182
[8770]183
[8378]184float
[8770]185gain_ratio_d(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args)
[8378]186{
[10206]187    struct Example *ex, *ex_end;
188    int i, cls_vals, attr_vals, attr_val, cls_val;
189    float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known;
[8378]190
[10206]191    cls_vals = args->domain->classVar->noOfValues();
192    attr_vals = args->domain->attributes->at(attr)->noOfValues();
[8378]193
[10206]194    /* allocate space */
195    ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *)));
196    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *)));
197    ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *)));
[8378]198
[10206]199    /* contingency matrix */
200    size_weight = 0.0;
201    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
202        if (!ex->example->values[attr].isSpecial()) {
203            attr_val = ex->example->values[attr].intV;
204            attr_dist[attr_val] += ex->weight;
205            if (!ex->example->getClass().isSpecial()) {
206                cls_val = ex->example->getClass().intV;
207                attr_dist_cls_known[attr_val] += ex->weight;
208                cont[attr_val * cls_vals + cls_val] += ex->weight;
209            }
210        }
211        size_weight += ex->weight;
212    }
[8378]213
[10206]214    /* min examples in leaves */
215    if (!test_min_examples(attr_dist, attr_vals, args)) {
216        score = -INFINITY;
217        goto finish;
218    }
[8378]219
[10206]220    size_attr_known = size_attr_cls_known = 0.0;
221    for (i = 0; i < attr_vals; i++) {
222        size_attr_known += attr_dist[i];
223        size_attr_cls_known += attr_dist_cls_known[i];
224    }
[8378]225
[10206]226    /* gain ratio */
227    score = 0.0;
228    for (i = 0; i < attr_vals; i++)
229        score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals);
230    attr_entropy = entropy(attr_dist, attr_vals);
[8770]231
[10206]232    if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) {
233        score = -INFINITY;
234        goto finish;
235    }
[8770]236
[10206]237    score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight);
[8770]238
[10206]239    /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */
[8770]240
241finish:
[10206]242    free(cont);
243    free(attr_dist);
244    free(attr_dist_cls_known);
245    return score;
[8770]246}
247
248
249float
250mse_c(struct Example *examples, int size, int attr, float cls_mse, struct Args *args, float *best_split)
251{
[10206]252    struct Example *ex, *ex_end, *ex_next;
253    int i, cls_vals, minInstances, size_known;
254    float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score;
[8770]255
[10206]256    struct Variance {
257        double n, sum, sum2;
258    } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0};
[8770]259
[10206]260    cls_vals = args->domain->classVar->noOfValues();
[8770]261
[10206]262    /* minInstances should be at least 1, otherwise there is no point in splitting */
263    minInstances = args->minInstances < 1 ? 1 : args->minInstances;
[8770]264
[10206]265    /* sort */
266    compar_attr = attr;
267    qsort(examples, size, sizeof(struct Example), compar_examples);
[8770]268
[10206]269    /* compute mse for every split */
270    size_known = size;
271    size_attr_known = 0.0;
272    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
273        if (ex->example->values[attr].isSpecial()) {
274            size_known = ex - examples;
275            break;
276        }
277        if (!ex->example->getClass().isSpecial()) {
278            cls_val = ex->example->getClass().floatV;
279            var_ge.n += ex->weight;
280            var_ge.sum += ex->weight * cls_val;
281            var_ge.sum2 += ex->weight * cls_val * cls_val;
282        }
283        size_attr_known += ex->weight;
284    }
[8770]285
[10206]286    /* count the remaining examples with unknown values */
287    size_weight = size_attr_known;
288    for (ex_end = examples + size; ex < ex_end; ex++)
289        size_weight += ex->weight;
[8770]290
[10206]291    size_attr_cls_known = var_ge.n;
292    best_score = -INFINITY;
[8770]293
[10206]294    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) {
295        if (!ex->example->getClass().isSpecial()) {
296            cls_val = ex->example->getClass();
297            var_lt.n += ex->weight;
298            var_lt.sum += ex->weight * cls_val;
299            var_lt.sum2 += ex->weight * cls_val * cls_val;
[8770]300
[10206]301            /* this calculation might be numarically unstable - fix */
302            var_ge.n -= ex->weight;
303            var_ge.sum -= ex->weight * cls_val;
304            var_ge.sum2 -= ex->weight * cls_val * cls_val;
305        }
[8770]306
[10206]307        /* Naive calculation of variance (used for testing)
308         
309        struct Example *ex2, *ex_end2;
310        float nlt, sumlt, sum2lt, nge, sumge, sum2ge;
311        nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0;
[9168]312
[10206]313        for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) {
314            cls_val = ex2->example->getClass();
315            if (ex2 < ex) {
316                nlt += ex2->weight;
317                sumlt += ex2->weight * cls_val;
318                sum2lt += ex2->weight * cls_val * cls_val;
319            } else {
320                nge += ex2->weight;
321                sumge += ex2->weight * cls_val;
322                sum2ge += ex2->weight * cls_val * cls_val;
323            }
324        }
325        */
[9168]326
327
[10206]328        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances)
329            continue;
[8770]330
[10206]331        /* compute mse */
332        score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n;
333        score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n;
[9168]334
[10206]335        score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
[8770]336
[10206]337        if (score > best_score) {
338            best_score = score;
339            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0;
340        }
341    }
[8770]342
[10206]343    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */
344    return best_score;
[8770]345}
346
347
348float
349mse_d(struct Example *examples, int size, int attr, float cls_mse, struct Args *args)
350{
[10206]351    int i, attr_vals, attr_val;
352    float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight;
353    struct Example *ex, *ex_end;
[8770]354
[10206]355    struct Variance {
356        float n, sum, sum2;
357    } *variances, *v, *v_end;
[8770]358
[10206]359    attr_vals = args->domain->attributes->at(attr)->noOfValues();
[8770]360
[10206]361    ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances));
362    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
[8770]363
[10206]364    size_weight = size_attr_cls_known = size_attr_known = 0.0;
365    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) {
366        if (!ex->example->values[attr].isSpecial()) {
367            attr_dist[ex->example->values[attr].intV] += ex->weight;
368            size_attr_known += ex->weight;
[8770]369
[10206]370            if (!ex->example->getClass().isSpecial()) {
371                    cls_val = ex->example->getClass().floatV;
372                    v = variances + ex->example->values[attr].intV;
373                    v->n += ex->weight;
374                    v->sum += ex->weight * cls_val;
375                    v->sum2 += ex->weight * cls_val * cls_val;
376                    size_attr_cls_known += ex->weight;
377            }
378        }
379        size_weight += ex->weight;
380    }
[8770]381
[10206]382    /* minimum examples in leaves */
383    if (!test_min_examples(attr_dist, attr_vals, args)) {
384        score = -INFINITY;
385        goto finish;
386    }
[8770]387
[10206]388    score = 0.0;
389    for (v = variances, v_end = variances + attr_vals; v < v_end; v++)
390        if (v->n > 0.0)
391            score += v->sum2 - v->sum * v->sum / v->n;
392    score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight);
[8770]393
[10206]394    if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0)
395        score = 0.0;
[8770]396
397finish:
[10206]398    free(attr_dist);
399    free(variances);
[8770]400
[10206]401    return score;
[8770]402}
403
404
405struct SimpleTreeNode *
406make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args)
407{
[10206]408    node->type = PredictorNode;
409    node->children_size = 0;
410    return node;
[8770]411}
412
413
414struct SimpleTreeNode *
415build_tree(struct Example *examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args)
416{
[10206]417    int i, cls_vals, best_attr;
418    float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split;
419    struct SimpleTreeNode *node;
420    struct Example *ex, *ex_end;
421    TVarList::const_iterator it;
[8770]422
[10206]423    cls_vals = args->domain->classVar->noOfValues();
[8770]424
[10206]425    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
[8770]426
[10206]427    if (args->type == Classification) {
428        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
[8770]429
[10206]430        if (size == 0) {
431            assert(parent);
432            node->type = PredictorNode;
433            node->children_size = 0;
434            memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist);
435            return node;
436        }
[8770]437
[10206]438        /* class distribution */
439        size_weight = 0.0;
440        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
441            if (!ex->example->getClass().isSpecial()) {
442                node->dist[ex->example->getClass().intV] += ex->weight;
443                size_weight += ex->weight;
444            }
[8770]445
[10206]446        /* stopping criterion: majority class */
447        for (i = 0; i < cls_vals; i++)
448            if (node->dist[i] / size_weight >= args->maxMajority)
449                return make_predictor(node, examples, size, args);
[8770]450
[10206]451        cls_entropy = entropy(node->dist, cls_vals);
452    } else {
453        float n, sum, sum2, cls_val;
[8770]454
[10206]455        assert(args->type == Regression);
456        if (size == 0) {
457            assert(parent);
458            node->type = PredictorNode;
459            node->children_size = 0;
460            node->n = parent->n;
461            node->sum = parent->sum;
462            return node;
463        }
[8770]464
[10206]465        n = sum = sum2 = 0.0;
466        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
467            if (!ex->example->getClass().isSpecial()) {
468                cls_val = ex->example->getClass().floatV;
469                n += ex->weight;
470                sum += ex->weight * cls_val;
471                sum2 += ex->weight * cls_val * cls_val;
472            }
[8770]473
[10206]474        node->n = n;
475        node->sum = sum;
476        cls_mse = (sum2 - sum * sum / n) / n;
[9168]477
[10206]478        if (cls_mse < 1e-5) {
479            return make_predictor(node, examples, size, args);
480        }
481    }
[8770]482
[10206]483    /* stopping criterion: depth exceeds limit */
484    if (depth == args->maxDepth)
485        return make_predictor(node, examples, size, args);
[8770]486
[10206]487    /* score attributes */
488    best_score = -INFINITY;
[8770]489
[10206]490    for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) {
491        if (!args->attr_split_so_far[i]) {
492            /* select random subset of attributes */
[9296]493            if (args->randomGenerator->randdouble() < args->skipProb)
[10206]494                continue;
[8770]495
[10206]496            if ((*it)->varType == TValue::INTVAR) {
497                score = args->type == Classification ?
498                  gain_ratio_d(examples, size, i, cls_entropy, args) :
499                  mse_d(examples, size, i, cls_mse, args);
500                if (score > best_score) {
501                    best_score = score;
502                    best_attr = i;
503                }
504            } else if ((*it)->varType == TValue::FLOATVAR) {
505                score = args->type == Classification ?
506                  gain_ratio_c(examples, size, i, cls_entropy, args, &split) :
507                  mse_c(examples, size, i, cls_mse, args, &split);
508                if (score > best_score) {
509                    best_score = score;
510                    best_split = split;
511                    best_attr = i;
512                }
513            }
514        }
515    }
[8378]516
[10206]517    if (best_score == -INFINITY)
518        return make_predictor(node, examples, size, args);
[8378]519
[10206]520    if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) {
521        struct Example *child_examples, *child_ex;
522        int attr_vals;
523        float size_known, *attr_dist;
[8378]524
[10206]525        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */
[8378]526
[10206]527        attr_vals = args->domain->attributes->at(best_attr)->noOfValues(); 
[8378]528
[10206]529        node->type = DiscreteNode;
530        node->split_attr = best_attr;
531        node->children_size = attr_vals;
[8378]532
[10206]533        ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples));
534        ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children));
535        ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist));
[8378]536
[10206]537        /* attribute distribution */
538        size_known = 0;
539        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
540            if (!ex->example->values[best_attr].isSpecial()) {
541                attr_dist[ex->example->values[best_attr].intV] += ex->weight;
542                size_known += ex->weight;
543            }
[8378]544
[10206]545        args->attr_split_so_far[best_attr] = 1;
[8378]546
[10206]547        for (i = 0; i < attr_vals; i++) {
548            /* create a new example table */
549            for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) {
550                if (ex->example->values[best_attr].isSpecial()) {
551                    *child_ex = *ex;
552                    child_ex->weight *= attr_dist[i] / size_known;
553                    child_ex++;
554                } else if (ex->example->values[best_attr].intV == i) {
555                    *child_ex++ = *ex;
556                }
557            }
[8378]558
[10206]559            node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args);
560        }
561                   
562        args->attr_split_so_far[best_attr] = 0;
[8378]563
[10206]564        free(attr_dist);
565        free(child_examples);
566    } else {
567        struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge;
568        float size_lt, size_ge;
[8378]569
[10206]570        /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */
[8770]571
[10206]572        assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR);
[8770]573
[10206]574        ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples));
575        ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples));
[8770]576
[10206]577        size_lt = size_ge = 0.0;
578        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++)
579            if (!ex->example->values[best_attr].isSpecial())
580                if (ex->example->values[best_attr].floatV < best_split)
581                    size_lt += ex->weight;
582                else
583                    size_ge += ex->weight;
[8770]584
[10206]585        for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++)
586            if (ex->example->values[best_attr].isSpecial()) {
587                *ex_lt = *ex;
588                *ex_ge = *ex;
589                ex_lt->weight *= size_lt / (size_lt + size_ge);
590                ex_ge->weight *= size_ge / (size_lt + size_ge);
591                ex_lt++;
592                ex_ge++;
593            } else if (ex->example->values[best_attr].floatV < best_split) {
594                *ex_lt++ = *ex;
595            } else {
596                *ex_ge++ = *ex;
597            }
[8378]598
[10206]599        node->type = ContinuousNode;
600        node->split_attr = best_attr;
601        node->split = best_split;
602        node->children_size = 2;
603        ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children));
[8378]604
[10206]605        node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args);
606        node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args);
[8378]607
[10206]608        free(examples_lt);
609        free(examples_ge);
610    }
[8378]611
[10206]612    return node;
[8378]613}
614
[9296]615TSimpleTreeLearner::TSimpleTreeLearner(const int &weight, float maxMajority, int minInstances, int maxDepth, float skipProb, PRandomGenerator rgen) :
[10206]616    maxMajority(maxMajority),
617    minInstances(minInstances),
618    maxDepth(maxDepth),
619    skipProb(skipProb)
[8378]620{
[9296]621    randomGenerator = rgen ? rgen : PRandomGenerator(mlnew TRandomGenerator());
[8378]622}
623
[8770]624PClassifier
[8378]625TSimpleTreeLearner::operator()(PExampleGenerator ogen, const int &weight)
[8770]626{
[10206]627    struct Example *examples, *ex;
628    struct SimpleTreeNode *tree;
629    struct Args args;
630    int cls_vals;
[8378]631
[10206]632    if (!ogen->domain->classVar)
633        raiseError("class-less domain");
[8378]634
[10767]635    if (!ogen->numberOfExamples() > 0)
636        raiseError("no examples");
637
[10206]638    /* create a tabel with pointers to examples */
639    ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples));
640    ex = examples;
641    PEITERATE(ei, ogen) {
642        ex->example = &(*ei);
643        ex->weight = 1.0;
644        ex++;
645    }
[8770]646
[10206]647    ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int)));
648    args.minInstances = minInstances;
649    args.maxMajority = maxMajority;
650    args.maxDepth = maxDepth;
651    args.skipProb = skipProb;
652    args.domain = ogen->domain;
[9296]653    args.randomGenerator = randomGenerator;
[10206]654    args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression;
655    cls_vals = ogen->domain->classVar->noOfValues();
[8378]656
[10206]657    tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args);
[8378]658
[10206]659    free(examples);
660    free(args.attr_split_so_far);
[8378]661
[10206]662    return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type, cls_vals);
[8378]663}
664
665
666/* classifier */
667TSimpleTreeClassifier::TSimpleTreeClassifier()
668{
669}
670
[10206]671TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type, int cls_vals) : 
672    TClassifier(classVar, true),
673    tree(tree),
674    type(type),
675    cls_vals(cls_vals)
[8378]676{
677}
678
[8770]679
[8378]680void
[8770]681destroy_tree(struct SimpleTreeNode *node, int type)
[8378]682{
[10206]683    int i;
[8378]684
[10206]685    if (node->type != PredictorNode) {
686        for (i = 0; i < node->children_size; i++)
687            destroy_tree(node->children[i], type);
688        free(node->children);
689    }
690    if (type == Classification)
691        free(node->dist);
692    free(node);
[8378]693}
694
[8770]695
[8378]696TSimpleTreeClassifier::~TSimpleTreeClassifier()
697{
[10206]698    destroy_tree(tree, type);
[8378]699}
700
[8770]701
702float *
[10492]703predict_classification(const TExample &ex, struct SimpleTreeNode *node, int *free_dist, int cls_vals)
[8378]704{
[10492]705    int i, j;
[10206]706    float *dist, *child_dist;
[8770]707
[10206]708    while (node->type != PredictorNode)
709        if (ex.values[node->split_attr].isSpecial()) {
710            ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist));
711            for (i = 0; i < node->children_size; i++) {
[10492]712                child_dist = predict_classification(ex, node->children[i], free_dist, cls_vals);
[10206]713                for (j = 0; j < cls_vals; j++)
714                    dist[j] += child_dist[j];
715                if (*free_dist)
716                    free(child_dist);
717            }
718            *free_dist = 1;
719            return dist;
720        } else if (node->type == DiscreteNode) {
721            node = node->children[ex.values[node->split_attr].intV];
722        } else {
723            assert(node->type == ContinuousNode);
724            node = node->children[ex.values[node->split_attr].floatV >= node->split];
725        }
[8770]726
[10206]727    *free_dist = 0;
728    return node->dist;
[8378]729}
730
[8770]731
732void
733predict_regression(const TExample &ex, struct SimpleTreeNode *node, float *sum, float *n)
734{
[10206]735    int i;
736    float local_sum, local_n;
[8770]737
[10206]738    while (node->type != PredictorNode) {
739        if (ex.values[node->split_attr].isSpecial()) {
740            *sum = *n = 0;
741            for (i = 0; i < node->children_size; i++) {
742                predict_regression(ex, node->children[i], &local_sum, &local_n);
743                *sum += local_sum;
744                *n += local_n;
745            }
746            return;
747        } else if (node->type == DiscreteNode) {
748            assert(ex.values[node->split_attr].intV < node->children_size);
749            node = node->children[ex.values[node->split_attr].intV];
750        } else {
751            assert(node->type == ContinuousNode);
752            node = node->children[ex.values[node->split_attr].floatV > node->split];
753        }
754    }
[8770]755
[10206]756    *sum = node->sum;
757    *n = node->n;
758}
759
760
761void
762TSimpleTreeClassifier::save_tree(ostringstream &ss, struct SimpleTreeNode *node)
763{
764    int i;
765
766    ss << "{ " << node->type << " " << node->children_size << " ";
767
768    if (node->type != PredictorNode)
769        ss << node->split_attr << " " << node->split << " ";
770
771    for (i = 0; i < node->children_size; i++)
772        this->save_tree(ss, node->children[i]);
773
774    if (this->type == Classification) {
775        for (i = 0; i < this->cls_vals; i++)
776            ss << node->dist[i] << " ";
777    } else {
778        assert(this->type == Regression);
779        ss << node->n << " " << node->sum << " ";
780    }
781    ss << "} ";
782}
783
784struct SimpleTreeNode *
785TSimpleTreeClassifier::load_tree(istringstream &ss)
786{
787    int i;
788    string lbracket, rbracket;
[10252]789    string split_string;
[10206]790    SimpleTreeNode *node;
791
[10252]792    ss.exceptions(istream::failbit);
793
[10206]794    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node));
795    ss >> lbracket >> node->type >> node->children_size;
796
[10252]797
[10206]798    if (node->type != PredictorNode)
[10252]799    {
800        ss >> node->split_attr;
801
802        /* Read split into a string and use strtod to parse it.
803         * istream sometimes (on some platforms) seems to have problems
804         * reading formated floats.
805         */
806        ss >> split_string;
807        node->split = float(strtod(split_string.c_str(), NULL));
808    }
[10206]809
810    if (node->children_size) {
811        ASSERT(node->children = (SimpleTreeNode **)calloc(node->children_size, sizeof *node->children));
812        for (i = 0; i < node->children_size; i++)
813            node->children[i] = load_tree(ss);
814    }
815
816    if (this->type == Classification) {
817        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *)));
818        for (i = 0; i < this->cls_vals; i++)
819            ss >> node->dist[i];
820    } else {
821        assert(this->type == Regression);
822        ss >> node->n >> node->sum;
823    }
824    ss >> rbracket;
825
826    /* Synchronization check */
827    assert(lbracket == "{" && rbracket == "}");
828
829    return node;
830}
831
832void
833TSimpleTreeClassifier::save_model(ostringstream &ss)
834{
835    ss.precision(9); /* we have floats */
836    ss << this->type << " " << this->cls_vals << " ";
837    this->save_tree(ss, this->tree);
838}
839
840void
841TSimpleTreeClassifier::load_model(istringstream &ss)
842{
843    ss >> this->type >> this->cls_vals;
844    this->tree = load_tree(ss);
[8770]845}
846
847
848TValue
[8378]849TSimpleTreeClassifier::operator()(const TExample &ex)
850{
[10206]851    if (type == Classification) {
852        int i, free_dist, best_val;
853        float *dist;
[8378]854
[10492]855        dist = predict_classification(ex, tree, &free_dist, this->cls_vals);
[10206]856        best_val = 0;
[10492]857        for (i = 1; i < this->cls_vals; i++)
[10206]858            if (dist[i] > dist[best_val])
859                best_val = i;
[8378]860
[10206]861        if (free_dist)
862            free(dist);
863        return TValue(best_val);
864    } else {
865        float sum, n;
[8770]866
[10206]867        assert(type == Regression);
[8770]868
[10206]869        predict_regression(ex, tree, &sum, &n);
870        return TValue(sum / n);
871    }
[8378]872}
873
874PDistribution
875TSimpleTreeClassifier::classDistribution(const TExample &ex)
876{
[10206]877    if (type == Classification) {
878        int i, free_dist;
879        float *dist;
[8378]880
[10492]881        dist = predict_classification(ex, tree, &free_dist, this->cls_vals);
[8378]882
[10492]883        PDistribution pdist = mlnew TDiscDistribution(this->cls_vals, 0.0);
[10965]884        pdist->variable = this->classVar;
[10492]885        for (i = 0; i < this->cls_vals; i++)
[10206]886            pdist->setint(i, dist[i]);
887        pdist->normalize();
[8378]888
[10206]889        if (free_dist)
890            free(dist);
891        return pdist;
892    } else {
893        return NULL;
894    }
[8378]895}
896
897void
898TSimpleTreeClassifier::predictionAndDistribution(const TExample &ex, TValue &value, PDistribution &dist)
899{
[10206]900    value = operator()(ex);
901    dist = classDistribution(ex);
[8378]902}
Note: See TracBrowser for help on using the repository browser.