Changeset 8131:9af5fb2d4d3f in orange


Ignore:
Timestamp:
08/01/11 14:35:24 (3 years ago)
Author:
jzbontar <jure.zbontar@…>
Branch:
default
Convert:
453d32bda826fdaff38ccc048af184b153d44ade
Message:

SimpleTreeLearner: fixed minExamples bug

Location:
source/orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • source/orange/tdidt_simple.cpp

    r8091 r8131  
    5454 
    5555    int *attr_split_so_far; 
     56    PDomain domain; 
    5657}; 
    5758 
    5859int compar_attr; 
    5960 
    60 /* This function uses the global variable compar_attr */ 
     61/* This function uses the global variable compar_attr. 
     62 * Examples with unknowns are larger so that when sorted, appear at the bottom. 
     63 */ 
    6164int 
    62 compar_examples(const void * ptr1, const void * ptr2) 
     65compar_examples(const void *ptr1, const void *ptr2) 
    6366{ 
    6467    TExample **e1 = (TExample **) ptr1; 
     
    8386        } 
    8487 
    85     return e / sum + LOG2(sum); 
     88    return sum == 0 ? 0.0 : e / sum + LOG2(sum); 
    8689} 
    8790 
     
    9093score_attribute_c(TExample **examples, int size, int attr, float cls_entropy, int *rank, struct Args *args) 
    9194{ 
    92     PDomain domain; 
    9395    TExample **ex, **ex_end, **ex_next; 
    94     int i, cls_vals, *attr_dist, *dist_lt, *dist_ge; 
     96    int i, cls_vals, *attr_dist, *dist_lt, *dist_ge, minExamples; 
    9597    float best_score; 
    9698 
    9799    assert(size > 0); 
    98     domain = examples[0]->domain; 
    99  
    100     cls_vals = domain->classVar->noOfValues(); 
     100 
     101    cls_vals = args->domain->classVar->noOfValues(); 
    101102 
    102103    /* allocate space */ 
     
    117118    } 
    118119 
    119     best_score = -HUGE_VAL; 
     120    best_score = -INFINITY; 
    120121    attr_dist[1] = size; 
    121     ex = examples + args->minExamples - 1; 
    122     ex_end = examples + size - (args->minExamples - 1); 
     122 
     123    /* minExamples should be at least 1, otherwise there is no point in splitting */ 
     124    minExamples = minExamples < 1 ? 1 : minExamples; 
     125    ex = examples + minExamples - 1; 
     126    ex_end = examples + size - (minExamples - 1); 
    123127    for (ex_next = ex + 1, i = 0; ex_next < ex_end; ex++, ex_next++, i++) { 
    124128        int cls; 
     
    146150    } 
    147151 
     152    /* printf("C %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), best_score); */ 
     153 
    148154    /* cleanup */ 
    149155    free(dist_lt); 
     
    157163score_attribute_d(TExample **examples, int size, int attr, float cls_entropy, struct Args *args) 
    158164{ 
    159     PDomain domain; 
    160165    TExample **ex, **ex_end; 
    161166    int i, j, cls_vals, attr_vals, *cont, *attr_dist, *attr_dist_cls_known, min, size_attr_known, size_attr_cls_known; 
     
    163168 
    164169    assert(size > 0); 
    165     domain = examples[0]->domain; 
    166  
    167     cls_vals = domain->classVar->noOfValues(); 
    168     attr_vals = domain->attributes->at(attr)->noOfValues(); 
     170 
     171    cls_vals = args->domain->classVar->noOfValues(); 
     172    attr_vals = args->domain->attributes->at(attr)->noOfValues(); 
    169173 
    170174    /* allocate space */ 
     
    205209    score = (cls_entropy - score / size_attr_cls_known) / entropy(attr_dist, attr_vals) * ((float)size_attr_known / size); 
    206210 
     211    /* printf("D %s %f\n", args->domain->attributes->at(attr)->get_name().c_str(), score); */ 
     212 
    207213finish: 
    208214    free(cont); 
     
    213219 
    214220struct SimpleTreeNode * 
    215 build_tree(TExample **examples, int size, int depth, struct Args *args) 
    216 { 
    217     PDomain domain; 
     221build_tree(TExample **examples, int size, int depth, struct SimpleTreeNode *parent, struct Args *args) 
     222{ 
    218223    TExample **ex, **ex_top, **ex_end; 
    219224    TVarList::const_iterator it; 
     
    222227    int i, best_attr, best_rank, sum, best_val, finish, cls_vals; 
    223228 
    224     assert(size > 0); 
    225     domain = examples[0]->domain; 
    226  
    227229    ASSERT(node = (SimpleTreeNode *) malloc(sizeof(*node))); 
     230    cls_vals = args->domain->classVar->noOfValues(); 
     231    ASSERT(node->dist = (int *) calloc(cls_vals, sizeof *node->dist)); 
     232 
     233    if (size == 0) { 
     234        assert(parent); 
     235        node->type = PredictorNode; 
     236        memcpy(node->dist, parent->dist, cls_vals * sizeof *node->dist); 
     237        return node; 
     238    } 
    228239 
    229240    /* class distribution */ 
    230     cls_vals = domain->classVar->noOfValues(); 
    231  
    232     ASSERT(node->dist = (int *) calloc(cls_vals, sizeof *node->dist)); 
    233241    for (ex = examples, ex_end = examples + size; ex != ex_end; ex++) 
    234242        if (!(*ex)->getClass().isSpecial()) 
     
    246254        cls_entropy = entropy(node->dist, cls_vals); 
    247255        best_score = -INFINITY; 
    248         for (i = 0, it = domain->attributes->begin(); it != domain->attributes->end(); it++, i++) 
     256        for (i = 0, it = args->domain->attributes->begin(); it != args->domain->attributes->end(); it++, i++) 
    249257            if (!args->attr_split_so_far[i]) { 
    250258 
     
    287295    size = ex_top - examples; 
    288296 
    289     if (domain->attributes->at(best_attr)->varType == TValue::INTVAR) { 
     297    if (args->domain->attributes->at(best_attr)->varType == TValue::INTVAR) { 
    290298        TExample **tmp; 
    291299        int *cnt, no_of_values; 
    292300 
    293         /* printf("%2d %3s %3d %f\n", depth, domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ 
     301        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ 
    294302 
    295303        node->type = DiscreteNode; 
     
    297305         
    298306        /* counting sort */ 
    299         no_of_values = domain->attributes->at(best_attr)->noOfValues(); 
     307        no_of_values = args->domain->attributes->at(best_attr)->noOfValues(); 
    300308 
    301309        ASSERT(tmp = (TExample **) calloc(size, sizeof *tmp)); 
     
    322330 
    323331            new_size = (i == no_of_values - 1) ? size - cnt[i] : cnt[i + 1] - cnt[i]; 
    324             node->children[i] = build_tree(examples + cnt[i], new_size, depth + 1, args); 
     332            node->children[i] = build_tree(examples + cnt[i], new_size, depth + 1, node, args); 
    325333        } 
    326334        args->attr_split_so_far[best_attr] = 0; 
     
    328336        free(tmp); 
    329337        free(cnt); 
    330     } else if (domain->attributes->at(best_attr)->varType == TValue::FLOATVAR) { 
     338    } else if (args->domain->attributes->at(best_attr)->varType == TValue::FLOATVAR) { 
    331339        compar_attr = best_attr; 
    332340        qsort(examples, size, sizeof(TExample *), compar_examples); 
     
    336344        node->split = (examples[best_rank]->values[best_attr].floatV + examples[best_rank + 1]->values[best_attr].floatV) / 2.0; 
    337345 
    338         /* printf("%2d %3s %.4f\n", depth, domain->attributes->at(best_attr)->get_name().c_str(), node->split); */ 
     346        /* printf("%2d %3s %.4f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), node->split); */ 
    339347 
    340348        /* recursively build subtrees */ 
     
    342350        ASSERT(node->children = (SimpleTreeNode **) calloc(2, sizeof *node->children)); 
    343351         
    344         node->children[0] = build_tree(examples, best_rank + 1, depth + 1, args); 
    345         node->children[1] = build_tree(examples + best_rank + 1, size - (best_rank + 1), depth + 1, args); 
     352        node->children[0] = build_tree(examples, best_rank + 1, depth + 1, node, args); 
     353        node->children[1] = build_tree(examples + best_rank + 1, size - (best_rank + 1), depth + 1, node, args); 
    346354    } 
    347355 
     
    379387    args.maxDepth = maxDepth; 
    380388    args.skipProb = skipProb; 
    381  
    382     tree = build_tree(examples, ogen->numberOfExamples(), 0, &args); 
     389    args.domain = ogen->domain; 
     390 
     391    tree = build_tree(examples, ogen->numberOfExamples(), 0, NULL, &args); 
    383392 
    384393    free(examples); 
  • source/orange/tdidt_simple.hpp

    r8065 r8131  
    4444    float skipProb; //P 
    4545 
    46     TSimpleTreeLearner(const int & =0, float=1.0, int=1, int=INT_MAX, float=0.0); 
     46    TSimpleTreeLearner(const int & =0, float=1.0, int=0, int=INT_MAX, float=0.0); 
    4747    PClassifier operator()(PExampleGenerator, const int & =0); 
    4848}; 
Note: See TracChangeset for help on using the changeset viewer.