Ignore:
Timestamp:
07/01/13 16:25:11 (10 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Changed (simplified) SVMClassifier constructors (and pickling).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • source/orange/libsvm_interface.cpp

    r11606 r11607  
    727727    PExampleTable supportVectors = extract_support_vectors(model, examples); 
    728728 
    729     PVariable classVar; 
    730  
    731     if (param.svm_type == ONE_CLASS) { 
    732         classVar = mlnew TFloatVariable("one class"); 
    733     } else { 
    734         classVar = examples->domain->classVar; 
    735     } 
    736  
    737     return PClassifier(createClassifier(classVar, examples, supportVectors, model)); 
     729    PDomain domain = examples->domain; 
     730 
     731    return PClassifier(createClassifier(examples->domain, model, supportVectors, examples)); 
    738732} 
    739733 
     
    766760 
    767761TSVMClassifier* TSVMLearner::createClassifier( 
    768         PVariable classVar, PExampleTable examples, PExampleTable supportVectors, svm_model* model) { 
     762        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
    769763    if (kernel_type != PRECOMPUTED) { 
    770764        examples = NULL; 
    771765    } 
    772     return mlnew TSVMClassifier(classVar, examples, supportVectors, model, kernelFunc); 
     766    return mlnew TSVMClassifier(domain, model, supportVectors, kernelFunc, examples); 
    773767} 
    774768 
     
    790784 
    791785TSVMClassifier* TSVMLearnerSparse::createClassifier( 
    792         PVariable classVar, PExampleTable examples, PExampleTable supportVectors, svm_model* model) { 
     786        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
    793787    if (kernel_type != PRECOMPUTED) { 
    794788        examples = NULL; 
    795789    } 
    796     return mlnew TSVMClassifierSparse(classVar, examples, supportVectors, model, useNonMeta, kernelFunc); 
     790    return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kernelFunc, examples); 
    797791} 
    798792 
    799793 
    800794TSVMClassifier::TSVMClassifier( 
    801         const PVariable &var, 
    802         PExampleTable examples, 
     795        PDomain domain, svm_model * model, 
    803796        PExampleTable supportVectors, 
    804         svm_model* model, 
    805         PKernelFunc kernelFunc) { 
    806     this->classVar = var; 
     797        PKernelFunc kernelFunc, 
     798        PExampleTable examples 
     799        ) : TClassifierFD(domain) { 
     800    this->model = model; 
     801    this->supportVectors = supportVectors; 
     802    this->kernelFunc = kernelFunc; 
    807803    this->examples = examples; 
    808     this->supportVectors = supportVectors; 
    809     this->model = model; 
    810     this->kernelFunc = kernelFunc; 
    811  
    812     domain = supportVectors->domain; 
     804 
    813805    svm_type = svm_get_svm_type(model); 
    814806    kernel_type = model->param.kernel_type; 
    815807 
     808    if (svm_type == ONE_CLASS) { 
     809        this->classVar = mlnew TFloatVariable("one class"); 
     810    } 
     811 
    816812    computesProbabilities = model && svm_check_probability_model(model) && \ 
    817             (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression 
     813                (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression 
    818814 
    819815    int nr_class = svm_get_nr_class(model); 
     
    823819     * class interface. 
    824820     */ 
    825     if (svm_type == C_SVC || svm_type == NU_SVC){ 
    826         nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum = model->l) 
    827         for(i = 0;i < nr_class; i++) 
    828             nSV->at(i) = model->nSV[i]; 
    829     } 
     821    if (svm_type == C_SVC || svm_type == NU_SVC) { 
     822        nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum(nSV) == model->l) 
     823        for(i = 0;i < nr_class; i++) { 
     824            nSV->at(i) = model->nSV[i]; 
     825        } 
     826    } 
    830827 
    831828    coef = mlnew TFloatListList(nr_class-1); 
    832     for(i = 0; i < nr_class - 1; i++){ 
     829    for(i = 0; i < nr_class - 1; i++) { 
    833830        TFloatList *coefs = mlnew TFloatList(model->l); 
    834         for(int j = 0;j < model->l; j++) 
     831        for(int j = 0;j < model->l; j++) { 
    835832            coefs->at(j) = model->sv_coef[i][j]; 
    836         coef->at(i)=coefs; 
    837     } 
    838     rho = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    839     for(i = 0; i < nr_class*(nr_class-1)/2; i++) 
     833        } 
     834        coef->at(i) = coefs; 
     835    } 
     836 
     837    // Number of binary classifiers in the model 
     838    int nr_bin_cls = nr_class * (nr_class - 1) / 2; 
     839 
     840    rho = mlnew TFloatList(nr_bin_cls); 
     841    for(i = 0; i < nr_bin_cls; i++) { 
    840842        rho->at(i) = model->rho[i]; 
    841     if(model->probA){ 
    842         probA = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    843         if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) // Regression has only probA 
    844             probB = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    845         for(i=0; i<nr_class*(nr_class-1)/2; i++){ 
     843    } 
     844 
     845    if(model->probA) { 
     846        probA = mlnew TFloatList(nr_bin_cls); 
     847        if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) { 
     848            // Regression only has probA 
     849            probB = mlnew TFloatList(nr_bin_cls); 
     850        } 
     851 
     852        for(i=0; i<nr_bin_cls; i++) { 
    846853            probA->at(i) = model->probA[i]; 
    847             if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) 
     854            if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) { 
    848855                probB->at(i) = model->probB[i]; 
    849         } 
    850     } 
    851 } 
     856            } 
     857        } 
     858    } 
     859} 
     860 
    852861 
    853862TSVMClassifier::~TSVMClassifier(){ 
     
    856865    } 
    857866} 
     867 
    858868 
    859869PDistribution TSVMClassifier::classDistribution(const TExample & example){ 
Note: See TracChangeset for help on using the changeset viewer.