Ignore:
Timestamp:
07/01/13 16:25:32 (10 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Sort the training data on the class values before training...

so LIBSVM preserves the class label order.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • source/orange/libsvm_interface.cpp

    r11607 r11608  
    662662    } 
    663663 
     664    PDomain domain = examples->domain; 
     665 
    664666    int classVarType; 
    665     if(examples->domain->classVar) 
    666         classVarType=examples->domain->classVar->varType; 
     667    if (domain->classVar) 
     668        classVarType = domain->classVar->varType; 
    667669    else{ 
    668         classVarType=TValue::NONE; 
    669         if(svm_type!=ONE_CLASS) 
     670        classVarType = TValue::NONE; 
     671        if(svm_type != ONE_CLASS) 
    670672            raiseError("Domain has no class variable"); 
    671673    } 
    672     if(classVarType==TValue::FLOATVAR && !(svm_type==EPSILON_SVR || svm_type==NU_SVR ||svm_type==ONE_CLASS)) 
     674    if (classVarType == TValue::FLOATVAR && !(svm_type == EPSILON_SVR || svm_type == NU_SVR ||svm_type == ONE_CLASS)) 
    673675        raiseError("Domain has continuous class"); 
    674676 
    675     if(kernel_type==PRECOMPUTED && !kernelFunc) 
     677    if (kernel_type == PRECOMPUTED && !kernelFunc) 
    676678        raiseError("Custom kernel function not supplied"); 
    677679 
    678     int numElements=getNumOfElements(examples); 
    679  
    680     if(kernel_type != PRECOMPUTED) 
    681         x_space = init_problem(prob, examples, numElements); 
     680    PExampleTable train_data = mlnew TExampleTable(examples, /* owns= */ false); 
     681 
     682    if (classVarType == TValue::INTVAR && svm_type != ONE_CLASS) { 
     683        /* Sort the train data by the class columns so the order of 
     684         * classVar.values is preserved in libsvm's model. 
     685         */ 
     686        vector<int> sort_columns(domain->variables->size() - 1); 
     687        train_data->sort(sort_columns); 
     688    } 
     689 
     690    int numElements = getNumOfElements(train_data); 
     691 
     692    if (kernel_type != PRECOMPUTED) 
     693        x_space = init_problem(prob, train_data, numElements); 
    682694    else // Compute the matrix using the kernelFunc 
    683         x_space = init_precomputed_problem(prob, examples, kernelFunc.getReference()); 
    684  
    685     if(param.gamma==0) 
     695        x_space = init_precomputed_problem(prob, train_data, kernelFunc.getReference()); 
     696 
     697    if (param.gamma==0) 
    686698        param.gamma=1.0f/(float(numElements)/float(prob.l)-1); 
    687699 
    688700    const char* error=svm_check_parameter(&prob,&param); 
    689     if(error){ 
     701    if (error){ 
    690702        free(x_space); 
    691703        free(prob.y); 
     
    725737    free(x_space); 
    726738 
    727     PExampleTable supportVectors = extract_support_vectors(model, examples); 
    728  
    729     PDomain domain = examples->domain; 
    730  
    731     return PClassifier(createClassifier(examples->domain, model, supportVectors, examples)); 
     739    PExampleTable supportVectors = extract_support_vectors(model, train_data); 
     740 
     741    return PClassifier(createClassifier(domain, model, supportVectors, train_data)); 
    732742} 
    733743 
     
    761771TSVMClassifier* TSVMLearner::createClassifier( 
    762772        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
     773    PKernelFunc kfunc; 
    763774    if (kernel_type != PRECOMPUTED) { 
     775        // Classifier does not need the train data and the kernelFunc. 
    764776        examples = NULL; 
    765     } 
    766     return mlnew TSVMClassifier(domain, model, supportVectors, kernelFunc, examples); 
     777        kfunc = NULL; 
     778    } else { 
     779        kfunc = kernelFunc; 
     780    } 
     781 
     782    return mlnew TSVMClassifier(domain, model, supportVectors, kfunc, examples); 
    767783} 
    768784 
     
    772788 
    773789    if(weight) 
    774             free(weight); 
     790        free(weight); 
    775791} 
    776792 
     
    785801TSVMClassifier* TSVMLearnerSparse::createClassifier( 
    786802        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
     803    PKernelFunc kfunc; 
    787804    if (kernel_type != PRECOMPUTED) { 
     805        // Classifier does not need the train data and the kernelFunc. 
    788806        examples = NULL; 
    789     } 
    790     return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kernelFunc, examples); 
     807        kfunc = NULL; 
     808    } else { 
     809        kfunc = kernelFunc; 
     810    } 
     811    return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kfunc, examples); 
    791812} 
    792813 
Note: See TracChangeset for help on using the changeset viewer.