Changeset 10206:f27323012018 in orange


Ignore:
Timestamp:
02/14/12 15:16:59 (2 years ago)
Author:
Jure Zbontar <jure.zbontar@…>
Branch:
default
Message:

Implement pickling for SimpleTreeLearner. Closes #1096.

Location:
source/orange
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • source/orange/lib_learner.cpp

    r8981 r10206  
    2323  #pragma warning (disable : 4786 4114 4018 4267 4244) 
    2424#endif 
     25 
     26#include <iostream> 
     27#include <sstream> 
    2528 
    2629#include "vars.hpp" 
     
    5558C_CALL(SimpleTreeLearner - Orange.classification.tree.SimpleTreeLearner, Learner, "([examples], [maxMajority=, minExamples=, maxDepth=])") 
    5659C_NAMED(SimpleTreeClassifier - Orange.classification.tree.SimpleTreeClassifier, Classifier, "()") 
     60 
     61PyObject *SimpleTreeClassifier__reduce__(PyObject *self) 
     62{ 
     63    PyTRY 
     64    ostringstream ss; 
     65 
     66    CAST_TO(TSimpleTreeClassifier, classifier); 
     67    classifier->save_model(ss); 
     68    return Py_BuildValue("O(s)N", getExportedFunction("__pickleLoaderSimpleTreeClassifier"),  
     69        ss.str().c_str(), packOrangeDictionary(self)); 
     70    PyCATCH 
     71} 
     72 
     73PyObject *__pickleLoaderSimpleTreeClassifier(PyObject *self, PyObject *args) PYARGS(METH_VARARGS, "(buffer)") 
     74{ 
     75    PyTRY 
     76    char *cbuf; 
     77    istringstream ss; 
     78 
     79    int buffer_size = 0; 
     80    if (!PyArg_ParseTuple(args, "s:__pickleLoaderSimpleTreeClassifier", &cbuf)) 
     81        return NULL; 
     82    ss.str(string(cbuf)); 
     83    PSimpleTreeClassifier classifier = mlnew TSimpleTreeClassifier(); 
     84    classifier->load_model(ss); 
     85    return WrapOrange(classifier); 
     86    PyCATCH 
     87} 
     88 
    5789 
    5890/* ************ MAJORITY AND COST ************ */ 
  • source/orange/tdidt_simple.cpp

    r9296 r10206  
    11/* 
    2     This file is part of Orange. 
    3  
    4     Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana 
    5     Contact: janez.demsar@fri.uni-lj.si 
    6  
    7     Orange is free software: you can redistribute it and/or modify 
    8     it under the terms of the GNU General Public License as published by 
    9     the Free Software Foundation, either version 3 of the License, or 
    10     (at your option) any later version. 
    11  
    12     Orange is distributed in the hope that it will be useful, 
    13     but WITHOUT ANY WARRANTY; without even the implied warranty of 
    14     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
    15     GNU General Public License for more details. 
    16  
    17     You should have received a copy of the GNU General Public License 
    18     along with Orange.  If not, see <http://www.gnu.org/licenses/>. 
     2    This file is part of Orange. 
     3 
     4    Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana 
     5    Contact: janez.demsar@fri.uni-lj.si 
     6 
     7    Orange is free software: you can redistribute it and/or modify 
     8    it under the terms of the GNU General Public License as published by 
     9    the Free Software Foundation, either version 3 of the License, or 
     10    (at your option) any later version. 
     11 
     12    Orange is distributed in the hope that it will be useful, 
     13    but WITHOUT ANY WARRANTY; without even the implied warranty of 
     14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     15    GNU General Public License for more details. 
     16 
     17    You should have received a copy of the GNU General Public License 
     18    along with Orange.  If not, see <http://www.gnu.org/licenses/>. 
    1919*/ 
    2020 
     21#include <iostream> 
     22#include <sstream> 
    2123#include <math.h> 
    2224#include <stdlib.h> 
     
    3436 
    3537#ifndef _MSC_VER 
    36     #include "err.h" 
    37     #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__) 
     38    #include "err.h" 
     39    #define ASSERT(x) if (!(x)) err(1, "%s:%d", __FILE__, __LINE__) 
    3840#else 
    39     #define ASSERT(x) if(!(x)) exit(1) 
    40     #define log2f(x) log((double) (x)) / log(2.0) 
     41    #define ASSERT(x) if(!(x)) exit(1) 
     42    #define log2f(x) log((double) (x)) / log(2.0) 
    4143#endif // _MSC_VER 
    4244 
    4345#ifndef INFINITY 
    44     #include <limits> 
    45     #define INFINITY numeric_limits<float>::infinity() 
     46    #include <limits> 
     47    #define INFINITY numeric_limits<float>::infinity() 
    4648#endif // INFINITY 
    4749 
    4850struct Args { 
    49     int minInstances, maxDepth; 
    50     float maxMajority, skipProb; 
    51  
    52     int type, *attr_split_so_far; 
    53     PDomain domain; 
     51    int minInstances, maxDepth; 
     52    float maxMajority, skipProb; 
     53 
     54    int type, *attr_split_so_far; 
     55    PDomain domain; 
    5456    PRandomGenerator randomGenerator; 
    5557}; 
    5658 
    5759struct Example { 
    58     TExample *example; 
    59     float weight; 
     60    TExample *example; 
     61    float weight; 
    6062}; 
    6163 
     
    7173compar_examples(const void *ptr1, const void *ptr2) 
    7274{ 
    73     struct Example *e1, *e2; 
    74  
    75     e1 = (struct Example *)ptr1; 
    76     e2 = (struct Example *)ptr2; 
    77     if (e1->example->values[compar_attr].isSpecial()) 
    78         return 1; 
    79     if (e2->example->values[compar_attr].isSpecial()) 
    80         return -1; 
    81     return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]); 
     75    struct Example *e1, *e2; 
     76 
     77    e1 = (struct Example *)ptr1; 
     78    e2 = (struct Example *)ptr2; 
     79    if (e1->example->values[compar_attr].isSpecial()) 
     80        return 1; 
     81    if (e2->example->values[compar_attr].isSpecial()) 
     82        return -1; 
     83    return e1->example->values[compar_attr].compare(e2->example->values[compar_attr]); 
    8284} 
    8385 
     
    8688entropy(float *xs, int size) 
    8789{ 
    88     float *ip, *end, sum, e; 
    89  
    90     for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++) 
    91         if (*ip > 0.0) { 
    92             e -= *ip * log2f(*ip); 
    93             sum += *ip; 
    94         } 
    95  
    96     return sum == 0.0 ? 0.0 : e / sum + log2f(sum); 
     90    float *ip, *end, sum, e; 
     91 
     92    for (ip = xs, end = xs + size, e = 0.0, sum = 0.0; ip != end; ip++) 
     93        if (*ip > 0.0) { 
     94            e -= *ip * log2f(*ip); 
     95            sum += *ip; 
     96        } 
     97 
     98    return sum == 0.0 ? 0.0 : e / sum + log2f(sum); 
    9799} 
    98100 
     
    100102test_min_examples(float *attr_dist, int attr_vals, struct Args *args) 
    101103{ 
    102     int i; 
    103  
    104     for (i = 0; i < attr_vals; i++) { 
    105         if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances) 
    106             return 0; 
    107     } 
    108     return 1; 
     104    int i; 
     105 
     106    for (i = 0; i < attr_vals; i++) { 
     107        if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances) 
     108            return 0; 
     109    } 
     110    return 1; 
    109111} 
    110112 
     
    112114gain_ratio_c(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args, float *best_split) 
    113115{ 
    114     struct Example *ex, *ex_end, *ex_next; 
    115     int i, cls, cls_vals, minInstances, size_known; 
    116     float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight; 
    117  
    118     cls_vals = args->domain->classVar->noOfValues(); 
    119  
    120     /* minInstances should be at least 1, otherwise there is no point in splitting */ 
    121     minInstances = args->minInstances < 1 ? 1 : args->minInstances; 
    122  
    123     /* allocate space */ 
    124     ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt)); 
    125     ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge)); 
    126     ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist)); 
    127  
    128     /* sort */ 
    129     compar_attr = attr; 
    130     qsort(examples, size, sizeof(struct Example), compar_examples); 
    131  
    132     /* compute gain ratio for every split */ 
    133     size_known = size; 
    134     size_weight = 0.0; 
    135     for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
    136         if (ex->example->values[attr].isSpecial()) { 
    137             size_known = ex - examples; 
    138             break; 
    139         } 
    140         if (!ex->example->getClass().isSpecial()) 
    141             dist_ge[ex->example->getClass().intV] += ex->weight; 
    142         size_weight += ex->weight; 
    143     } 
    144  
    145     attr_dist[1] = size_weight; 
    146     best_score = -INFINITY; 
    147  
    148     for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { 
    149         if (!ex->example->getClass().isSpecial()) { 
    150             cls = ex->example->getClass().intV; 
    151             dist_lt[cls] += ex->weight; 
    152             dist_ge[cls] -= ex->weight; 
    153         } 
    154         attr_dist[0] += ex->weight; 
    155         attr_dist[1] -= ex->weight; 
    156  
    157         if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) 
    158             continue; 
    159  
    160         /* gain ratio */ 
    161         score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight; 
    162         score = (cls_entropy - score) / entropy(attr_dist, 2); 
    163  
    164  
    165         if (score > best_score) { 
    166             best_score = score; 
    167             *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; 
    168         } 
    169     } 
    170  
    171     /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ 
    172  
    173     /* cleanup */ 
    174     free(dist_lt); 
    175     free(dist_ge); 
    176     free(attr_dist); 
    177  
    178     return best_score; 
     116    struct Example *ex, *ex_end, *ex_next; 
     117    int i, cls, cls_vals, minInstances, size_known; 
     118    float score, *dist_lt, *dist_ge, *attr_dist, best_score, size_weight; 
     119 
     120    cls_vals = args->domain->classVar->noOfValues(); 
     121 
     122    /* minInstances should be at least 1, otherwise there is no point in splitting */ 
     123    minInstances = args->minInstances < 1 ? 1 : args->minInstances; 
     124 
     125    /* allocate space */ 
     126    ASSERT(dist_lt = (float *)calloc(cls_vals, sizeof *dist_lt)); 
     127    ASSERT(dist_ge = (float *)calloc(cls_vals, sizeof *dist_ge)); 
     128    ASSERT(attr_dist = (float *)calloc(2, sizeof *attr_dist)); 
     129 
     130    /* sort */ 
     131    compar_attr = attr; 
     132    qsort(examples, size, sizeof(struct Example), compar_examples); 
     133 
     134    /* compute gain ratio for every split */ 
     135    size_known = size; 
     136    size_weight = 0.0; 
     137    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
     138        if (ex->example->values[attr].isSpecial()) { 
     139            size_known = ex - examples; 
     140            break; 
     141        } 
     142        if (!ex->example->getClass().isSpecial()) 
     143            dist_ge[ex->example->getClass().intV] += ex->weight; 
     144        size_weight += ex->weight; 
     145    } 
     146 
     147    attr_dist[1] = size_weight; 
     148    best_score = -INFINITY; 
     149 
     150    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { 
     151        if (!ex->example->getClass().isSpecial()) { 
     152            cls = ex->example->getClass().intV; 
     153            dist_lt[cls] += ex->weight; 
     154            dist_ge[cls] -= ex->weight; 
     155        } 
     156        attr_dist[0] += ex->weight; 
     157        attr_dist[1] -= ex->weight; 
     158 
     159        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) 
     160            continue; 
     161 
     162        /* gain ratio */ 
     163        score = (attr_dist[0] * entropy(dist_lt, cls_vals) + attr_dist[1] * entropy(dist_ge, cls_vals)) / size_weight; 
     164        score = (cls_entropy - score) / entropy(attr_dist, 2); 
     165 
     166 
     167        if (score > best_score) { 
     168            best_score = score; 
     169            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; 
     170        } 
     171    } 
     172 
     173    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ 
     174 
     175    /* cleanup */ 
     176    free(dist_lt); 
     177    free(dist_ge); 
     178    free(attr_dist); 
     179 
     180    return best_score; 
    179181} 
    180182 
     
    183185gain_ratio_d(struct Example *examples, int size, int attr, float cls_entropy, struct Args *args) 
    184186{ 
    185     struct Example *ex, *ex_end; 
    186     int i, cls_vals, attr_vals, attr_val, cls_val; 
    187     float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known; 
    188  
    189     cls_vals = args->domain->classVar->noOfValues(); 
    190     attr_vals = args->domain->attributes->at(attr)->noOfValues(); 
    191  
    192     /* allocate space */ 
    193     ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *))); 
    194     ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *))); 
    195     ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *))); 
    196  
    197     /* contingency matrix */ 
    198     size_weight = 0.0; 
    199     for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
    200         if (!ex->example->values[attr].isSpecial()) { 
    201             attr_val = ex->example->values[attr].intV; 
    202             attr_dist[attr_val] += ex->weight; 
    203             if (!ex->example->getClass().isSpecial()) { 
    204                 cls_val = ex->example->getClass().intV; 
    205                 attr_dist_cls_known[attr_val] += ex->weight; 
    206                 cont[attr_val * cls_vals + cls_val] += ex->weight; 
    207             } 
    208         } 
    209         size_weight += ex->weight; 
    210     } 
    211  
    212     /* min examples in leaves */ 
    213     if (!test_min_examples(attr_dist, attr_vals, args)) { 
    214         score = -INFINITY; 
    215         goto finish; 
    216     } 
    217  
    218     size_attr_known = size_attr_cls_known = 0.0; 
    219     for (i = 0; i < attr_vals; i++) { 
    220         size_attr_known += attr_dist[i]; 
    221         size_attr_cls_known += attr_dist_cls_known[i]; 
    222     } 
    223  
    224     /* gain ratio */ 
    225     score = 0.0; 
    226     for (i = 0; i < attr_vals; i++) 
    227         score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals); 
    228     attr_entropy = entropy(attr_dist, attr_vals); 
    229  
    230     if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) { 
    231         score = -INFINITY; 
    232         goto finish; 
    233     } 
    234  
    235     score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight); 
    236  
    237     /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */ 
     187    struct Example *ex, *ex_end; 
     188    int i, cls_vals, attr_vals, attr_val, cls_val; 
     189    float score, size_weight, size_attr_known, size_attr_cls_known, attr_entropy, *cont, *attr_dist, *attr_dist_cls_known; 
     190 
     191    cls_vals = args->domain->classVar->noOfValues(); 
     192    attr_vals = args->domain->attributes->at(attr)->noOfValues(); 
     193 
     194    /* allocate space */ 
     195    ASSERT(cont = (float *)calloc(cls_vals * attr_vals, sizeof(float *))); 
     196    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof(float *))); 
     197    ASSERT(attr_dist_cls_known = (float *)calloc(attr_vals, sizeof(float *))); 
     198 
     199    /* contingency matrix */ 
     200    size_weight = 0.0; 
     201    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
     202        if (!ex->example->values[attr].isSpecial()) { 
     203            attr_val = ex->example->values[attr].intV; 
     204            attr_dist[attr_val] += ex->weight; 
     205            if (!ex->example->getClass().isSpecial()) { 
     206                cls_val = ex->example->getClass().intV; 
     207                attr_dist_cls_known[attr_val] += ex->weight; 
     208                cont[attr_val * cls_vals + cls_val] += ex->weight; 
     209            } 
     210        } 
     211        size_weight += ex->weight; 
     212    } 
     213 
     214    /* min examples in leaves */ 
     215    if (!test_min_examples(attr_dist, attr_vals, args)) { 
     216        score = -INFINITY; 
     217        goto finish; 
     218    } 
     219 
     220    size_attr_known = size_attr_cls_known = 0.0; 
     221    for (i = 0; i < attr_vals; i++) { 
     222        size_attr_known += attr_dist[i]; 
     223        size_attr_cls_known += attr_dist_cls_known[i]; 
     224    } 
     225 
     226    /* gain ratio */ 
     227    score = 0.0; 
     228    for (i = 0; i < attr_vals; i++) 
     229        score += attr_dist_cls_known[i] * entropy(cont + i * cls_vals, cls_vals); 
     230    attr_entropy = entropy(attr_dist, attr_vals); 
     231 
     232    if (size_attr_cls_known == 0.0 || attr_entropy == 0.0 || size_weight == 0.0) { 
     233        score = -INFINITY; 
     234        goto finish; 
     235    } 
     236 
     237    score = (cls_entropy - score / size_attr_cls_known) / attr_entropy * ((float)size_attr_known / size_weight); 
     238 
     239    /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */ 
    238240 
    239241finish: 
    240     free(cont); 
    241     free(attr_dist); 
    242     free(attr_dist_cls_known); 
    243     return score; 
     242    free(cont); 
     243    free(attr_dist); 
     244    free(attr_dist_cls_known); 
     245    return score; 
    244246} 
    245247 
     
    248250mse_c(struct Example *examples, int size, int attr, float cls_mse, struct Args *args, float *best_split) 
    249251{ 
    250     struct Example *ex, *ex_end, *ex_next; 
    251     int i, cls_vals, minInstances, size_known; 
    252     float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score; 
    253  
    254     struct Variance { 
    255         double n, sum, sum2; 
    256     } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0}; 
    257  
    258     cls_vals = args->domain->classVar->noOfValues(); 
    259  
    260     /* minInstances should be at least 1, otherwise there is no point in splitting */ 
    261     minInstances = args->minInstances < 1 ? 1 : args->minInstances; 
    262  
    263     /* sort */ 
    264     compar_attr = attr; 
    265     qsort(examples, size, sizeof(struct Example), compar_examples); 
    266  
    267     /* compute mse for every split */ 
    268     size_known = size; 
    269     size_attr_known = 0.0; 
    270     for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
    271         if (ex->example->values[attr].isSpecial()) { 
    272             size_known = ex - examples; 
    273             break; 
    274         } 
    275         if (!ex->example->getClass().isSpecial()) { 
    276             cls_val = ex->example->getClass().floatV; 
    277             var_ge.n += ex->weight; 
    278             var_ge.sum += ex->weight * cls_val; 
    279             var_ge.sum2 += ex->weight * cls_val * cls_val; 
    280         } 
    281         size_attr_known += ex->weight; 
    282     } 
    283  
    284     /* count the remaining examples with unknown values */ 
    285     size_weight = size_attr_known; 
    286     for (ex_end = examples + size; ex < ex_end; ex++) 
    287         size_weight += ex->weight; 
    288  
    289     size_attr_cls_known = var_ge.n; 
    290     best_score = -INFINITY; 
    291  
    292     for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { 
    293         if (!ex->example->getClass().isSpecial()) { 
    294             cls_val = ex->example->getClass(); 
    295             var_lt.n += ex->weight; 
    296             var_lt.sum += ex->weight * cls_val; 
    297             var_lt.sum2 += ex->weight * cls_val * cls_val; 
    298  
    299             /* this calculation might be numarically unstable - fix */ 
    300             var_ge.n -= ex->weight; 
    301             var_ge.sum -= ex->weight * cls_val; 
    302             var_ge.sum2 -= ex->weight * cls_val * cls_val; 
    303         } 
    304  
    305         /* Naive calculation of variance (used for testing) 
    306          
    307         struct Example *ex2, *ex_end2; 
    308         float nlt, sumlt, sum2lt, nge, sumge, sum2ge; 
    309         nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0; 
    310  
    311         for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) { 
    312             cls_val = ex2->example->getClass(); 
    313             if (ex2 < ex) { 
    314                 nlt += ex2->weight; 
    315                 sumlt += ex2->weight * cls_val; 
    316                 sum2lt += ex2->weight * cls_val * cls_val; 
    317             } else { 
    318                 nge += ex2->weight; 
    319                 sumge += ex2->weight * cls_val; 
    320                 sum2ge += ex2->weight * cls_val * cls_val; 
    321             } 
    322         } 
    323         */ 
    324  
    325  
    326         if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) 
    327             continue; 
    328  
    329         /* compute mse */ 
    330         score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n; 
    331         score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n; 
    332  
    333         score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); 
    334  
    335         if (score > best_score) { 
    336             best_score = score; 
    337             *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; 
    338         } 
    339     } 
    340  
    341     /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ 
    342     return best_score; 
     252    struct Example *ex, *ex_end, *ex_next; 
     253    int i, cls_vals, minInstances, size_known; 
     254    float size_attr_known, size_weight, cls_val, cls_score, best_score, size_attr_cls_known, score; 
     255 
     256    struct Variance { 
     257        double n, sum, sum2; 
     258    } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0}; 
     259 
     260    cls_vals = args->domain->classVar->noOfValues(); 
     261 
     262    /* minInstances should be at least 1, otherwise there is no point in splitting */ 
     263    minInstances = args->minInstances < 1 ? 1 : args->minInstances; 
     264 
     265    /* sort */ 
     266    compar_attr = attr; 
     267    qsort(examples, size, sizeof(struct Example), compar_examples); 
     268 
     269    /* compute mse for every split */ 
     270    size_known = size; 
     271    size_attr_known = 0.0; 
     272    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
     273        if (ex->example->values[attr].isSpecial()) { 
     274            size_known = ex - examples; 
     275            break; 
     276        } 
     277        if (!ex->example->getClass().isSpecial()) { 
     278            cls_val = ex->example->getClass().floatV; 
     279            var_ge.n += ex->weight; 
     280            var_ge.sum += ex->weight * cls_val; 
     281            var_ge.sum2 += ex->weight * cls_val * cls_val; 
     282        } 
     283        size_attr_known += ex->weight; 
     284    } 
     285 
     286    /* count the remaining examples with unknown values */ 
     287    size_weight = size_attr_known; 
     288    for (ex_end = examples + size; ex < ex_end; ex++) 
     289        size_weight += ex->weight; 
     290 
     291    size_attr_cls_known = var_ge.n; 
     292    best_score = -INFINITY; 
     293 
     294    for (ex = examples, ex_end = ex + size_known - minInstances, ex_next = ex + 1, i = 0; ex < ex_end; ex++, ex_next++, i++) { 
     295        if (!ex->example->getClass().isSpecial()) { 
     296            cls_val = ex->example->getClass(); 
     297            var_lt.n += ex->weight; 
     298            var_lt.sum += ex->weight * cls_val; 
     299            var_lt.sum2 += ex->weight * cls_val * cls_val; 
     300 
     301            /* this calculation might be numarically unstable - fix */ 
     302            var_ge.n -= ex->weight; 
     303            var_ge.sum -= ex->weight * cls_val; 
     304            var_ge.sum2 -= ex->weight * cls_val * cls_val; 
     305        } 
     306 
     307        /* Naive calculation of variance (used for testing) 
     308         
     309        struct Example *ex2, *ex_end2; 
     310        float nlt, sumlt, sum2lt, nge, sumge, sum2ge; 
     311        nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0; 
     312 
     313        for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) { 
     314            cls_val = ex2->example->getClass(); 
     315            if (ex2 < ex) { 
     316                nlt += ex2->weight; 
     317                sumlt += ex2->weight * cls_val; 
     318                sum2lt += ex2->weight * cls_val * cls_val; 
     319            } else { 
     320                nge += ex2->weight; 
     321                sumge += ex2->weight * cls_val; 
     322                sum2ge += ex2->weight * cls_val * cls_val; 
     323            } 
     324        } 
     325        */ 
     326 
     327 
     328        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) 
     329            continue; 
     330 
     331        /* compute mse */ 
     332        score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n; 
     333        score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n; 
     334 
     335        score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); 
     336 
     337        if (score > best_score) { 
     338            best_score = score; 
     339            *best_split = (ex->example->values[attr].floatV + ex_next->example->values[attr].floatV) / 2.0; 
     340        } 
     341    } 
     342 
     343    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ 
     344    return best_score; 
    343345} 
    344346 
     
    347349mse_d(struct Example *examples, int size, int attr, float cls_mse, struct Args *args) 
    348350{ 
    349     int i, attr_vals, attr_val; 
    350     float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight; 
    351     struct Example *ex, *ex_end; 
    352  
    353     struct Variance { 
    354         float n, sum, sum2; 
    355     } *variances, *v, *v_end; 
    356  
    357     attr_vals = args->domain->attributes->at(attr)->noOfValues(); 
    358  
    359     ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances)); 
    360     ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); 
    361  
    362     size_weight = size_attr_cls_known = size_attr_known = 0.0; 
    363     for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
    364         if (!ex->example->values[attr].isSpecial()) { 
    365             attr_dist[ex->example->values[attr].intV] += ex->weight; 
    366             size_attr_known += ex->weight; 
    367  
    368             if (!ex->example->getClass().isSpecial()) { 
    369                     cls_val = ex->example->getClass().floatV; 
    370                     v = variances + ex->example->values[attr].intV; 
    371                     v->n += ex->weight; 
    372                     v->sum += ex->weight * cls_val; 
    373                     v->sum2 += ex->weight * cls_val * cls_val; 
    374                     size_attr_cls_known += ex->weight; 
    375             } 
    376         } 
    377         size_weight += ex->weight; 
    378     } 
    379  
    380     /* minimum examples in leaves */ 
    381     if (!test_min_examples(attr_dist, attr_vals, args)) { 
    382         score = -INFINITY; 
    383         goto finish; 
    384     } 
    385  
    386     score = 0.0; 
    387     for (v = variances, v_end = variances + attr_vals; v < v_end; v++) 
    388         if (v->n > 0.0) 
    389             score += v->sum2 - v->sum * v->sum / v->n; 
    390     score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); 
    391  
    392     if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0) 
    393         score = 0.0; 
     351    int i, attr_vals, attr_val; 
     352    float *attr_dist, d, score, cls_val, size_attr_cls_known, size_attr_known, size_weight; 
     353    struct Example *ex, *ex_end; 
     354 
     355    struct Variance { 
     356        float n, sum, sum2; 
     357    } *variances, *v, *v_end; 
     358 
     359    attr_vals = args->domain->attributes->at(attr)->noOfValues(); 
     360 
     361    ASSERT(variances = (struct Variance *)calloc(attr_vals, sizeof *variances)); 
     362    ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); 
     363 
     364    size_weight = size_attr_cls_known = size_attr_known = 0.0; 
     365    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
     366        if (!ex->example->values[attr].isSpecial()) { 
     367            attr_dist[ex->example->values[attr].intV] += ex->weight; 
     368            size_attr_known += ex->weight; 
     369 
     370            if (!ex->example->getClass().isSpecial()) { 
     371                    cls_val = ex->example->getClass().floatV; 
     372                    v = variances + ex->example->values[attr].intV; 
     373                    v->n += ex->weight; 
     374                    v->sum += ex->weight * cls_val; 
     375                    v->sum2 += ex->weight * cls_val * cls_val; 
     376                    size_attr_cls_known += ex->weight; 
     377            } 
     378        } 
     379        size_weight += ex->weight; 
     380    } 
     381 
     382    /* minimum examples in leaves */ 
     383    if (!test_min_examples(attr_dist, attr_vals, args)) { 
     384        score = -INFINITY; 
     385        goto finish; 
     386    } 
     387 
     388    score = 0.0; 
     389    for (v = variances, v_end = variances + attr_vals; v < v_end; v++) 
     390        if (v->n > 0.0) 
     391            score += v->sum2 - v->sum * v->sum / v->n; 
     392    score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); 
     393 
     394    if (size_attr_cls_known <= 0.0 || cls_mse <= 0.0 || size_weight <= 0.0) 
     395        score = 0.0; 
    394396 
    395397finish: 
    396     free(attr_dist); 
    397     free(variances); 
    398  
    399     return score; 
     398    free(attr_dist); 
     399    free(variances); 
     400 
     401    return score; 
    400402} 
    401403 
     
    404406make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args) 
    405407{ 
    406     node->type = PredictorNode; 
    407     return node; 
     408    node->type = PredictorNode; 
     409    node->children_size = 0; 
     410    return node; 
    408411} 
    409412 
     
    412415build_tree(struct Example *examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args) 
    413416{ 
    414     int i, cls_vals, best_attr; 
    415     float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split; 
    416     struct SimpleTreeNode *node; 
    417     struct Example *ex, *ex_end; 
    418     TVarList::const_iterator it; 
    419  
    420     cls_vals = args->domain->classVar->noOfValues(); 
    421  
    422     ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node)); 
    423  
    424     if (args->type == Classification) { 
    425         ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *))); 
    426  
    427         if (size == 0) { 
    428             assert(parent); 
    429             node->type = PredictorNode; 
    430             memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist); 
    431             return node; 
    432         } 
    433  
    434         /* class distribution */ 
    435         size_weight = 0.0; 
    436         for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
    437             if (!ex->example->getClass().isSpecial()) { 
    438                 node->dist[ex->example->getClass().intV] += ex->weight; 
    439                 size_weight += ex->weight; 
    440             } 
    441  
    442         /* stopping criterion: majority class */ 
    443         for (i = 0; i < cls_vals; i++) 
    444             if (node->dist[i] / size_weight >= args->maxMajority) 
    445                 return make_predictor(node, examples, size, args); 
    446  
    447         cls_entropy = entropy(node->dist, cls_vals); 
    448     } else { 
    449         float n, sum, sum2, cls_val; 
    450  
    451         assert(args->type == Regression); 
    452         if (size == 0) { 
    453             assert(parent); 
    454             node->type = PredictorNode; 
    455             node->n = parent->n; 
    456             node->sum = parent->sum; 
    457             return node; 
    458         } 
    459  
    460         n = sum = sum2 = 0.0; 
    461         for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
    462             if (!ex->example->getClass().isSpecial()) { 
    463                 cls_val = ex->example->getClass().floatV; 
    464                 n += ex->weight; 
    465                 sum += ex->weight * cls_val; 
    466                 sum2 += ex->weight * cls_val * cls_val; 
    467             } 
    468  
    469         node->n = n; 
    470         node->sum = sum; 
    471         cls_mse = (sum2 - sum * sum / n) / n; 
    472  
    473         if (cls_mse < 1e-5) { 
    474             return make_predictor(node, examples, size, args); 
    475         } 
    476     } 
    477  
    478     /* stopping criterion: depth exceeds limit */ 
    479     if (depth == args->maxDepth) 
    480         return make_predictor(node, examples, size, args); 
    481  
    482     /* score attributes */ 
    483     best_score = -INFINITY; 
    484  
    485     for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) { 
    486         if (!args->attr_split_so_far[i]) { 
    487             /* select random subset of attributes */ 
     417    int i, cls_vals, best_attr; 
     418    float cls_entropy, cls_mse, best_score, score, size_weight, best_split, split; 
     419    struct SimpleTreeNode *node; 
     420    struct Example *ex, *ex_end; 
     421    TVarList::const_iterator it; 
     422 
     423    cls_vals = args->domain->classVar->noOfValues(); 
     424 
     425    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node)); 
     426 
     427    if (args->type == Classification) { 
     428        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *))); 
     429 
     430        if (size == 0) { 
     431            assert(parent); 
     432            node->type = PredictorNode; 
     433            node->children_size = 0; 
     434            memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist); 
     435            return node; 
     436        } 
     437 
     438        /* class distribution */ 
     439        size_weight = 0.0; 
     440        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
     441            if (!ex->example->getClass().isSpecial()) { 
     442                node->dist[ex->example->getClass().intV] += ex->weight; 
     443                size_weight += ex->weight; 
     444            } 
     445 
     446        /* stopping criterion: majority class */ 
     447        for (i = 0; i < cls_vals; i++) 
     448            if (node->dist[i] / size_weight >= args->maxMajority) 
     449                return make_predictor(node, examples, size, args); 
     450 
     451        cls_entropy = entropy(node->dist, cls_vals); 
     452    } else { 
     453        float n, sum, sum2, cls_val; 
     454 
     455        assert(args->type == Regression); 
     456        if (size == 0) { 
     457            assert(parent); 
     458            node->type = PredictorNode; 
     459            node->children_size = 0; 
     460            node->n = parent->n; 
     461            node->sum = parent->sum; 
     462            return node; 
     463        } 
     464 
     465        n = sum = sum2 = 0.0; 
     466        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
     467            if (!ex->example->getClass().isSpecial()) { 
     468                cls_val = ex->example->getClass().floatV; 
     469                n += ex->weight; 
     470                sum += ex->weight * cls_val; 
     471                sum2 += ex->weight * cls_val * cls_val; 
     472            } 
     473 
     474        node->n = n; 
     475        node->sum = sum; 
     476        cls_mse = (sum2 - sum * sum / n) / n; 
     477 
     478        if (cls_mse < 1e-5) { 
     479            return make_predictor(node, examples, size, args); 
     480        } 
     481    } 
     482 
     483    /* stopping criterion: depth exceeds limit */ 
     484    if (depth == args->maxDepth) 
     485        return make_predictor(node, examples, size, args); 
     486 
     487    /* score attributes */ 
     488    best_score = -INFINITY; 
     489 
     490    for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) { 
     491        if (!args->attr_split_so_far[i]) { 
     492            /* select random subset of attributes */ 
    488493            if (args->randomGenerator->randdouble() < args->skipProb) 
    489                 continue; 
    490  
    491             if ((*it)->varType == TValue::INTVAR) { 
    492                 score = args->type == Classification ? 
    493                   gain_ratio_d(examples, size, i, cls_entropy, args) : 
    494                   mse_d(examples, size, i, cls_mse, args); 
    495                 if (score > best_score) { 
    496                     best_score = score; 
    497                     best_attr = i; 
    498                 } 
    499             } else if ((*it)->varType == TValue::FLOATVAR) { 
    500                 score = args->type == Classification ? 
    501                   gain_ratio_c(examples, size, i, cls_entropy, args, &split) : 
    502                   mse_c(examples, size, i, cls_mse, args, &split); 
    503                 if (score > best_score) { 
    504                     best_score = score; 
    505                     best_split = split; 
    506                     best_attr = i; 
    507                 } 
    508             } 
    509         } 
    510     } 
    511  
    512     if (best_score == -INFINITY) 
    513         return make_predictor(node, examples, size, args); 
    514  
    515     if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) { 
    516         struct Example *child_examples, *child_ex; 
    517         int attr_vals; 
    518         float size_known, *attr_dist; 
    519  
    520         /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ 
    521  
    522         attr_vals = args->domain->attributes->at(best_attr)->noOfValues();  
    523  
    524         node->type = DiscreteNode; 
    525         node->split_attr = best_attr; 
    526         node->children_size = attr_vals; 
    527  
    528         ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples)); 
    529         ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children)); 
    530         ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); 
    531  
    532         /* attribute distribution */ 
    533         size_known = 0; 
    534         for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
    535             if (!ex->example->values[best_attr].isSpecial()) { 
    536                 attr_dist[ex->example->values[best_attr].intV] += ex->weight; 
    537                 size_known += ex->weight; 
    538             } 
    539  
    540         args->attr_split_so_far[best_attr] = 1; 
    541  
    542         for (i = 0; i < attr_vals; i++) { 
    543             /* create a new example table */ 
    544             for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) { 
    545                 if (ex->example->values[best_attr].isSpecial()) { 
    546                     *child_ex = *ex; 
    547                     child_ex->weight *= attr_dist[i] / size_known; 
    548                     child_ex++; 
    549                 } else if (ex->example->values[best_attr].intV == i) { 
    550                     *child_ex++ = *ex; 
    551                 } 
    552             } 
    553  
    554             node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args); 
    555         } 
    556                      
    557         args->attr_split_so_far[best_attr] = 0; 
    558  
    559         free(attr_dist); 
    560         free(child_examples); 
    561     } else { 
    562         struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge; 
    563         float size_lt, size_ge; 
    564  
    565         /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */ 
    566  
    567         assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR); 
    568  
    569         ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples)); 
    570         ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples)); 
    571  
    572         size_lt = size_ge = 0.0; 
    573         for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
    574             if (!ex->example->values[best_attr].isSpecial()) 
    575                 if (ex->example->values[best_attr].floatV < best_split) 
    576                     size_lt += ex->weight; 
    577                 else 
    578                     size_ge += ex->weight; 
    579  
    580         for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++) 
    581             if (ex->example->values[best_attr].isSpecial()) { 
    582                 *ex_lt = *ex; 
    583                 *ex_ge = *ex; 
    584                 ex_lt->weight *= size_lt / (size_lt + size_ge); 
    585                 ex_ge->weight *= size_ge / (size_lt + size_ge); 
    586                 ex_lt++; 
    587                 ex_ge++; 
    588             } else if (ex->example->values[best_attr].floatV < best_split) { 
    589                 *ex_lt++ = *ex; 
    590             } else { 
    591                 *ex_ge++ = *ex; 
    592             } 
    593  
    594         node->type = ContinuousNode; 
    595         node->split_attr = best_attr; 
    596         node->split = best_split; 
    597         node->children_size = 2; 
    598         ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children)); 
    599  
    600         node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args); 
    601         node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args); 
    602  
    603         free(examples_lt); 
    604         free(examples_ge); 
    605     } 
    606  
    607     return node; 
     494                continue; 
     495 
     496            if ((*it)->varType == TValue::INTVAR) { 
     497                score = args->type == Classification ? 
     498                  gain_ratio_d(examples, size, i, cls_entropy, args) : 
     499                  mse_d(examples, size, i, cls_mse, args); 
     500                if (score > best_score) { 
     501                    best_score = score; 
     502                    best_attr = i; 
     503                } 
     504            } else if ((*it)->varType == TValue::FLOATVAR) { 
     505                score = args->type == Classification ? 
     506                  gain_ratio_c(examples, size, i, cls_entropy, args, &split) : 
     507                  mse_c(examples, size, i, cls_mse, args, &split); 
     508                if (score > best_score) { 
     509                    best_score = score; 
     510                    best_split = split; 
     511                    best_attr = i; 
     512                } 
     513            } 
     514        } 
     515    } 
     516 
     517    if (best_score == -INFINITY) 
     518        return make_predictor(node, examples, size, args); 
     519 
     520    if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) { 
     521        struct Example *child_examples, *child_ex; 
     522        int attr_vals; 
     523        float size_known, *attr_dist; 
     524 
     525        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ 
     526 
     527        attr_vals = args->domain->attributes->at(best_attr)->noOfValues();  
     528 
     529        node->type = DiscreteNode; 
     530        node->split_attr = best_attr; 
     531        node->children_size = attr_vals; 
     532 
     533        ASSERT(child_examples = (struct Example *)calloc(size, sizeof *child_examples)); 
     534        ASSERT(node->children = (SimpleTreeNode **)calloc(attr_vals, sizeof *node->children)); 
     535        ASSERT(attr_dist = (float *)calloc(attr_vals, sizeof *attr_dist)); 
     536 
     537        /* attribute distribution */ 
     538        size_known = 0; 
     539        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
     540            if (!ex->example->values[best_attr].isSpecial()) { 
     541                attr_dist[ex->example->values[best_attr].intV] += ex->weight; 
     542                size_known += ex->weight; 
     543            } 
     544 
     545        args->attr_split_so_far[best_attr] = 1; 
     546 
     547        for (i = 0; i < attr_vals; i++) { 
     548            /* create a new example table */ 
     549            for (ex = examples, ex_end = examples + size, child_ex = child_examples; ex < ex_end; ex++) { 
     550                if (ex->example->values[best_attr].isSpecial()) { 
     551                    *child_ex = *ex; 
     552                    child_ex->weight *= attr_dist[i] / size_known; 
     553                    child_ex++; 
     554                } else if (ex->example->values[best_attr].intV == i) { 
     555                    *child_ex++ = *ex; 
     556                } 
     557            } 
     558 
     559            node->children[i] = build_tree(child_examples, child_ex - child_examples, depth + 1, node, args); 
     560        } 
     561                     
     562        args->attr_split_so_far[best_attr] = 0; 
     563 
     564        free(attr_dist); 
     565        free(child_examples); 
     566    } else { 
     567        struct Example *examples_lt, *examples_ge, *ex_lt, *ex_ge; 
     568        float size_lt, size_ge; 
     569 
     570        /* printf("* %2d %3s %3d %f %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_split, best_score); */ 
     571 
     572        assert(args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR); 
     573 
     574        ASSERT(examples_lt = (struct Example *)calloc(size, sizeof *examples)); 
     575        ASSERT(examples_ge = (struct Example *)calloc(size, sizeof *examples)); 
     576 
     577        size_lt = size_ge = 0.0; 
     578        for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
     579            if (!ex->example->values[best_attr].isSpecial()) 
     580                if (ex->example->values[best_attr].floatV < best_split) 
     581                    size_lt += ex->weight; 
     582                else 
     583                    size_ge += ex->weight; 
     584 
     585        for (ex = examples, ex_end = examples + size, ex_lt = examples_lt, ex_ge = examples_ge; ex < ex_end; ex++) 
     586            if (ex->example->values[best_attr].isSpecial()) { 
     587                *ex_lt = *ex; 
     588                *ex_ge = *ex; 
     589                ex_lt->weight *= size_lt / (size_lt + size_ge); 
     590                ex_ge->weight *= size_ge / (size_lt + size_ge); 
     591                ex_lt++; 
     592                ex_ge++; 
     593            } else if (ex->example->values[best_attr].floatV < best_split) { 
     594                *ex_lt++ = *ex; 
     595            } else { 
     596                *ex_ge++ = *ex; 
     597            } 
     598 
     599        node->type = ContinuousNode; 
     600        node->split_attr = best_attr; 
     601        node->split = best_split; 
     602        node->children_size = 2; 
     603        ASSERT(node->children = (SimpleTreeNode **)calloc(2, sizeof *node->children)); 
     604 
     605        node->children[0] = build_tree(examples_lt, ex_lt - examples_lt, depth + 1, node, args); 
     606        node->children[1] = build_tree(examples_ge, ex_ge - examples_ge, depth + 1, node, args); 
     607 
     608        free(examples_lt); 
     609        free(examples_ge); 
     610    } 
     611 
     612    return node; 
    608613} 
    609614 
    610615TSimpleTreeLearner::TSimpleTreeLearner(const int &weight, float maxMajority, int minInstances, int maxDepth, float skipProb, PRandomGenerator rgen) : 
    611     maxMajority(maxMajority), 
    612     minInstances(minInstances), 
    613     maxDepth(maxDepth), 
    614     skipProb(skipProb) 
     616    maxMajority(maxMajority), 
     617    minInstances(minInstances), 
     618    maxDepth(maxDepth), 
     619    skipProb(skipProb) 
    615620{ 
    616621    randomGenerator = rgen ? rgen : PRandomGenerator(mlnew TRandomGenerator()); 
     
    620625TSimpleTreeLearner::operator()(PExampleGenerator ogen, const int &weight) 
    621626{ 
    622     struct Example *examples, *ex; 
    623     struct SimpleTreeNode *tree; 
    624     struct Args args; 
    625  
    626     if (!ogen->domain->classVar) 
    627         raiseError("class-less domain"); 
    628  
    629     /* create a tabel with pointers to examples */ 
    630     ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples)); 
    631     ex = examples; 
    632     PEITERATE(ei, ogen) { 
    633         ex->example = &(*ei); 
    634         ex->weight = 1.0; 
    635         ex++; 
    636     } 
    637  
    638     ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int))); 
    639     args.minInstances = minInstances; 
    640     args.maxMajority = maxMajority; 
    641     args.maxDepth = maxDepth; 
    642     args.skipProb = skipProb; 
    643     args.domain = ogen->domain; 
     627    struct Example *examples, *ex; 
     628    struct SimpleTreeNode *tree; 
     629    struct Args args; 
     630    int cls_vals; 
     631 
     632    if (!ogen->domain->classVar) 
     633        raiseError("class-less domain"); 
     634 
     635    /* create a tabel with pointers to examples */ 
     636    ASSERT(examples = (struct Example *)calloc(ogen->numberOfExamples(), sizeof *examples)); 
     637    ex = examples; 
     638    PEITERATE(ei, ogen) { 
     639        ex->example = &(*ei); 
     640        ex->weight = 1.0; 
     641        ex++; 
     642    } 
     643 
     644    ASSERT(args.attr_split_so_far = (int *)calloc(ogen->domain->attributes->size(), sizeof(int))); 
     645    args.minInstances = minInstances; 
     646    args.maxMajority = maxMajority; 
     647    args.maxDepth = maxDepth; 
     648    args.skipProb = skipProb; 
     649    args.domain = ogen->domain; 
    644650    args.randomGenerator = randomGenerator; 
    645     args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression; 
    646  
    647     tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args); 
    648  
    649     free(examples); 
    650     free(args.attr_split_so_far); 
    651  
    652     return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type); 
     651    args.type = ogen->domain->classVar->varType == TValue::INTVAR ? Classification : Regression; 
     652    cls_vals = ogen->domain->classVar->noOfValues(); 
     653 
     654    tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args); 
     655 
     656    free(examples); 
     657    free(args.attr_split_so_far); 
     658 
     659    return new TSimpleTreeClassifier(ogen->domain->classVar, tree, args.type, cls_vals); 
    653660} 
    654661 
     
    659666} 
    660667 
    661 TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type) :  
    662     TClassifier(classVar, true), 
    663     tree(tree), 
    664     type(type) 
     668TSimpleTreeClassifier::TSimpleTreeClassifier(const PVariable &classVar, struct SimpleTreeNode *tree, int type, int cls_vals) :  
     669    TClassifier(classVar, true), 
     670    tree(tree), 
     671    type(type), 
     672    cls_vals(cls_vals) 
    665673{ 
    666674} 
     
    670678destroy_tree(struct SimpleTreeNode *node, int type) 
    671679{ 
    672     int i; 
    673  
    674     if (node->type != PredictorNode) { 
    675         for (i = 0; i < node->children_size; i++) 
    676             destroy_tree(node->children[i], type); 
    677         free(node->children); 
    678     } 
    679     if (type == Classification) 
    680         free(node->dist); 
    681     free(node); 
     680    int i; 
     681 
     682    if (node->type != PredictorNode) { 
     683        for (i = 0; i < node->children_size; i++) 
     684            destroy_tree(node->children[i], type); 
     685        free(node->children); 
     686    } 
     687    if (type == Classification) 
     688        free(node->dist); 
     689    free(node); 
    682690} 
    683691 
     
    685693TSimpleTreeClassifier::~TSimpleTreeClassifier() 
    686694{ 
    687     destroy_tree(tree, type); 
     695    destroy_tree(tree, type); 
    688696} 
    689697 
     
    692700predict_classification(const TExample &ex, struct SimpleTreeNode *node, int *free_dist) 
    693701{ 
    694     int i, j, cls_vals; 
    695     float *dist, *child_dist; 
    696  
    697     while (node->type != PredictorNode) 
    698         if (ex.values[node->split_attr].isSpecial()) { 
    699             cls_vals = ex.domain->classVar->noOfValues(); 
    700             ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist)); 
    701             for (i = 0; i < node->children_size; i++) { 
    702                 child_dist = predict_classification(ex, node->children[i], free_dist); 
    703                 for (j = 0; j < cls_vals; j++) 
    704                     dist[j] += child_dist[j]; 
    705                 if (*free_dist) 
    706                     free(child_dist); 
    707             } 
    708             *free_dist = 1; 
    709             return dist; 
    710         } else if (node->type == DiscreteNode) { 
    711             node = node->children[ex.values[node->split_attr].intV]; 
    712         } else { 
    713             assert(node->type == ContinuousNode); 
    714             node = node->children[ex.values[node->split_attr].floatV >= node->split]; 
    715         } 
    716  
    717     *free_dist = 0; 
    718     return node->dist; 
     702    int i, j, cls_vals; 
     703    float *dist, *child_dist; 
     704 
     705    while (node->type != PredictorNode) 
     706        if (ex.values[node->split_attr].isSpecial()) { 
     707            cls_vals = ex.domain->classVar->noOfValues(); 
     708            ASSERT(dist = (float *)calloc(cls_vals, sizeof *dist)); 
     709            for (i = 0; i < node->children_size; i++) { 
     710                child_dist = predict_classification(ex, node->children[i], free_dist); 
     711                for (j = 0; j < cls_vals; j++) 
     712                    dist[j] += child_dist[j]; 
     713                if (*free_dist) 
     714                    free(child_dist); 
     715            } 
     716            *free_dist = 1; 
     717            return dist; 
     718        } else if (node->type == DiscreteNode) { 
     719            node = node->children[ex.values[node->split_attr].intV]; 
     720        } else { 
     721            assert(node->type == ContinuousNode); 
     722            node = node->children[ex.values[node->split_attr].floatV >= node->split]; 
     723        } 
     724 
     725    *free_dist = 0; 
     726    return node->dist; 
    719727} 
    720728 
     
    723731predict_regression(const TExample &ex, struct SimpleTreeNode *node, float *sum, float *n) 
    724732{ 
    725     int i; 
    726     float local_sum, local_n; 
    727  
    728     while (node->type != PredictorNode) { 
    729         if (ex.values[node->split_attr].isSpecial()) { 
    730             *sum = *n = 0; 
    731             for (i = 0; i < node->children_size; i++) { 
    732                 predict_regression(ex, node->children[i], &local_sum, &local_n); 
    733                 *sum += local_sum; 
    734                 *n += local_n; 
    735             } 
    736             return; 
    737         } else if (node->type == DiscreteNode) { 
    738             assert(ex.values[node->split_attr].intV < node->children_size); 
    739             node = node->children[ex.values[node->split_attr].intV]; 
    740         } else { 
    741             assert(node->type == ContinuousNode); 
    742             node = node->children[ex.values[node->split_attr].floatV > node->split]; 
    743         } 
    744     } 
    745  
    746     *sum = node->sum; 
    747     *n = node->n; 
     733    int i; 
     734    float local_sum, local_n; 
     735 
     736    while (node->type != PredictorNode) { 
     737        if (ex.values[node->split_attr].isSpecial()) { 
     738            *sum = *n = 0; 
     739            for (i = 0; i < node->children_size; i++) { 
     740                predict_regression(ex, node->children[i], &local_sum, &local_n); 
     741                *sum += local_sum; 
     742                *n += local_n; 
     743            } 
     744            return; 
     745        } else if (node->type == DiscreteNode) { 
     746            assert(ex.values[node->split_attr].intV < node->children_size); 
     747            node = node->children[ex.values[node->split_attr].intV]; 
     748        } else { 
     749            assert(node->type == ContinuousNode); 
     750            node = node->children[ex.values[node->split_attr].floatV > node->split]; 
     751        } 
     752    } 
     753 
     754    *sum = node->sum; 
     755    *n = node->n; 
     756} 
     757 
     758 
     759void 
     760TSimpleTreeClassifier::save_tree(ostringstream &ss, struct SimpleTreeNode *node) 
     761{ 
     762    int i; 
     763 
     764    ss << "{ " << node->type << " " << node->children_size << " "; 
     765 
     766    if (node->type != PredictorNode) 
     767        ss << node->split_attr << " " << node->split << " "; 
     768 
     769    for (i = 0; i < node->children_size; i++) 
     770        this->save_tree(ss, node->children[i]); 
     771 
     772    if (this->type == Classification) { 
     773        for (i = 0; i < this->cls_vals; i++) 
     774            ss << node->dist[i] << " "; 
     775    } else { 
     776        assert(this->type == Regression); 
     777        ss << node->n << " " << node->sum << " "; 
     778    } 
     779    ss << "} "; 
     780} 
     781 
     782struct SimpleTreeNode * 
     783TSimpleTreeClassifier::load_tree(istringstream &ss) 
     784{ 
     785    int i; 
     786    string lbracket, rbracket; 
     787    SimpleTreeNode *node; 
     788 
     789    ASSERT(node = (SimpleTreeNode *)malloc(sizeof *node)); 
     790    ss >> lbracket >> node->type >> node->children_size; 
     791 
     792    if (node->type != PredictorNode) 
     793        ss >> node->split_attr >> node->split; 
     794 
     795    if (node->children_size) { 
     796        ASSERT(node->children = (SimpleTreeNode **)calloc(node->children_size, sizeof *node->children)); 
     797        for (i = 0; i < node->children_size; i++) 
     798            node->children[i] = load_tree(ss); 
     799    } 
     800 
     801    if (this->type == Classification) { 
     802        ASSERT(node->dist = (float *)calloc(cls_vals, sizeof(float *))); 
     803        for (i = 0; i < this->cls_vals; i++) 
     804            ss >> node->dist[i]; 
     805    } else { 
     806        assert(this->type == Regression); 
     807        ss >> node->n >> node->sum; 
     808    } 
     809    ss >> rbracket; 
     810 
     811    /* Synchronization check */ 
     812    assert(lbracket == "{" && rbracket == "}"); 
     813 
     814    return node; 
     815} 
     816 
     817void 
     818TSimpleTreeClassifier::save_model(ostringstream &ss) 
     819{ 
     820    ss.precision(9); /* we have floats */ 
     821    ss << this->type << " " << this->cls_vals << " "; 
     822    this->save_tree(ss, this->tree); 
     823} 
     824 
     825void 
     826TSimpleTreeClassifier::load_model(istringstream &ss) 
     827{ 
     828    ss >> this->type >> this->cls_vals; 
     829    this->tree = load_tree(ss); 
    748830} 
    749831 
     
    752834TSimpleTreeClassifier::operator()(const TExample &ex) 
    753835{ 
    754     if (type == Classification) { 
    755         int i, free_dist, best_val; 
    756         float *dist; 
    757  
    758         dist = predict_classification(ex, tree, &free_dist); 
    759         best_val = 0; 
    760         for (i = 1; i < ex.domain->classVar->noOfValues(); i++) 
    761             if (dist[i] > dist[best_val]) 
    762                 best_val = i; 
    763  
    764         if (free_dist) 
    765             free(dist); 
    766         return TValue(best_val); 
    767     } else { 
    768         float sum, n; 
    769  
    770         assert(type == Regression); 
    771  
    772         predict_regression(ex, tree, &sum, &n); 
    773         return TValue(sum / n); 
    774     } 
     836    if (type == Classification) { 
     837        int i, free_dist, best_val; 
     838        float *dist; 
     839 
     840        dist = predict_classification(ex, tree, &free_dist); 
     841        best_val = 0; 
     842        for (i = 1; i < ex.domain->classVar->noOfValues(); i++) 
     843            if (dist[i] > dist[best_val]) 
     844                best_val = i; 
     845 
     846        if (free_dist) 
     847            free(dist); 
     848        return TValue(best_val); 
     849    } else { 
     850        float sum, n; 
     851 
     852        assert(type == Regression); 
     853 
     854        predict_regression(ex, tree, &sum, &n); 
     855        return TValue(sum / n); 
     856    } 
    775857} 
    776858 
     
    778860TSimpleTreeClassifier::classDistribution(const TExample &ex) 
    779861{ 
    780     if (type == Classification) { 
    781         int i, free_dist; 
    782         float *dist; 
    783  
    784         dist = predict_classification(ex, tree, &free_dist); 
    785  
    786         PDistribution pdist = TDistribution::create(ex.domain->classVar); 
    787         for (i = 0; i < ex.domain->classVar->noOfValues(); i++) 
    788             pdist->setint(i, dist[i]); 
    789         pdist->normalize(); 
    790  
    791         if (free_dist) 
    792             free(dist); 
    793         return pdist; 
    794     } else { 
    795         return NULL; 
    796     } 
     862    if (type == Classification) { 
     863        int i, free_dist; 
     864        float *dist; 
     865 
     866        dist = predict_classification(ex, tree, &free_dist); 
     867 
     868        PDistribution pdist = TDistribution::create(ex.domain->classVar); 
     869        for (i = 0; i < ex.domain->classVar->noOfValues(); i++) 
     870            pdist->setint(i, dist[i]); 
     871        pdist->normalize(); 
     872 
     873        if (free_dist) 
     874            free(dist); 
     875        return pdist; 
     876    } else { 
     877        return NULL; 
     878    } 
    797879} 
    798880 
     
    800882TSimpleTreeClassifier::predictionAndDistribution(const TExample &ex, TValue &value, PDistribution &dist) 
    801883{ 
    802     value = operator()(ex); 
    803     dist = classDistribution(ex); 
    804 } 
     884    value = operator()(ex); 
     885    dist = classDistribution(ex); 
     886} 
  • source/orange/tdidt_simple.hpp

    r9296 r10206  
    11/* 
    2     This file is part of Orange. 
    3      
    4     Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana 
    5     Contact: janez.demsar@fri.uni-lj.si 
     2    This file is part of Orange. 
     3     
     4    Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana 
     5    Contact: janez.demsar@fri.uni-lj.si 
    66 
    7     Orange is free software: you can redistribute it and/or modify 
    8     it under the terms of the GNU General Public License as published by 
    9     the Free Software Foundation, either version 3 of the License, or 
    10     (at your option) any later version. 
     7    Orange is free software: you can redistribute it and/or modify 
     8    it under the terms of the GNU General Public License as published by 
     9    the Free Software Foundation, either version 3 of the License, or 
     10    (at your option) any later version. 
    1111 
    12     Orange is distributed in the hope that it will be useful, 
    13     but WITHOUT ANY WARRANTY; without even the implied warranty of 
    14     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
    15     GNU General Public License for more details. 
     12    Orange is distributed in the hope that it will be useful, 
     13    but WITHOUT ANY WARRANTY; without even the implied warranty of 
     14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     15    GNU General Public License for more details. 
    1616 
    17     You should have received a copy of the GNU General Public License 
    18     along with Orange.  If not, see <http://www.gnu.org/licenses/>. 
     17    You should have received a copy of the GNU General Public License 
     18    along with Orange.  If not, see <http://www.gnu.org/licenses/>. 
    1919*/ 
    2020 
     
    2727 
    2828struct SimpleTreeNode { 
    29     int type, children_size, split_attr; 
    30     float split; 
    31     SimpleTreeNode **children; 
     29    int type, children_size, split_attr; 
     30    float split; 
     31    SimpleTreeNode **children; 
    3232 
    33     float *dist;  /* classification */ 
    34     float n, sum; /* regression */ 
     33    float *dist;  /* classification */ 
     34    float n, sum; /* regression */ 
    3535}; 
    3636 
     
    3838public: 
    3939    __REGISTER_CLASS 
    40     float maxMajority; //P 
    41     int minInstances; //P 
    42     int maxDepth; //P 
    43     float skipProb; //P 
     40    float maxMajority; //P 
     41    int minInstances; //P 
     42    int maxDepth; //P 
     43    float skipProb; //P 
    4444    PRandomGenerator randomGenerator; //P 
    4545 
     
    5050class ORANGE_API TSimpleTreeClassifier : public TClassifier { 
    5151private: 
    52     int type; 
     52    int type, cls_vals; 
    5353    struct SimpleTreeNode *tree; 
     54 
     55    void save_tree(ostringstream &, struct SimpleTreeNode *); 
     56    struct SimpleTreeNode *load_tree(istringstream &); 
    5457 
    5558public: 
     
    5760 
    5861    TSimpleTreeClassifier(); 
    59     TSimpleTreeClassifier(const PVariable &, struct SimpleTreeNode *, int); 
    60     ~TSimpleTreeClassifier(); 
     62    TSimpleTreeClassifier(const PVariable &, struct SimpleTreeNode *, int, int); 
     63    ~TSimpleTreeClassifier(); 
    6164 
     65    void save_model(ostringstream &); 
     66    void load_model(istringstream &); 
    6267    TValue operator()(const TExample &); 
    6368    PDistribution classDistribution(const TExample &); 
     
    6570}; 
    6671 
     72WRAPPER(SimpleTreeClassifier) 
    6773#endif 
Note: See TracChangeset for help on using the changeset viewer.