Changeset 9168:bc9806d90207 in orange


Ignore:
Timestamp:
11/02/11 21:15:16 (2 years ago)
Author:
jzbontar <jure.zbontar@…>
Branch:
default
Convert:
49ac026abb394eab296682f6ea466bd09ac766c5
Message:

SimpleTreeLearner: calculating the variance is numerically instable, temporary fix

Location:
source/orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • source/orange/tdidt_simple.cpp

    r9042 r9168  
    101101    int i; 
    102102 
    103     for (i = 0; i < attr_vals; i++) 
     103    for (i = 0; i < attr_vals; i++) { 
    104104        if (attr_dist[i] > 0.0 && attr_dist[i] < args->minInstances) 
    105105            return 0; 
     106    } 
    106107    return 1; 
    107108} 
     
    251252 
    252253    struct Variance { 
    253         float n, sum, sum2; 
     254        double n, sum, sum2; 
    254255    } var_lt = {0.0, 0.0, 0.0}, var_ge = {0.0, 0.0, 0.0}; 
    255256 
     
    295296            var_lt.sum2 += ex->weight * cls_val * cls_val; 
    296297 
     298            /* this calculation might be numarically unstable - fix */ 
    297299            var_ge.n -= ex->weight; 
    298300            var_ge.sum -= ex->weight * cls_val; 
     
    300302        } 
    301303 
     304        /* Naive calculation of variance (used for testing) 
     305          
     306        struct Example *ex2, *ex_end2; 
     307        float nlt, sumlt, sum2lt, nge, sumge, sum2ge; 
     308        nlt = sumlt = sum2lt = nge = sumge = sum2ge = 0.0; 
     309 
     310        for (ex2 = examples, ex_end2 = ex2 + size; ex2 < ex_end2; ex2++) { 
     311            cls_val = ex2->example->getClass(); 
     312            if (ex2 < ex) { 
     313                nlt += ex2->weight; 
     314                sumlt += ex2->weight * cls_val; 
     315                sum2lt += ex2->weight * cls_val * cls_val; 
     316            } else { 
     317                nge += ex2->weight; 
     318                sumge += ex2->weight * cls_val; 
     319                sum2ge += ex2->weight * cls_val * cls_val; 
     320            } 
     321        } 
     322        */ 
     323 
     324 
    302325        if (ex->example->values[attr] == ex_next->example->values[attr] || i + 1 < minInstances) 
    303326            continue; 
     
    306329        score = var_lt.sum2 - var_lt.sum * var_lt.sum / var_lt.n; 
    307330        score += var_ge.sum2 - var_ge.sum * var_ge.sum / var_ge.n; 
     331 
    308332        score = (cls_mse - score / size_attr_cls_known) / cls_mse * (size_attr_known / size_weight); 
    309333 
     
    379403make_predictor(struct SimpleTreeNode *node, struct Example *examples, int size, struct Args *args) 
    380404{ 
    381     struct Example *ex, *ex_end; 
    382  
    383405    node->type = PredictorNode; 
    384     if (args->type == Regression) { 
    385         node->n = node->sum = 0.0; 
    386         for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) 
    387             if (!ex->example->getClass().isSpecial()) { 
    388                 node->sum += ex->weight * ex->example->getClass().floatV; 
    389                 node->n += ex->weight; 
    390             } 
    391  
    392     } 
    393  
    394406    return node; 
    395407} 
     
    454466            } 
    455467 
     468        node->n = n; 
     469        node->sum = sum; 
    456470        cls_mse = (sum2 - sum * sum / n) / n; 
     471 
     472        if (cls_mse < 1e-5) { 
     473            return make_predictor(node, examples, size, args); 
     474        } 
    457475    } 
    458476 
     
    500518 
    501519        /* printf("* %2d %3s %3d %f\n", depth, args->domain->attributes->at(best_attr)->get_name().c_str(), size, best_score); */ 
     520        printf("* %2d %d %3d %f\n", depth, best_attr, size, best_score);  
    502521 
    503522        attr_vals = args->domain->attributes->at(best_attr)->noOfValues();  
     
    707726    float local_sum, local_n; 
    708727 
    709     while (node->type != PredictorNode) 
     728    while (node->type != PredictorNode) { 
    710729        if (ex.values[node->split_attr].isSpecial()) { 
    711730            *sum = *n = 0; 
     
    717736            return; 
    718737        } else if (node->type == DiscreteNode) { 
     738            assert(ex.values[node->split_attr].intV < node->children_size); 
    719739            node = node->children[ex.values[node->split_attr].intV]; 
    720740        } else { 
     
    722742            node = node->children[ex.values[node->split_attr].floatV > node->split]; 
    723743        } 
     744    } 
    724745 
    725746    *sum = node->sum; 
  • source/orange/tdidt_simple.hpp

    r9042 r9168  
    4444    int seed; //P 
    4545 
    46     TSimpleTreeLearner(const int & =0, float=1.0, int=0, int=INT_MAX, float=0.0, unsigned int=0); 
     46    TSimpleTreeLearner(const int & =0, float=1.0, int=2, int=INT_MAX, float=0.0, unsigned int=0); 
    4747    PClassifier operator()(PExampleGenerator, const int & =0); 
    4848}; 
Note: See TracChangeset for help on using the changeset viewer.