Changeset 10987:ce640127b31d in orange


Ignore:
Timestamp:
09/11/12 20:07:09 (20 months ago)
Author:
Miran Levar <mlevar@…>
Branch:
default
Message:

Changes to methods in Clustering Trees.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • source/orange/tdidt_clustering.cpp

    r10970 r10987  
    105105 
    106106 
    107 float distance_gini(struct Example *examples, int size, int attr, int *cls_vals, struct Arguments *args) 
     107float distance_gini(struct Example *examples, int size, int attr, int *cls_vals, float gini_prior, struct Arguments *args) 
    108108{ 
    109109    TValue *cls, *cls_end; 
     
    173173    free(size_weight); 
    174174 
    175     return score; 
     175    return score - gini_prior; 
    176176} 
    177177 
     
    298298float dist_silhuette(float **ptypes, int ptypes_size, struct Example *examples, int size, int attr, struct Arguments *args, float split) { 
    299299    int i, j, n_classes, attr_val, n_dist=0; 
    300     float dist = 0, temp, d, inter_dist, intra_dist; 
    301     struct Example *ex, *ex_end; 
     300    float dist = 0, temp, d, *cls_vals, *cls_vals_n, inter_dist, intra_dist; 
     301    struct Example *ex,  *ex_end; 
    302302    TValue *cls, *cls_end; 
    303303 
    304304    n_classes = args->domain->classVars->size(); 
     305     
     306    ASSERT(cls_vals = (float *)calloc(n_classes, sizeof(float))); 
     307    ASSERT(cls_vals_n = (float *)calloc(n_classes, sizeof(float))); 
    305308     
    306309    for (ex = examples, ex_end = examples + size; ex < ex_end; ex++) { 
    307310        if (!ex->example->values[attr].isSpecial()) { 
     311            for (i = 0; i < n_classes; i++) 
     312                cls_vals_n[i] = 0; 
     313             
     314            for (cls = ex->example->values_end,  cls_end = ex->example->classes_end; 
     315                    cls < cls_end; cls++) { 
     316                i = cls  + n_classes - cls_end; 
     317                if (!cls->isSpecial()) { 
     318                    cls_vals[i] = args->type == Classification ? cls->intV : cls->floatV; 
     319                    cls_vals_n[i] += ex->weight; 
     320                }else 
     321                    cls_vals[i]=-INFINITY; 
     322            } 
     323             
    308324            if (split != INFINITY){ 
    309325                attr_val = ex->example->values[attr].floatV; 
     
    311327            }else 
    312328                attr_val = ex->example->values[attr].intV; 
    313  
    314             inter_dist = INFINITY; 
     329             
     330 
     331            for (i = 0; i < n_classes; i++){ 
     332                if(cls_vals_n[i]!=0) 
     333                    cls_vals[i] = cls_vals[i] / cls_vals_n[i]; 
     334                else 
     335                    cls_vals[i]=INFINITY; 
     336            } 
     337 
    315338            intra_dist = 0; 
    316             n_dist = 0; 
    317             for (cls = ex->example->values_end, cls_end = ex->example->classes_end; 
    318                     cls < cls_end; cls++) { 
    319                 if (!cls->isSpecial()) { 
    320                     i = cls  + n_classes - cls_end; 
    321                     temp = args->type == Classification ? cls->intV : cls->floatV; 
    322                     d = ptypes[attr_val][i] - temp; 
    323                     intra_dist += d * d; 
    324                     n_dist++; 
    325                 } 
    326             } 
    327             intra_dist /= n_dist; 
    328             for (j = 0; j < ptypes_size; j++) { 
    329                 if(j==attr_val) 
    330                     continue; 
     339            intra_dist = INFINITY; 
     340            for (i = 0; i < ptypes_size; i++){ 
    331341                temp = 0; 
    332                 n_dist = 0; 
    333                 for (cls = ex->example->values_end, cls_end = ex->example->classes_end; 
    334                     cls < cls_end; cls++) { 
    335                     if (!cls->isSpecial()) { 
    336                         i = cls  + n_classes - cls_end; 
    337                          
    338                         d = ptypes[j][i] - cls->intV; 
    339                         temp += d * d; 
    340                         n_dist++; 
    341                     } 
    342                 } 
    343                 inter_dist /= n_dist; 
    344                 if( temp < inter_dist) 
     342                for(j = 0; j < n_classes; j++){ 
     343                    d = ptypes[i][j] - cls_vals[i]; 
     344                    temp += d*d; 
     345                } 
     346                if(i == attr_val){ 
    345347                    inter_dist = temp; 
     348                }else{ 
     349                    if(temp < intra_dist) 
     350                        intra_dist = temp; 
     351                } 
    346352            } 
    347353             
    348354            temp = inter_dist - intra_dist; 
    349355            if (intra_dist > inter_dist) 
    350                 temp/= intra_dist; 
     356                temp /= intra_dist; 
    351357            else 
    352                 temp/=inter_dist; 
     358                temp /= inter_dist; 
    353359            dist += temp; 
    354         } 
    355     } 
    356     return dist / size; 
     360            n_dist++; 
     361        } 
     362    } 
     363    free(cls_vals); 
     364    free(cls_vals_n); 
     365    dist = dist / n_dist; 
     366 
     367    if(dist < -1) 
     368        return -1; 
     369    else if (dist > 1) 
     370        return 1; 
     371    else 
     372        return dist; 
    357373} 
    358374 
     
    360376    int i; 
    361377    float dist = 0, ptypes_size, **ptypes; 
    362  
     378     
    363379    ptypes = protottype_d(examples, size, attr, args, &ptypes_size); 
    364380 
     
    370386    ASSERT(ptypes); 
    371387     
    372     if(args->method == Silhouette) 
    373         dist = dist_silhuette(ptypes, ptypes_size, examples, size, attr, args, INFINITY); 
    374     else if (args->method == Intra) 
    375         dist_intra(ptypes, ptypes_size, examples, size, attr, args, INFINITY); 
     388     
     389    if (args->method == Intra) 
     390        dist = dist_intra(ptypes, ptypes_size, examples, size, attr, args, INFINITY); 
     391    else if(args->method == Silhouette) 
     392        return dist_silhuette(ptypes, ptypes_size, examples, size, attr, args, INFINITY); 
    376393    else 
    377394        dist = dist_inter(ptypes, ptypes_size, args); 
     
    443460            dist = dist_silhuette(ptypes, 2, examples, size, attr, args, split); 
    444461        else if (args->method == Intra) 
    445             dist_intra(ptypes, 2, examples, size, attr, args, split); 
     462            dist = dist_intra(ptypes, 2, examples, size, attr, args, split); 
    446463        else 
    447464            dist = dist_inter(ptypes, 2, args); 
     
    477494        struct ClusteringTreeNode *parent, struct Arguments *args) { 
    478495    int i, j, n_classes, best_attr, *cls_vals, stop_maj; 
    479     float cls_mse, best_score, score, *size_weight, best_split, split; 
     496    float cls_mse, best_score, score, *size_weight, best_split, split, gini_prior; 
    480497    struct ClusteringTreeNode *node; 
    481498    struct Example *ex, *ex_end; 
     
    534551            return make_predictor(node, examples, size, args); 
    535552        } 
     553         
     554        if(args->method == Gini){ 
     555            gini_prior = 0.0; 
     556            for (i = 0; i < n_classes; i++){ 
     557                for (j=0; j < cls_vals[i]; j++)  
     558                    gini_prior  = node->dist[i][j] / size_weight[i] * node->dist[i][j] / size_weight[i]; 
     559            } 
     560            gini_prior /= n_classes; 
     561        } 
    536562 
    537563    } else { 
     
    604630             
    605631            if ((*it)->varType == TValue::INTVAR) { 
    606                 score = args->method == Gini ? distance_gini(examples, size, i, cls_vals, args) : 
     632                score = args->method == Gini ? distance_gini(examples, size, i, cls_vals, gini_prior, args) : 
    607633                    distance_d(examples, size, i, args); 
    608634                if (score > best_score) { 
Note: See TracChangeset for help on using the changeset viewer.