source: orange/source/orange/libsvm_interface.cpp @ 11607:8ecd4831def9

Revision 11607:8ecd4831def9, 24.6 KB checked in by Ales Erjavec <ales.erjavec@…>, 10 months ago (diff)

Changed (simplified) SVMClassifier constructors (and pickling).

RevLine 
[8978]1/*
2    This file is part of Orange.
3
4    Copyright 1996-2011 Faculty of Computer and Information Science, University of Ljubljana
5    Contact: janez.demsar@fri.uni-lj.si
6
7    Orange is free software: you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation, either version 3 of the License, or
10    (at your option) any later version.
11
12    Orange is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with Orange.  If not, see <http://www.gnu.org/licenses/>.
19*/
20
21#include <iostream>
22#include <sstream>
23
24#include "libsvm_interface.ppp"
25
26// Defined in svm.cpp. If new svm or kernel types are added this should be updated.
27
28static const char *svm_type_table[] =
29{
30    "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
31};
32
33static const char *kernel_type_table[]=
34{
35    "linear","polynomial","rbf","sigmoid","precomputed",NULL
36};
37
38#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
39
40/*
41 * Save load functions for use with orange pickling.
42 * They are a copy of the standard libSVM save-load functions
43 * except that they read/write from/to std::iostream objects.
44 */
45
46int svm_save_model_alt(std::ostream& stream, const svm_model *model){
47    const svm_parameter& param = model->param;
48    stream.precision(17);
49
50    stream << "svm_type " << svm_type_table[param.svm_type] << endl;
51    stream << "kernel_type " << kernel_type_table[param.kernel_type] << endl;
52
53    if(param.kernel_type == POLY)
54        stream << "degree " << param.degree << endl;
55
56    if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
57        stream << "gamma " << param.gamma << endl;
58
59    if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
60        stream << "coef0 " << param.coef0 << endl;
61
62    int nr_class = model->nr_class;
63    int l = model->l;
64    stream << "nr_class " << nr_class << endl;
65    stream << "total_sv " << l << endl;
66    {
67        stream << "rho";
68        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
69            stream << " " << model->rho[i];
70        stream << endl;
71    }
72
73    if(model->label)
74    {
75        stream << "label";
76        for(int i=0;i<nr_class;i++)
77            stream << " " << model->label[i];
78        stream << endl;
79    }
80
81    if(model->probA) // regression has probA only
82    {
83        stream << "probA";
84        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
85            stream << " " << model->probA[i];
86        stream << endl;
87    }
88    if(model->probB)
89    {
90        stream << "probB";
91        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
92            stream << " " << model->probB[i];
93        stream << endl;
94    }
95
96    if(model->nSV)
97    {
98        stream << "nr_sv";
99        for(int i=0;i<nr_class;i++)
100            stream << " " << model->nSV[i];
101        stream << endl;
102    }
103
104    stream << "SV" << endl;
105    const double * const *sv_coef = model->sv_coef;
106    const svm_node * const *SV = model->SV;
107
108    for(int i=0;i<l;i++)
109    {
110        for(int j=0;j<nr_class-1;j++)
111            stream << sv_coef[j][i] << " ";
112
113        const svm_node *p = SV[i];
114
115        if(param.kernel_type == PRECOMPUTED)
[9039]116            stream << "0:" << (int)(p->value) << " ";
[8978]117        else
118            while(p->index != -1)
119            {
120                stream << (int)(p->index) << ":" << p->value << " ";
121                p++;
122            }
123        stream << endl;
124    }
125
126    if (!stream.fail())
127        return 0;
128    else
129        return 1;
130}
131
132int svm_save_model_alt(std::string& buffer, const svm_model *model){
133    std::ostringstream strstream;
134    int ret = svm_save_model_alt(strstream, model);
135    buffer = strstream.rdbuf()->str();
136    return ret;
137}
138
139
140#include <algorithm>
141
142svm_model *svm_load_model_alt(std::istream& stream)
143{
144    svm_model *model = Malloc(svm_model,1);
145    svm_parameter& param = model->param;
146    model->rho = NULL;
147    model->probA = NULL;
148    model->probB = NULL;
149    model->label = NULL;
150    model->nSV = NULL;
151
152    char cmd[81];
153    stream.width(80);
154    while (stream.good())
155    {
156        stream >> cmd;
157
158        if(strcmp(cmd, "svm_type") == 0)
159        {
160            stream >> cmd;
161            int i;
162            for(i=0; svm_type_table[i]; i++)
163            {
164                if(strcmp(cmd, svm_type_table[i]) == 0)
165                {
166                    param.svm_type=i;
167                    break;
168                }
169            }
170            if(svm_type_table[i] == NULL)
171            {
172                fprintf(stderr, "unknown svm type.\n");
173                free(model->rho);
174                free(model->label);
175                free(model->nSV);
176                free(model);
177                return NULL;
178            }
179        }
180        else if(strcmp(cmd, "kernel_type") == 0)
181        {
182            stream >> cmd;
183            int i;
184            for(i=0;kernel_type_table[i];i++)
185            {
186                if(strcmp(kernel_type_table[i], cmd)==0)
187                {
188                    param.kernel_type=i;
189                    break;
190                }
191            }
192            if(kernel_type_table[i] == NULL)
193            {
194                fprintf(stderr,"unknown kernel function.\n");
195                free(model->rho);
196                free(model->label);
197                free(model->nSV);
198                free(model);
199                return NULL;
200            }
201        }
202        else if(strcmp(cmd,"degree")==0)
203            stream >> param.degree;
204        else if(strcmp(cmd,"gamma")==0)
205            stream >> param.gamma;
206        else if(strcmp(cmd,"coef0")==0)
207            stream >> param.coef0;
208        else if(strcmp(cmd,"nr_class")==0)
209            stream >> model->nr_class;
210        else if(strcmp(cmd,"total_sv")==0)
211            stream >> model->l;
212        else if(strcmp(cmd,"rho")==0)
213        {
214            int n = model->nr_class * (model->nr_class-1)/2;
215            model->rho = Malloc(double,n);
[9389]216            string rho_str;
217            for(int i=0;i<n;i++){
218                // Read the number into a string and then use strtod
219                // for proper handling of NaN's
220                stream >> rho_str;
221                model->rho[i] = strtod(rho_str.c_str(), NULL);
222            }
[8978]223        }
224        else if(strcmp(cmd,"label")==0)
225        {
226            int n = model->nr_class;
227            model->label = Malloc(int,n);
228            for(int i=0;i<n;i++)
229                stream >> model->label[i];
230        }
231        else if(strcmp(cmd,"probA")==0)
232        {
233            int n = model->nr_class * (model->nr_class-1)/2;
234            model->probA = Malloc(double,n);
235            for(int i=0;i<n;i++)
236                stream >> model->probA[i];
237        }
238        else if(strcmp(cmd,"probB")==0)
239        {
240            int n = model->nr_class * (model->nr_class-1)/2;
241            model->probB = Malloc(double,n);
242            for(int i=0;i<n;i++)
243                stream >> model->probB[i];
244        }
245        else if(strcmp(cmd,"nr_sv")==0)
246        {
247            int n = model->nr_class;
248            model->nSV = Malloc(int,n);
249            for(int i=0;i<n;i++)
250                stream >> model->nSV[i];
251        }
252        else if(strcmp(cmd,"SV")==0)
253        {
254            while(1)
255            {
256                int c = stream.get();
257                if(stream.eof() || c=='\n') break;
258            }
259            break;
260        }
261        else
262        {
263            fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
264            free(model->rho);
265            free(model->label);
266            free(model->nSV);
267            free(model);
268            return NULL;
269        }
270    }
271    if (stream.fail()){
272        free(model->rho);
273        free(model->label);
274        free(model->nSV);
275        free(model);
276        return NULL;
277
278    }
279
280    // read sv_coef and SV
281
282    int elements = 0;
283    long pos = stream.tellg();
284
285    char *p,*endptr,*idx,*val;
286    string str_line;
287    while (!stream.eof() && !stream.fail())
288    {
289        getline(stream, str_line);
290        elements += std::count(str_line.begin(), str_line.end(), ':');
291    }
292
293    elements += model->l;
294
295    stream.clear();
296    stream.seekg(pos, ios::beg);
297
298    int m = model->nr_class - 1;
299    int l = model->l;
300    model->sv_coef = Malloc(double *,m);
301    int i;
302    for(i=0;i<m;i++)
303        model->sv_coef[i] = Malloc(double,l);
304    model->SV = Malloc(svm_node*,l);
305    svm_node *x_space = NULL;
306    if(l>0) x_space = Malloc(svm_node,elements);
307
308    int j=0;
309    char *line;
310    for(i=0;i<l;i++)
311    {
312        getline(stream, str_line);
313        if (str_line.size() == 0)
314            continue;
315
316        line = (char *) Malloc(char, str_line.size() + 1);
317        // Copy the line for strtok.
318        strcpy(line, str_line.c_str());
319
320        model->SV[i] = &x_space[j];
321
322        p = strtok(line, " \t");
323        model->sv_coef[0][i] = strtod(p,&endptr);
324        for(int k=1;k<m;k++)
325        {
326            p = strtok(NULL, " \t");
327            model->sv_coef[k][i] = strtod(p,&endptr);
328        }
329
330        while(1)
331        {
332            idx = strtok(NULL, ":");
333            val = strtok(NULL, " \t");
334
335            if(val == NULL)
336                break;
337            x_space[j].index = (int) strtol(idx,&endptr,10);
338            x_space[j].value = strtod(val,&endptr);
339
340            ++j;
341        }
342        x_space[j++].index = -1;
343        free(line);
344    }
345
346    if (stream.fail())
347        return NULL;
348
349    model->free_sv = 1; // XXX
350    return model;
351}
352
353svm_model *svm_load_model_alt(std::string& stream)
354{
355    std::istringstream strstream(stream);
356    return svm_load_model_alt(strstream);
357}
358
[11606]359
360/*
361 * Return a formated string representing a svm data instance (svm_node *)
362 * (useful for debugging)
363 */
364string svm_node_to_string(svm_node * node) {
365    std::ostringstream strstream;
366    strstream.precision(17);
367    while (node->index != -1) {
368        strstream << node->index << ":" << node->value << " ";
369        node++;
370    }
371    strstream << node->index << ":" << node->value << " ";
372    return strstream.rdbuf()->str();
373}
374
375
[8978]376svm_node* example_to_svm(const TExample &ex, svm_node* node, float last=0.0, int type=0){
377    if(type==0){
378        int index=1;
379        for(TExample::iterator i=ex.begin(); i!=ex.end(); i++){
380            if(i->isRegular() && i!=&ex.getClass()){
381                if(i->varType==TValue::FLOATVAR)
382                    node->value=float(*i);
383                else
384                    node->value=int(*i);
385                node->index=index++;
386                if(node->value==numeric_limits<float>::signaling_NaN() ||
387                    node->value==numeric_limits<int>::max())
388                    node--;
389                node++;
390            }
391        }
392    }
393    if(type == 1){ /*one dummy attr so we can pickle the classifier and keep the SV index in the training table*/
394        node->index=1;
395        node->value=last;
396        node++;
397    }
398    //cout<<(node-1)->index<<endl<<(node-2)->index<<endl;
399    node->index=-1;
400    node->value=last;
401    node++;
402    return node;
403}
404
405class SVM_NodeSort{
406public:
407    bool operator() (const svm_node &lhs, const svm_node &rhs){
408        return lhs.index < rhs.index;
409    }
410};
411
412svm_node* example_to_svm_sparse(const TExample &ex, svm_node* node, float last=0.0, bool useNonMeta=false){
413    svm_node *first=node;
414    int j=1;
415    int index=1;
416    if(useNonMeta)
417        for(TExample::iterator i=ex.begin(); i!=ex.end(); i++){
418            if(i->isRegular() && i!=&ex.getClass()){
419                if(i->varType==TValue::FLOATVAR)
420                    node->value=float(*i);
421                else
422                    node->value=int(*i);
423                node->index=index;
424                if(node->value==numeric_limits<float>::signaling_NaN() ||
425                    node->value==numeric_limits<int>::max())
426                    node--;
427                node++;
428            }
429            index++;
430        }
431    for(TMetaValues::const_iterator i=ex.meta.begin(); i!=ex.meta.end();i++,j++){
432        if(i->second.isRegular()){
433            if(i->second.varType==TValue::FLOATVAR)
434                node->value=float(i->second);
435            else
436                node->value=int(i->second);
[11606]437            node->index = index - i->first;
438
[8978]439            if(node->value==numeric_limits<float>::signaling_NaN() ||
440                node->value==numeric_limits<int>::max())
441                node--;
442            node++;
443        }
444    }
445    sort(first, node, SVM_NodeSort());
446    //cout<<first->index<<endl<<(first+1)->index<<endl;
447    node->index=-1;
448    node->value=last;
449    node++;
450    return node;
451}
452
453/*
454 * Precompute Gram matrix row for ex.
455 * Used for prediction when using the PRECOMPUTED kernel.
456 */
457svm_node* example_to_svm_precomputed(const TExample &ex, PExampleGenerator examples, PKernelFunc kernel, svm_node* node){
458    node->index = 0;
[11606]459    node->value = 0.0; // Can be any value.
[8978]460    node++;
461    int k = 0;
462    PEITERATE(iter, examples){
463        node->index = ++k;
464        node->value = kernel.getReference()(*iter, ex);
465        node++;
466    }
467    node->index = -1; // sentry
468    node++;
469    return node;
470}
471
472int getNumOfElements(const TExample &ex, bool meta=false, bool useNonMeta=false){
473    if(!meta)
474        return std::max(ex.domain->attributes->size()+1, 2);
475    else{
476        int count=1; //we need one to indicate the end of a sequence
477        if(useNonMeta)
478            count+=ex.domain->attributes->size();
479        for(TMetaValues::const_iterator i=ex.meta.begin(); i!=ex.meta.end();i++)
480            if(i->second.isRegular())
481                count++;
482        return std::max(count,2);
483    }
484}
485
486int getNumOfElements(PExampleGenerator &examples, bool meta=false, bool useNonMeta=false){
487    if(!meta)
488        return getNumOfElements(*(examples->begin()), meta)*examples->numberOfExamples();
489    else{
490        int count=0;
491        for(TExampleGenerator::iterator ex(examples->begin()); ex!=examples->end(); ++ex){
492            count+=getNumOfElements(*ex, meta, useNonMeta);
493        }
494        return count;
495    }
496}
497
498#include "symmatrix.hpp"
499svm_node* init_precomputed_problem(svm_problem &problem, PExampleTable examples, TKernelFunc &kernel){
500    int n_examples = examples->numberOfExamples();
501    int i,j;
502    PSymMatrix matrix = mlnew TSymMatrix(n_examples, 0.0);
503    for (i = 0; i < n_examples; i++)
504        for (j = 0; j <= i; j++){
505            matrix->getref(i, j) = kernel(examples->at(i), examples->at(j));
506//          cout << i << " " << j << " " << matrix->getitem(i, j) << endl;
507        }
508    svm_node *x_space = Malloc(svm_node, n_examples * (n_examples + 2));
509    svm_node *node = x_space;
510
511    problem.l = n_examples;
512    problem.x = Malloc(svm_node*, n_examples);
513    problem.y = Malloc(double, n_examples);
514
515    for (i = 0; i < n_examples; i++){
516        problem.x[i] = node;
517        if (examples->domain->classVar->varType == TValue::FLOATVAR)
518            problem.y[i] = examples->at(i).getClass().floatV;
519        else
520            problem.y[i] = examples->at(i).getClass().intV;
521
522        node->index = 0;
523        node->value = i + 1; // instance indices are 1 based
524        node++;
525        for (j = 0; j < n_examples; j++){
526            node->index = j + 1;
527            node->value = matrix->getitem(i, j);
528            node++;
529        }
530        node->index = -1; // sentry
531        node++;
532    }
533    return x_space;
534}
535
[11606]536/*
537 * Extract an ExampleTable corresponding to the support vectors from the
538 * trained model.
539 */
540PExampleTable extract_support_vectors(svm_model * model, PExampleTable train_instances)
541{
542    PExampleTable vectors = mlnew TExampleTable(train_instances->domain);
543
544    for (int i = 0; i < model->l; i++) {
545        svm_node *node = model->SV[i];
546        int sv_index = -1;
547        if(model->param.kernel_type != PRECOMPUTED){
548            /* The value of the last node (with index == -1) holds the
549             * index of the training example.
550             */
551            while(node->index != -1) {
552                node++;
553            }
554            sv_index = int(node->value);
555        } else {
556            /* The value of the first node contains the training instance
557             * index (indices 1 based).
558             */
559            sv_index = int(node->value) - 1;
560        }
561        vectors->addExample(mlnew TExample(train_instances->at(sv_index)));
562    }
563
564    return vectors;
565}
566
567
568/*
569 * Consolidate model->SV[1] .. SV[l] vectors into a single contiguous
570 * memory block. The model will 'own' the new *(model->SV) array and
571 * will be freed in destroy_svm_model (model->free_sv == 1). Note that
572 * the original 'x_space' is left intact, it is the caller's
573 * responsibility to free it. However the model->SV array itself is
574 * reused (overwritten).
575 */
576
577void svm_model_consolidate_SV(svm_model * model) {
578    int count = 0;
579    svm_node * x_space = NULL;
580    svm_node * ptr = NULL;
581    svm_node * ptr_source = NULL;
582
583    // Count the number of elements.
584    for (int i = 0; i < model->l; i++) {
585        ptr = model->SV[i];
586        while (ptr->index != -1){
587            count++;
588            ptr++;
589        }
590    }
591    // add the sentinel count
592    count += model->l;
593
594    x_space = Malloc(svm_node, count);
595    ptr = x_space;
596    for (int i = 0; i < model->l; i++) {
597        ptr_source = model->SV[i];
598        model->SV[i] = ptr;
599        while (ptr_source->index != -1) {
600            *(ptr++) = *(ptr_source++);
601        }
602        // copy the sentinel
603        *(ptr++) = *(ptr_source++);
604    }
605    model->free_sv = 1; // XXX
606}
607
[8978]608static void print_string_null(const char* s) {}
609
[11606]610
[8978]611TSVMLearner::TSVMLearner(){
612    svm_type = NU_SVC;
613    kernel_type = RBF;
614    degree = 3;
615    gamma = 0;
616    coef0 = 0;
617    nu = 0.5;
618    cache_size = 250;
619    C = 1;
620    eps = 1e-3f;
621    p = 0.1f;
622    shrinking = 1;
623    probability = 0;
624    verbose = false;
625    nr_weight = 0;
626    weight_label = NULL;
627    weight = NULL;
628};
629
[11606]630
[8978]631PClassifier TSVMLearner::operator ()(PExampleGenerator examples, const int&){
632    svm_parameter param;
633    svm_problem prob;
634    svm_model* model;
635    svm_node* x_space;
636
637    param.svm_type = svm_type;
638    param.kernel_type = kernel_type;
639    param.degree = degree;
640    param.gamma = gamma;
641    param.coef0 = coef0;
642    param.nu = nu;
643    param.C = C;
644    param.eps = eps;
645    param.p = p;
646    param.cache_size=cache_size;
647    param.shrinking=shrinking;
648    param.probability=probability;
649    param.nr_weight = nr_weight;
[11606]650
[8978]651    if (nr_weight > 0) {
652        param.weight_label = Malloc(int, nr_weight);
653        param.weight = Malloc(double, nr_weight);
654        int i;
655        for (i=0; i<nr_weight; i++) {
656            param.weight_label[i] = weight_label[i];
657            param.weight[i] = weight[i];
658        }
659    } else {
660        param.weight_label = NULL;
661        param.weight = NULL;
662    }
663
664    int classVarType;
665    if(examples->domain->classVar)
666        classVarType=examples->domain->classVar->varType;
667    else{
668        classVarType=TValue::NONE;
669        if(svm_type!=ONE_CLASS)
670            raiseError("Domain has no class variable");
671    }
672    if(classVarType==TValue::FLOATVAR && !(svm_type==EPSILON_SVR || svm_type==NU_SVR ||svm_type==ONE_CLASS))
673        raiseError("Domain has continuous class");
674
675    if(kernel_type==PRECOMPUTED && !kernelFunc)
676        raiseError("Custom kernel function not supplied");
677
678    int numElements=getNumOfElements(examples);
679
680    if(kernel_type != PRECOMPUTED)
681        x_space = init_problem(prob, examples, numElements);
682    else // Compute the matrix using the kernelFunc
683        x_space = init_precomputed_problem(prob, examples, kernelFunc.getReference());
684
685    if(param.gamma==0)
686        param.gamma=1.0f/(float(numElements)/float(prob.l)-1);
687
688    const char* error=svm_check_parameter(&prob,&param);
689    if(error){
690        free(x_space);
691        free(prob.y);
692        free(prob.x);
693        raiseError("LibSVM parameter error: %s", error);
694    }
[9347]695
696    // If a probability model was requested LibSVM uses 5 fold
697    // cross-validation to estimate the prediction errors. This includes a
698    // random shuffle of the data. To make the results reproducible and
699    // consistent with 'svm-train' (which always learns just on one dataset
700    // in a process run) we reset the random seed. This could have unintended
701    // consequences.
702    if (param.probability)
703    {
704        srand(1);
705    }
[8978]706    svm_set_print_string_function((verbose)? NULL : &print_string_null);
707
[11606]708    model = svm_train(&prob, &param);
[8978]709
[11606]710    if ((svm_type==C_SVC || svm_type==NU_SVC) && !model->nSV) {
711        svm_free_and_destroy_model(&model);
712        if (x_space) {
713            free(x_space);
714        }
715        raiseError("LibSVM returned no support vectors");
716    }
717
[8978]718    svm_destroy_param(&param);
719    free(prob.y);
720    free(prob.x);
[11606]721
722    // Consolidate the SV so x_space can be safely freed
723    svm_model_consolidate_SV(model);
724
725    free(x_space);
726
727    PExampleTable supportVectors = extract_support_vectors(model, examples);
728
[11607]729    PDomain domain = examples->domain;
[11606]730
[11607]731    return PClassifier(createClassifier(examples->domain, model, supportVectors, examples));
[8978]732}
733
734svm_node* TSVMLearner::example_to_svm(const TExample &ex, svm_node* node, float last, int type){
735    return ::example_to_svm(ex, node, last, type);
736}
737
[9187]738svm_node* TSVMLearner::init_problem(svm_problem &problem, PExampleTable examples, int n_elements){
739    problem.l = examples->numberOfExamples();
740    problem.y = Malloc(double ,problem.l);
741    problem.x = Malloc(svm_node*, problem.l);
742    svm_node *x_space = Malloc(svm_node, n_elements);
743    svm_node *node = x_space;
744
745    for (int i = 0; i < problem.l; i++){
746        problem.x[i] = node;
747        node = example_to_svm(examples->at(i), node, i);
748        if (examples->domain->classVar)
749            if (examples->domain->classVar->varType == TValue::FLOATVAR)
750                problem.y[i] = examples->at(i).getClass().floatV;
751            else if (examples->domain->classVar->varType == TValue::INTVAR)
752                problem.y[i] = examples->at(i).getClass().intV;
753    }
754    return x_space;
755}
756
[8978]757int TSVMLearner::getNumOfElements(PExampleGenerator examples){
758    return ::getNumOfElements(examples);
759}
760
[11606]761TSVMClassifier* TSVMLearner::createClassifier(
[11607]762        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) {
[11606]763    if (kernel_type != PRECOMPUTED) {
764        examples = NULL;
765    }
[11607]766    return mlnew TSVMClassifier(domain, model, supportVectors, kernelFunc, examples);
[8978]767}
768
769TSVMLearner::~TSVMLearner(){
770    if(weight_label)
771        free(weight_label);
772
773    if(weight)
774            free(weight);
775}
776
777svm_node* TSVMLearnerSparse::example_to_svm(const TExample &ex, svm_node* node, float last, int type){
778    return ::example_to_svm_sparse(ex, node, last, useNonMeta);
779}
780
781int TSVMLearnerSparse::getNumOfElements(PExampleGenerator examples){
782    return ::getNumOfElements(examples, true, useNonMeta);
783}
784
[11606]785TSVMClassifier* TSVMLearnerSparse::createClassifier(
[11607]786        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) {
[11606]787    if (kernel_type != PRECOMPUTED) {
788        examples = NULL;
789    }
[11607]790    return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kernelFunc, examples);
[8978]791}
792
[11606]793
794TSVMClassifier::TSVMClassifier(
[11607]795        PDomain domain, svm_model * model,
[11606]796        PExampleTable supportVectors,
[11607]797        PKernelFunc kernelFunc,
798        PExampleTable examples
799        ) : TClassifierFD(domain) {
800    this->model = model;
801    this->supportVectors = supportVectors;
802    this->kernelFunc = kernelFunc;
[11606]803    this->examples = examples;
804
[9021]805    svm_type = svm_get_svm_type(model);
806    kernel_type = model->param.kernel_type;
807
[11607]808    if (svm_type == ONE_CLASS) {
809        this->classVar = mlnew TFloatVariable("one class");
810    }
811
[8978]812    computesProbabilities = model && svm_check_probability_model(model) && \
[11607]813                (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression
[9021]814
[8978]815    int nr_class = svm_get_nr_class(model);
816    int i = 0;
[11606]817
818    /* Expose (copy) the model data (coef, rho, probA) to public
819     * class interface.
820     */
[11607]821    if (svm_type == C_SVC || svm_type == NU_SVC) {
822        nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum(nSV) == model->l)
823        for(i = 0;i < nr_class; i++) {
824            nSV->at(i) = model->nSV[i];
825        }
826    }
[8978]827
828    coef = mlnew TFloatListList(nr_class-1);
[11607]829    for(i = 0; i < nr_class - 1; i++) {
[8978]830        TFloatList *coefs = mlnew TFloatList(model->l);
[11607]831        for(int j = 0;j < model->l; j++) {
[8978]832            coefs->at(j) = model->sv_coef[i][j];
[11607]833        }
834        coef->at(i) = coefs;
[8978]835    }
[11607]836
837    // Number of binary classifiers in the model
838    int nr_bin_cls = nr_class * (nr_class - 1) / 2;
839
840    rho = mlnew TFloatList(nr_bin_cls);
841    for(i = 0; i < nr_bin_cls; i++) {
[8978]842        rho->at(i) = model->rho[i];
[11607]843    }
844
845    if(model->probA) {
846        probA = mlnew TFloatList(nr_bin_cls);
847        if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) {
848            // Regression only has probA
849            probB = mlnew TFloatList(nr_bin_cls);
850        }
851
852        for(i=0; i<nr_bin_cls; i++) {
[8978]853            probA->at(i) = model->probA[i];
[11607]854            if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) {
[8978]855                probB->at(i) = model->probB[i];
[11607]856            }
[8978]857        }
858    }
859}
860
[11607]861
[8978]862TSVMClassifier::~TSVMClassifier(){
[11606]863    if (model) {
[8978]864        svm_free_and_destroy_model(&model);
[11606]865    }
[8978]866}
867
[11607]868
[8978]869PDistribution TSVMClassifier::classDistribution(const TExample & example){
870    if(!model)
871        raiseError("No Model");
872
873    if(!computesProbabilities)
874        return TClassifierFD::classDistribution(example);
875
876    int n_elements;
877    if (model->param.kernel_type != PRECOMPUTED)
878        n_elements = getNumOfElements(example);
879    else
880        n_elements = examples->numberOfExamples() + 2;
881
882    int svm_type = svm_get_svm_type(model);
883    int nr_class = svm_get_nr_class(model);
884
885    svm_node *x = Malloc(svm_node, n_elements);
886    try{
887        if (model->param.kernel_type != PRECOMPUTED)
888            example_to_svm(example, x, -1.0);
889        else
890            example_to_svm_precomputed(example, examples, kernelFunc, x);
891    } catch (...) {
892        free(x);
893        throw;
894    }
895
896    int *labels=(int *) malloc(nr_class*sizeof(int));
897    svm_get_labels(model, labels);
898
899    double *prob_estimates = (double *) malloc(nr_class*sizeof(double));;
900    svm_predict_probability(model, x, prob_estimates);
901
902    PDistribution dist = TDistribution::create(example.domain->classVar);
903    for(int i=0; i<nr_class; i++)
904        dist->setint(labels[i], prob_estimates[i]);
905    free(x);
906    free(prob_estimates);
907    free(labels);
908    return dist;
909}
910
911TValue TSVMClassifier::operator()(const TExample & example){
912    if(!model)
913        raiseError("No Model");
914
915    int n_elements;
916    if (model->param.kernel_type != PRECOMPUTED)
[11606]917        n_elements = getNumOfElements(example);
[8978]918    else
919        n_elements = examples->numberOfExamples() + 2;
920
921    int svm_type = svm_get_svm_type(model);
922    int nr_class = svm_get_nr_class(model);
923
924    svm_node *x = Malloc(svm_node, n_elements);
925    try {
926        if (model->param.kernel_type != PRECOMPUTED)
927            example_to_svm(example, x);
928        else
929            example_to_svm_precomputed(example, examples, kernelFunc, x);
930    } catch (...) {
931        free(x);
932        throw;
933    }
934
935    double v;
936
937    if(svm_check_probability_model(model)){
938        double *prob = (double *) malloc(nr_class*sizeof(double));
939        v = svm_predict_probability(model, x, prob);
940        free(prob);
941    } else
942        v = svm_predict(model, x);
943
944    free(x);
945    if(svm_type==NU_SVR || svm_type==EPSILON_SVR || svm_type==ONE_CLASS)
946        return TValue(v);
947    else
948        return TValue(int(v));
949}
950
951PFloatList TSVMClassifier::getDecisionValues(const TExample &example){
952    if(!model)
953        raiseError("No Model");
954
955    int n_elements;
[11606]956    if (model->param.kernel_type != PRECOMPUTED)
957        n_elements = getNumOfElements(example);
958    else
959        n_elements = examples->numberOfExamples() + 2;
[8978]960
961    int svm_type=svm_get_svm_type(model);
962    int nr_class=svm_get_nr_class(model);
963
964    svm_node *x = Malloc(svm_node, n_elements);
965    try {
966        if (model->param.kernel_type != PRECOMPUTED)
967            example_to_svm(example, x);
968        else
969            example_to_svm_precomputed(example, examples, kernelFunc, x);
970    } catch (...) {
971        free(x);
972        throw;
973    }
974
975    int nDecValues = nr_class*(nr_class-1)/2;
976    double *dec = (double*) malloc(sizeof(double)*nDecValues);
977    svm_predict_values(model, x, dec);
978    PFloatList res = mlnew TFloatList(nDecValues);
979    for(int i=0; i<nDecValues; i++){
980        res->at(i) = dec[i];
981    }
982    free(x);
983    free(dec);
984    return res;
985}
986
987svm_node *TSVMClassifier::example_to_svm(const TExample &ex, svm_node *node, float last, int type){
988    return ::example_to_svm(ex, node, last, type);
989}
990
991int TSVMClassifier::getNumOfElements(const TExample& example){
992    return ::getNumOfElements(example);
993}
[11606]994svm_node *TSVMClassifierSparse::example_to_svm(const TExample &ex, svm_node *node, float last, int){
[8978]995    return ::example_to_svm_sparse(ex, node, last, useNonMeta);
996}
997
998int TSVMClassifierSparse::getNumOfElements(const TExample& example){
999    return ::getNumOfElements(example, true, useNonMeta);
1000}
1001
1002
Note: See TracBrowser for help on using the repository browser.