Changeset 11607:8ecd4831def9 in orange


Ignore:
Timestamp:
07/01/13 16:25:11 (10 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Changed (simplified) SVMClassifier constructors (and pickling).

Location:
source/orange
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • source/orange/lib_learner.cpp

    r11606 r11607  
    14531453        raiseError("Error saving SVM model"); 
    14541454    } 
    1455     if(svm->kernelFunc) 
    1456         return Py_BuildValue("O(OOOsO)N", self->ob_type, 
    1457                                     WrapOrange(svm->classVar), 
     1455 
     1456    return Py_BuildValue("O(OsOOO)N", self->ob_type, 
     1457                                    WrapOrange(svm->domain), 
     1458                                    buf.c_str(), 
     1459                                    WrapOrange(svm->supportVectors), 
     1460                                    WrapOrange(svm->kernelFunc), 
    14581461                                    WrapOrange(svm->examples), 
    1459                                     WrapOrange(svm->supportVectors), 
    1460                                     buf.c_str(), 
    1461                                     WrapOrange(svm->kernelFunc), 
    1462                                     packOrangeDictionary(self)); 
    1463     else 
    1464         return Py_BuildValue("O(OOOs)N", self->ob_type, 
    1465                                     WrapOrange(svm->classVar), 
    1466                                     WrapOrange(svm->examples), 
    1467                                     WrapOrange(svm->supportVectors), 
    1468                                     buf.c_str(), 
    14691462                                    packOrangeDictionary(self)); 
    14701463  PyCATCH 
     
    14801473        raiseError("Error saving SVM model."); 
    14811474    } 
    1482     if(svm->kernelFunc) 
    1483         return Py_BuildValue("O(OOOsbO)N", self->ob_type, 
    1484                                     WrapOrange(svm->classVar), 
    1485                                     WrapOrange(svm->examples), 
    1486                                     WrapOrange(svm->supportVectors), 
    1487                                     buf.c_str(), 
    1488                                     (char)(svm->useNonMeta? 1: 0), 
    1489                                     WrapOrange(svm->kernelFunc), 
    1490                                     packOrangeDictionary(self)); 
    1491     else 
    1492         return Py_BuildValue("O(OOOsb)N", self->ob_type, 
    1493                                     WrapOrange(svm->classVar), 
    1494                                     WrapOrange(svm->examples), 
    1495                                     WrapOrange(svm->supportVectors), 
     1475 
     1476    return Py_BuildValue("O(OsbOOO)N", self->ob_type, 
     1477                                    WrapOrange(svm->domain), 
    14961478                                    buf.c_str(), 
    14971479                                    (char)(svm->useNonMeta? 1: 0), 
     1480                                    WrapOrange(svm->supportVectors), 
     1481                                    WrapOrange(svm->kernelFunc), 
     1482                                    WrapOrange(svm->examples), 
    14981483                                    packOrangeDictionary(self)); 
    14991484  PyCATCH 
     
    15231508 
    15241509 
    1525 PyObject * SVMClassifier_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) BASED_ON(ClassifierFD, "(Variable, Examples, Examples, string, [kernelFunc]) -> SVMClassifier") 
    1526 {PyTRY 
    1527     PVariable classVar; 
     1510PyObject * SVMClassifier_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) BASED_ON(ClassifierFD, "(Domain, model_string, supportVectors, [kernelFunc, examples]) -> SVMClassifier") 
     1511{ 
     1512PyTRY 
     1513    PDomain domain; 
     1514    char*  model_string = NULL; 
     1515    PExampleTable supportVectors; 
     1516    PKernelFunc kernel; 
    15281517    PExampleTable examples; 
    1529     PExampleTable supportVectors; 
    1530     char*  model_string; 
    1531     PKernelFunc kernel; 
     1518 
    15321519    if (PyArg_ParseTuple(args, "")) 
    15331520        return WrapNewOrange(mlnew TSVMClassifier(), type); 
     1521 
    15341522    PyErr_Clear(); 
    1535      
    1536     if (!PyArg_ParseTuple(args, "O&O&O&s|O&:__new__", cc_Variable, &classVar, ccn_ExampleTable, &examples, cc_ExampleTable, &supportVectors, &model_string, cc_KernelFunc, &kernel)) 
    1537         return NULL; 
     1523 
     1524    if (!PyArg_ParseTuple(args, "O&sO&|O&O&", 
     1525            cc_Domain, &domain, 
     1526            &model_string, 
     1527            cc_ExampleTable, &supportVectors, 
     1528            ccn_KernelFunc, &kernel, 
     1529            ccn_ExampleTable, &examples)) { 
     1530         // Old pickle arguments format. 
     1531        PVariable classVar; 
     1532        if (!PyArg_ParseTuple(args, "O&O&O&s|O&:__new__", 
     1533                cc_Variable, &classVar, 
     1534                ccn_ExampleTable, &examples, 
     1535                cc_ExampleTable, &supportVectors, 
     1536                &model_string, 
     1537                cc_KernelFunc, &kernel)) { 
     1538            return NULL; 
     1539        } 
     1540        PyErr_Clear(); 
     1541        domain = supportVectors->domain; 
     1542    } 
    15381543 
    15391544    string buffer(model_string); 
     
    15421547        raiseError("Error building LibSVM Model"); 
    15431548 
    1544     PSVMClassifier svm = mlnew TSVMClassifier(classVar, examples, supportVectors, model, kernel); 
     1549    PSVMClassifier svm = mlnew TSVMClassifier(domain, model, supportVectors, kernel, examples); 
    15451550 
    15461551    return WrapOrange(svm); 
     
    15501555 
    15511556 
    1552 PyObject * SVMClassifierSparse_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) BASED_ON(SVMClassifier, "(Variable, Examples, Examples, string, [useNonMeta, kernelFunc]) -> SVMClassifierSparse") 
    1553 {PyTRY 
    1554     PVariable classVar; 
     1557PyObject * SVMClassifierSparse_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) BASED_ON(SVMClassifier, "(Domain, model_string, useNonMeta, supportVectors, [kernelFunc, examples]) -> SVMClassifierSparse") 
     1558{ 
     1559PyTRY 
     1560    PDomain domain; 
     1561    char*  model_string = NULL; 
     1562    char useNonMeta = 0; 
     1563 
     1564    PExampleTable supportVectors; 
     1565    PKernelFunc kernel; 
    15551566    PExampleTable examples; 
    1556     PExampleTable supportVectors; 
    1557     char*  model_string; 
    1558     char useNonMeta; 
    1559     PKernelFunc kernel; 
     1567 
    15601568    if (PyArg_ParseTuple(args, "")) 
    15611569        return WrapNewOrange(mlnew TSVMClassifierSparse(), type); 
     1570 
    15621571    PyErr_Clear(); 
    1563      
    1564     if (!PyArg_ParseTuple(args, "O&O&O&s|bO&:__new__", cc_Variable, &classVar, ccn_ExampleTable, &examples, cc_ExampleTable, &supportVectors, &model_string, &useNonMeta, cc_KernelFunc, &kernel)) 
    1565         return NULL; 
    1566      
     1572 
     1573    if (!PyArg_ParseTuple(args, "O&sbO&|O&O&:__new__", 
     1574            cc_Domain, &domain, 
     1575            &model_string, 
     1576            &useNonMeta, 
     1577            cc_ExampleTable, &supportVectors, 
     1578            ccn_KernelFunc, &kernel, 
     1579            ccn_ExampleTable, &examples)) { 
     1580 
     1581         // Old pickle arguments format. 
     1582        PVariable classVar; 
     1583        if (!PyArg_ParseTuple(args, "O&O&O&s|bO&:__new__", 
     1584                cc_Variable, &classVar, 
     1585                ccn_ExampleTable, &examples, 
     1586                cc_ExampleTable, &supportVectors, 
     1587                &model_string, 
     1588                &useNonMeta, 
     1589                cc_KernelFunc, &kernel)) { 
     1590            return NULL; 
     1591        } 
     1592 
     1593        PyErr_Clear(); 
     1594        domain = supportVectors->domain; 
     1595    } 
     1596 
    15671597    string buffer(model_string); 
    15681598    svm_model* model = svm_load_model_alt(buffer); 
     
    15701600        raiseError("Error building LibSVM Model"); 
    15711601 
    1572     PSVMClassifier svm = mlnew TSVMClassifierSparse(classVar, examples, supportVectors, model, useNonMeta != 0, kernel); 
     1602    PSVMClassifier svm = mlnew TSVMClassifierSparse(domain, model, useNonMeta != 0, supportVectors, kernel, examples); 
    15731603 
    15741604    return WrapOrange(svm); 
  • source/orange/libsvm_interface.cpp

    r11606 r11607  
    727727    PExampleTable supportVectors = extract_support_vectors(model, examples); 
    728728 
    729     PVariable classVar; 
    730  
    731     if (param.svm_type == ONE_CLASS) { 
    732         classVar = mlnew TFloatVariable("one class"); 
    733     } else { 
    734         classVar = examples->domain->classVar; 
    735     } 
    736  
    737     return PClassifier(createClassifier(classVar, examples, supportVectors, model)); 
     729    PDomain domain = examples->domain; 
     730 
     731    return PClassifier(createClassifier(examples->domain, model, supportVectors, examples)); 
    738732} 
    739733 
     
    766760 
    767761TSVMClassifier* TSVMLearner::createClassifier( 
    768         PVariable classVar, PExampleTable examples, PExampleTable supportVectors, svm_model* model) { 
     762        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
    769763    if (kernel_type != PRECOMPUTED) { 
    770764        examples = NULL; 
    771765    } 
    772     return mlnew TSVMClassifier(classVar, examples, supportVectors, model, kernelFunc); 
     766    return mlnew TSVMClassifier(domain, model, supportVectors, kernelFunc, examples); 
    773767} 
    774768 
     
    790784 
    791785TSVMClassifier* TSVMLearnerSparse::createClassifier( 
    792         PVariable classVar, PExampleTable examples, PExampleTable supportVectors, svm_model* model) { 
     786        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) { 
    793787    if (kernel_type != PRECOMPUTED) { 
    794788        examples = NULL; 
    795789    } 
    796     return mlnew TSVMClassifierSparse(classVar, examples, supportVectors, model, useNonMeta, kernelFunc); 
     790    return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kernelFunc, examples); 
    797791} 
    798792 
    799793 
    800794TSVMClassifier::TSVMClassifier( 
    801         const PVariable &var, 
    802         PExampleTable examples, 
     795        PDomain domain, svm_model * model, 
    803796        PExampleTable supportVectors, 
    804         svm_model* model, 
    805         PKernelFunc kernelFunc) { 
    806     this->classVar = var; 
     797        PKernelFunc kernelFunc, 
     798        PExampleTable examples 
     799        ) : TClassifierFD(domain) { 
     800    this->model = model; 
     801    this->supportVectors = supportVectors; 
     802    this->kernelFunc = kernelFunc; 
    807803    this->examples = examples; 
    808     this->supportVectors = supportVectors; 
    809     this->model = model; 
    810     this->kernelFunc = kernelFunc; 
    811  
    812     domain = supportVectors->domain; 
     804 
    813805    svm_type = svm_get_svm_type(model); 
    814806    kernel_type = model->param.kernel_type; 
    815807 
     808    if (svm_type == ONE_CLASS) { 
     809        this->classVar = mlnew TFloatVariable("one class"); 
     810    } 
     811 
    816812    computesProbabilities = model && svm_check_probability_model(model) && \ 
    817             (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression 
     813                (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression 
    818814 
    819815    int nr_class = svm_get_nr_class(model); 
     
    823819     * class interface. 
    824820     */ 
    825     if (svm_type == C_SVC || svm_type == NU_SVC){ 
    826         nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum = model->l) 
    827         for(i = 0;i < nr_class; i++) 
    828             nSV->at(i) = model->nSV[i]; 
    829     } 
     821    if (svm_type == C_SVC || svm_type == NU_SVC) { 
     822        nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum(nSV) == model->l) 
     823        for(i = 0;i < nr_class; i++) { 
     824            nSV->at(i) = model->nSV[i]; 
     825        } 
     826    } 
    830827 
    831828    coef = mlnew TFloatListList(nr_class-1); 
    832     for(i = 0; i < nr_class - 1; i++){ 
     829    for(i = 0; i < nr_class - 1; i++) { 
    833830        TFloatList *coefs = mlnew TFloatList(model->l); 
    834         for(int j = 0;j < model->l; j++) 
     831        for(int j = 0;j < model->l; j++) { 
    835832            coefs->at(j) = model->sv_coef[i][j]; 
    836         coef->at(i)=coefs; 
    837     } 
    838     rho = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    839     for(i = 0; i < nr_class*(nr_class-1)/2; i++) 
     833        } 
     834        coef->at(i) = coefs; 
     835    } 
     836 
     837    // Number of binary classifiers in the model 
     838    int nr_bin_cls = nr_class * (nr_class - 1) / 2; 
     839 
     840    rho = mlnew TFloatList(nr_bin_cls); 
     841    for(i = 0; i < nr_bin_cls; i++) { 
    840842        rho->at(i) = model->rho[i]; 
    841     if(model->probA){ 
    842         probA = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    843         if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) // Regression has only probA 
    844             probB = mlnew TFloatList(nr_class*(nr_class-1)/2); 
    845         for(i=0; i<nr_class*(nr_class-1)/2; i++){ 
     843    } 
     844 
     845    if(model->probA) { 
     846        probA = mlnew TFloatList(nr_bin_cls); 
     847        if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) { 
     848            // Regression only has probA 
     849            probB = mlnew TFloatList(nr_bin_cls); 
     850        } 
     851 
     852        for(i=0; i<nr_bin_cls; i++) { 
    846853            probA->at(i) = model->probA[i]; 
    847             if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) 
     854            if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) { 
    848855                probB->at(i) = model->probB[i]; 
    849         } 
    850     } 
    851 } 
     856            } 
     857        } 
     858    } 
     859} 
     860 
    852861 
    853862TSVMClassifier::~TSVMClassifier(){ 
     
    856865    } 
    857866} 
     867 
    858868 
    859869PDistribution TSVMClassifier::classDistribution(const TExample & example){ 
  • source/orange/libsvm_interface.hpp

    r11606 r11607  
    6767WRAPPER(KernelFunc) 
    6868 
    69 //#include "callback.hpp" 
    7069 
    7170class ORANGE_API TSVMLearner : public TLearner{ 
     
    108107    virtual int getNumOfElements(PExampleGenerator examples); 
    109108    virtual TSVMClassifier* createClassifier( 
    110             PVariable var, PExampleTable examples, PExampleTable supportVectors, svm_model* model); 
     109                PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples); 
    111110}; 
    112111 
     
    119118    virtual int getNumOfElements(PExampleGenerator examples); 
    120119    virtual TSVMClassifier* createClassifier( 
    121             PVariable classVar, PExampleTable examples, PExampleTable supportVectors, svm_model* model); 
     120            PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples); 
    122121}; 
    123122 
    124123 
    125 class ORANGE_API TSVMClassifier : public TClassifierFD{ 
     124class ORANGE_API TSVMClassifier : public TClassifierFD { 
    126125public: 
    127126    __REGISTER_CLASS 
    128     TSVMClassifier(){ 
     127    TSVMClassifier() { 
    129128        this->model = NULL; 
    130129    }; 
    131130 
    132     TSVMClassifier(const PVariable & , PExampleTable examples, PExampleTable supportVectors, 
    133             svm_model* model, PKernelFunc kernelFunc); 
     131    TSVMClassifier(PDomain, svm_model * model, PExampleTable supportVectors, 
     132            PKernelFunc kernelFunc=NULL, PExampleTable examples=NULL); 
    134133 
    135134    ~TSVMClassifier(); 
     
    146145    PFloatList probB; //P probB - pairwise probability information 
    147146    PExampleTable supportVectors; //P support vectors 
    148     PExampleTable examples; //P (training instances when svm_type == Custom) 
    149     PKernelFunc kernelFunc; //P custom kernel function 
     147 
     148    PExampleTable examples; //P training instances when svm_type == Custom 
     149    PKernelFunc kernelFunc; //P custom kernel function used when svm_type == Custom 
    150150 
    151151    int svm_type; //P(&SVMLearner_SVMType)  SVM type (C_SVC=0, NU_SVC, ONE_CLASS, EPSILON_SVR=3, NU_SVR=4) 
     
    162162}; 
    163163 
    164 class ORANGE_API TSVMClassifierSparse : public TSVMClassifier{ 
     164class ORANGE_API TSVMClassifierSparse : public TSVMClassifier { 
    165165public: 
    166166    __REGISTER_CLASS 
    167     TSVMClassifierSparse(){}; 
    168     TSVMClassifierSparse(PVariable var, PExampleTable examples,  PExampleTable supportVectors, 
    169             svm_model* model, bool useNonMeta, PKernelFunc kernelFunc 
    170             ) :TSVMClassifier(var, examples, supportVectors, model, kernelFunc){ 
     167    TSVMClassifierSparse() {}; 
     168 
     169    TSVMClassifierSparse( 
     170            PDomain domain, svm_model * model, bool useNonMeta, 
     171            PExampleTable supportVectors, 
     172            PKernelFunc kernelFunc=NULL, 
     173            PExampleTable examples=NULL 
     174            ) : TSVMClassifier(domain, model, supportVectors, kernelFunc, examples) { 
    171175        this->useNonMeta = useNonMeta; 
    172176    } 
    173177 
    174     bool useNonMeta; //P include non meta attributes 
     178    bool useNonMeta; //PR include non meta attributes 
    175179 
    176180protected: 
Note: See TracChangeset for help on using the changeset viewer.