Changeset 10948:c1bc9e5b584e in orange


Ignore:
Timestamp:
07/10/12 00:07:10 (22 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Fixed weight vector initialization when class values are missing from training data.

Fixes #1214.

Location:
source/orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • source/orange/liblinear_interface.cpp

    r10772 r10948  
    329329        delete param; 
    330330        destroy_problem(prob); 
    331         raiseError("LIBLINEAR error: %s" , error_msg); 
     331        raiseError("LIBLINEAR error: %s", error_msg); 
    332332    } 
    333333    /* The solvers in liblinear use rand() function. 
    334      * To make the results reporoducible we set the seed from the data table's 
     334     * To make the results reproducible we set the seed from the data table's 
    335335     * crc 
    336336     */ 
     
    353353 
    354354    computesProbabilities = check_probability_model(linmodel) != 0; 
    355     int nr_classifier = (linmodel->nr_class==2 && linmodel->param.solver_type != MCSVM_CS)? 1 : linmodel->nr_class; 
     355    // Number of class values 
     356    int nr_values = this->get_nr_values(); 
     357 
     358    /* Number of liblinear classifiers (if some class values are missing 
     359     * from the training set they are not present in the liblinear model). 
     360     */ 
     361    int nr_classifier = linmodel->nr_class; 
     362    if (linmodel->nr_class == 2 && linmodel->param.solver_type != MCSVM_CS) 
     363    { 
     364        nr_classifier = 1; 
     365    } 
     366 
     367    // Number of weight vectors exposed in orange. 
     368    int nr_orange_weights = nr_values; 
     369    if (nr_values == 2 && linmodel->param.solver_type != MCSVM_CS) 
     370    { 
     371        nr_orange_weights = 1; 
     372    } 
    356373 
    357374    int nr_feature = linmodel->nr_feature; 
     375 
    358376    if (linmodel->bias >= 0.0) 
     377    { 
    359378        nr_feature++; 
     379    } 
    360380 
    361381    int* labels = new int[linmodel->nr_class]; 
    362382    get_labels(linmodel, labels); 
    363383 
    364     weights = mlnew TFloatListList(nr_classifier); 
    365     for (int i = 0; i < nr_classifier; i++) 
    366     { 
    367         weights->at(i) = mlnew TFloatList(nr_feature); 
    368     } 
    369  
    370     for (int i = 0; i < nr_classifier; i++) 
    371     { 
    372         for (int j = 0; j < nr_feature; j++) 
    373         { 
    374             weights->at((nr_classifier > 1)? labels[i]: 0)->at(j) = \ 
    375                     linmodel->w[j*nr_classifier + i]; 
    376         } 
    377     } 
    378     delete[] labels; 
     384    // Initialize nr_orange_weights vectors 
     385    weights = mlnew TFloatListList(nr_orange_weights); 
     386    for (int i = 0; i < nr_orange_weights; i++) 
     387    { 
     388        weights->at(i) = mlnew TFloatList(nr_feature, 0.0f); 
     389    } 
     390 
     391    if (nr_classifier > 1) 
     392    { 
     393        for (int i = 0; i < nr_classifier; i++) 
     394        { 
     395            for (int j = 0; j < nr_feature; j++) 
     396            { 
     397                weights->at(labels[i])->at(j) = \ 
     398                        linmodel->w[j*nr_classifier + i]; 
     399            } 
     400        } 
     401} 
     402    else 
     403    { 
     404        for (int j = 0; j < nr_feature; j++) 
     405        { 
     406            /* If there are more than 2 class values 
     407             */ 
     408            if (nr_orange_weights > 1) 
     409            { 
     410                weights->at(labels[0])->at(j) = linmodel->w[j]; 
     411                weights->at(labels[1])->at(j) = - linmodel->w[j]; 
     412            } 
     413            else 
     414            { 
     415                weights->at(0)->at(j) = linmodel->w[j]; 
     416            } 
     417        } 
     418    } 
     419    delete[] labels; 
    379420} 
    380421 
     
    382423    if (linmodel) 
    383424        free_and_destroy_model(&linmodel); 
     425} 
     426 
     427/* Return the number of discrete class values, or raise an error 
     428 * if the class_var is not discrete. 
     429 */ 
     430int TLinearClassifier::get_nr_values() 
     431{ 
     432    int nr_values = 0; 
     433    TEnumVariable * enum_var = NULL; 
     434    enum_var = dynamic_cast<TEnumVariable*>(classVar.getUnwrappedPtr()); 
     435    if (enum_var) 
     436    { 
     437        nr_values = enum_var->noOfValues(); 
     438    } 
     439    else 
     440    { 
     441        raiseError("Discrete class expected."); 
     442    } 
    384443} 
    385444 
  • source/orange/liblinear_interface.hpp

    r10771 r10948  
    7575    model *linmodel; 
    7676    double dbias; 
     77    int get_nr_values(); 
    7778}; 
    7879 
Note: See TracChangeset for help on using the changeset viewer.