source: orange/source/orange/libsvm_interface.cpp @ 11703:9b8d8ab7820c

Revision 11703:9b8d8ab7820c, 25.4 KB checked in by janezd <janez.demsar@…>, 7 months ago (diff)

Removed the GPL copyright notice from all files except orangeqt.

Line 
1#include <iostream>
2#include <sstream>
3
4#include "libsvm_interface.hpp"
5#include "symmatrix.hpp"
6
7#include <algorithm>
8#include <cmath>
9
10
11// Defined in svm.cpp. If new svm or kernel types are added this should be updated.
12
13static const char *svm_type_table[] =
14{
15    "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
16};
17
18static const char *kernel_type_table[]=
19{
20    "linear","polynomial","rbf","sigmoid","precomputed",NULL
21};
22
23#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
24
25/*
26 * Save load functions for use with orange pickling.
27 * They are a copy of the standard libSVM save-load functions
28 * except that they read/write from/to std::iostream objects.
29 */
30
31int svm_save_model_alt(std::ostream& stream, const svm_model *model){
32    const svm_parameter& param = model->param;
33    stream.precision(17);
34
35    stream << "svm_type " << svm_type_table[param.svm_type] << endl;
36    stream << "kernel_type " << kernel_type_table[param.kernel_type] << endl;
37
38    if(param.kernel_type == POLY)
39        stream << "degree " << param.degree << endl;
40
41    if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
42        stream << "gamma " << param.gamma << endl;
43
44    if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
45        stream << "coef0 " << param.coef0 << endl;
46
47    int nr_class = model->nr_class;
48    int l = model->l;
49    stream << "nr_class " << nr_class << endl;
50    stream << "total_sv " << l << endl;
51    {
52        stream << "rho";
53        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
54            stream << " " << model->rho[i];
55        stream << endl;
56    }
57
58    if(model->label)
59    {
60        stream << "label";
61        for(int i=0;i<nr_class;i++)
62            stream << " " << model->label[i];
63        stream << endl;
64    }
65
66    if(model->probA) // regression has probA only
67    {
68        stream << "probA";
69        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
70            stream << " " << model->probA[i];
71        stream << endl;
72    }
73    if(model->probB)
74    {
75        stream << "probB";
76        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
77            stream << " " << model->probB[i];
78        stream << endl;
79    }
80
81    if(model->nSV)
82    {
83        stream << "nr_sv";
84        for(int i=0;i<nr_class;i++)
85            stream << " " << model->nSV[i];
86        stream << endl;
87    }
88
89    stream << "SV" << endl;
90    const double * const *sv_coef = model->sv_coef;
91    const svm_node * const *SV = model->SV;
92
93    for(int i=0;i<l;i++)
94    {
95        for(int j=0;j<nr_class-1;j++)
96            stream << sv_coef[j][i] << " ";
97
98        const svm_node *p = SV[i];
99
100        if(param.kernel_type == PRECOMPUTED)
101            stream << "0:" << (int)(p->value) << " ";
102        else
103            while(p->index != -1)
104            {
105                stream << (int)(p->index) << ":" << p->value << " ";
106                p++;
107            }
108        stream << endl;
109    }
110
111    if (!stream.fail())
112        return 0;
113    else
114        return 1;
115}
116
117int svm_save_model_alt(std::string& buffer, const svm_model *model){
118    std::ostringstream strstream;
119    int ret = svm_save_model_alt(strstream, model);
120    buffer = strstream.rdbuf()->str();
121    return ret;
122}
123
124
125svm_model *svm_load_model_alt(std::istream& stream)
126{
127    svm_model *model = Malloc(svm_model,1);
128    svm_parameter& param = model->param;
129    model->rho = NULL;
130    model->probA = NULL;
131    model->probB = NULL;
132    model->label = NULL;
133    model->nSV = NULL;
134
135#if LIBSVM_VERSION >= 313
136    // libsvm seems to ensure ordered numbers for versioning (3.0 was 300,
137    // 3.1 was 310,  3.11 was 311, there was no 3.2, ...)
138    model->sv_indices = NULL;
139#endif
140
141    char cmd[81];
142    stream.width(80);
143    while (stream.good())
144    {
145        stream >> cmd;
146
147        if(strcmp(cmd, "svm_type") == 0)
148        {
149            stream >> cmd;
150            int i;
151            for(i=0; svm_type_table[i]; i++)
152            {
153                if(strcmp(cmd, svm_type_table[i]) == 0)
154                {
155                    param.svm_type=i;
156                    break;
157                }
158            }
159            if(svm_type_table[i] == NULL)
160            {
161                fprintf(stderr, "unknown svm type.\n");
162                free(model->rho);
163                free(model->label);
164                free(model->nSV);
165                free(model);
166                return NULL;
167            }
168        }
169        else if(strcmp(cmd, "kernel_type") == 0)
170        {
171            stream >> cmd;
172            int i;
173            for(i=0;kernel_type_table[i];i++)
174            {
175                if(strcmp(kernel_type_table[i], cmd)==0)
176                {
177                    param.kernel_type=i;
178                    break;
179                }
180            }
181            if(kernel_type_table[i] == NULL)
182            {
183                fprintf(stderr,"unknown kernel function.\n");
184                free(model->rho);
185                free(model->label);
186                free(model->nSV);
187                free(model);
188                return NULL;
189            }
190        }
191        else if(strcmp(cmd,"degree")==0)
192            stream >> param.degree;
193        else if(strcmp(cmd,"gamma")==0)
194            stream >> param.gamma;
195        else if(strcmp(cmd,"coef0")==0)
196            stream >> param.coef0;
197        else if(strcmp(cmd,"nr_class")==0)
198            stream >> model->nr_class;
199        else if(strcmp(cmd,"total_sv")==0)
200            stream >> model->l;
201        else if(strcmp(cmd,"rho")==0)
202        {
203            int n = model->nr_class * (model->nr_class-1)/2;
204            model->rho = Malloc(double,n);
205            string rho_str;
206            for(int i=0;i<n;i++){
207                // Read the number into a string and then use strtod
208                // for proper handling of NaN's
209                stream >> rho_str;
210                model->rho[i] = strtod(rho_str.c_str(), NULL);
211            }
212        }
213        else if(strcmp(cmd,"label")==0)
214        {
215            int n = model->nr_class;
216            model->label = Malloc(int,n);
217            for(int i=0;i<n;i++)
218                stream >> model->label[i];
219        }
220        else if(strcmp(cmd,"probA")==0)
221        {
222            int n = model->nr_class * (model->nr_class-1)/2;
223            model->probA = Malloc(double,n);
224            for(int i=0;i<n;i++)
225                stream >> model->probA[i];
226        }
227        else if(strcmp(cmd,"probB")==0)
228        {
229            int n = model->nr_class * (model->nr_class-1)/2;
230            model->probB = Malloc(double,n);
231            for(int i=0;i<n;i++)
232                stream >> model->probB[i];
233        }
234        else if(strcmp(cmd,"nr_sv")==0)
235        {
236            int n = model->nr_class;
237            model->nSV = Malloc(int,n);
238            for(int i=0;i<n;i++)
239                stream >> model->nSV[i];
240        }
241        else if(strcmp(cmd,"SV")==0)
242        {
243            while(1)
244            {
245                int c = stream.get();
246                if(stream.eof() || c=='\n') break;
247            }
248            break;
249        }
250        else
251        {
252            fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
253            free(model->rho);
254            free(model->label);
255            free(model->nSV);
256            free(model);
257            return NULL;
258        }
259    }
260    if (stream.fail()){
261        free(model->rho);
262        free(model->label);
263        free(model->nSV);
264        free(model);
265        return NULL;
266
267    }
268
269    // read sv_coef and SV
270
271    int elements = 0;
272    long pos = stream.tellg();
273
274    char *p,*endptr,*idx,*val;
275    string str_line;
276    while (!stream.eof() && !stream.fail())
277    {
278        getline(stream, str_line);
279        elements += std::count(str_line.begin(), str_line.end(), ':');
280    }
281
282    elements += model->l;
283
284    stream.clear();
285    stream.seekg(pos, ios::beg);
286
287    int m = model->nr_class - 1;
288    int l = model->l;
289    model->sv_coef = Malloc(double *,m);
290    int i;
291    for(i=0;i<m;i++)
292        model->sv_coef[i] = Malloc(double,l);
293    model->SV = Malloc(svm_node*,l);
294    svm_node *x_space = NULL;
295    if(l>0) x_space = Malloc(svm_node,elements);
296
297    int j=0;
298    char *line;
299    for(i=0;i<l;i++)
300    {
301        getline(stream, str_line);
302        if (str_line.size() == 0)
303            continue;
304
305        line = (char *) Malloc(char, str_line.size() + 1);
306        // Copy the line for strtok.
307        strcpy(line, str_line.c_str());
308
309        model->SV[i] = &x_space[j];
310
311        p = strtok(line, " \t");
312        model->sv_coef[0][i] = strtod(p,&endptr);
313        for(int k=1;k<m;k++)
314        {
315            p = strtok(NULL, " \t");
316            model->sv_coef[k][i] = strtod(p,&endptr);
317        }
318
319        while(1)
320        {
321            idx = strtok(NULL, ":");
322            val = strtok(NULL, " \t");
323
324            if(val == NULL)
325                break;
326            x_space[j].index = (int) strtol(idx,&endptr,10);
327            x_space[j].value = strtod(val,&endptr);
328
329            ++j;
330        }
331        x_space[j++].index = -1;
332        free(line);
333    }
334
335    if (stream.fail())
336        return NULL;
337
338    model->free_sv = 1; // XXX
339    return model;
340}
341
342
343svm_model *svm_load_model_alt(std::string& stream)
344{
345    std::istringstream strstream(stream);
346    return svm_load_model_alt(strstream);
347}
348
349
350std::ostream & svm_node_vector_to_stream(std::ostream & stream, const svm_node * node) {
351    while (node->index != -1) {
352        stream << node->index << ":" << node->value << " ";
353        node++;
354    }
355    stream << node->index << ":" << node->value;
356    return stream;
357}
358
359std::ostream & operator << (std::ostream & stream, const svm_problem & problem) {
360    svm_node * node = NULL;
361    for (int i = 0; i < problem.l; i++) {
362        stream << problem.y[i] << " ";
363        svm_node_vector_to_stream(stream, problem.x[i]);
364        stream << endl;
365    }
366    return stream;
367}
368
369/*
370 * Return a formated string representing a svm data instance (svm_node *)
371 * (useful for debugging)
372 */
373std::string svm_node_to_string(const svm_node * node) {
374    std::ostringstream strstream;
375    strstream.precision(17);
376    svm_node_vector_to_stream(strstream, node);
377    return strstream.rdbuf()->str();
378}
379
380
381#ifdef _MSC_VER
382    #include <float.h>
383    #define isfinite _finite
384#endif
385
386/*!
387 * Check if the value is valid (not a special value in 'TValue').
388 */
389
390inline bool is_valid(double value) {
391    return isfinite(value) && value != numeric_limits<int>::max();
392}
393
394
395svm_node* example_to_svm(const TExample &ex, svm_node* node, double last=0.0) {
396    int index = 1;
397    double value = 0.0;
398    TExample::iterator values_end;
399
400    if (ex.domain->classVar) {
401        values_end = ex.end() - 1;
402    } else {
403        values_end = ex.end();
404    }
405
406    for(TExample::iterator iter = ex.begin(); iter != values_end; iter++, index++) {
407        if(iter->isRegular()) {
408            if(iter->varType == TValue::FLOATVAR) {
409                value = iter->floatV;
410            } else if (iter->varType == TValue::INTVAR) {
411                value = iter->intV;
412            } else {
413                continue;
414            }
415
416            // Only add non zero values (speedup due to sparseness)
417            if (value != 0 && is_valid(value)) {
418                node->index = index;
419                node->value = value;
420                node++;
421            }
422        }
423    }
424
425    // Sentinel
426    node->index = -1;
427    node->value = last;
428    node++;
429    return node;
430}
431
432class SVM_NodeSort{
433public:
434    bool operator() (const svm_node &lhs, const svm_node &rhs) {
435        return lhs.index < rhs.index;
436    }
437};
438
439svm_node* example_to_svm_sparse(const TExample &ex, svm_node* node, double last=0.0, bool include_regular=false) {
440    svm_node *first = node;
441    int index = 1;
442    double value;
443
444    if (include_regular) {
445        node = example_to_svm(ex, node);
446        // Rewind the sentinel
447        node--;
448        assert(node->index == -1);
449        index += ex.domain->variables->size();
450    }
451
452    for (TMetaValues::const_iterator iter=ex.meta.begin(); iter!=ex.meta.end(); iter++) {
453        if(iter->second.isRegular()) {
454            if(iter->second.varType == TValue::FLOATVAR) {
455                value = iter->second.floatV;
456            } else if (iter->second.varType == TValue::INTVAR) {
457                value = iter->second.intV;
458            } else {
459                continue;
460            }
461
462            if (value != 0 && is_valid(value)) {
463                // add the (- meta_id) to index; meta_ids are negative
464                node->index = index - iter->first;
465                node->value = value;
466                node++;
467            }
468        }
469    }
470
471    // sort the nodes by index (metas are not ordered)
472    sort(first, node, SVM_NodeSort());
473
474    // Sentinel
475    node->index = -1;
476    node->value = last;
477    node++;
478    return node;
479}
480
481/*
482 * Precompute Gram matrix row for ex.
483 * Used for prediction when using the PRECOMPUTED kernel.
484 */
485svm_node* example_to_svm_precomputed(const TExample &ex, PExampleGenerator examples, PKernelFunc kernel, svm_node* node) {
486    // Required node with index 0
487    node->index = 0;
488    node->value = 0.0; // Can be any value.
489    node++;
490    int k = 0;
491    PEITERATE(iter, examples){
492        node->index = ++k;
493        node->value = kernel->operator()(*iter, ex);
494        node++;
495    }
496
497    // Sentinel
498    node->index = -1;
499    node++;
500    return node;
501}
502
503int getNumOfElements(const TExample &ex, bool meta=false, bool useNonMeta=false){
504    if (!meta)
505        return std::max(ex.domain->attributes->size() + 1, 2);
506    else {
507        int count = 1; // we need one to indicate the end of a sequence
508        if (useNonMeta)
509            count += ex.domain->attributes->size();
510        for (TMetaValues::const_iterator iter=ex.meta.begin(); iter!=ex.meta.end();iter++)
511            if(iter->second.isRegular())
512                count++;
513        return std::max(count,2);
514    }
515}
516
517int getNumOfElements(PExampleGenerator &examples, bool meta=false, bool useNonMeta=false) {
518    if (!meta)
519        return getNumOfElements(*(examples->begin()), meta) * examples->numberOfExamples();
520    else {
521        int count = 0;
522        for(TExampleGenerator::iterator ex(examples->begin()); ex!=examples->end(); ++ex){
523            count += getNumOfElements(*ex, meta, useNonMeta);
524        }
525        return count;
526    }
527}
528
529svm_node* init_precomputed_problem(svm_problem &problem, PExampleTable examples, TKernelFunc &kernel){
530    int n_examples = examples->numberOfExamples();
531    int i,j;
532    PSymMatrix matrix = mlnew TSymMatrix(n_examples, 0.0);
533    for (i = 0; i < n_examples; i++)
534        for (j = 0; j <= i; j++){
535            matrix->getref(i, j) = kernel(examples->at(i), examples->at(j));
536        }
537    svm_node *x_space = Malloc(svm_node, n_examples * (n_examples + 2));
538    svm_node *node = x_space;
539
540    problem.l = n_examples;
541    problem.x = Malloc(svm_node*, n_examples);
542    problem.y = Malloc(double, n_examples);
543
544    for (i = 0; i < n_examples; i++){
545        problem.x[i] = node;
546        if (examples->domain->classVar->varType == TValue::FLOATVAR)
547            problem.y[i] = examples->at(i).getClass().floatV;
548        else
549            problem.y[i] = examples->at(i).getClass().intV;
550
551        node->index = 0;
552        node->value = i + 1; // instance indices are 1 based
553        node++;
554        for (j = 0; j < n_examples; j++){
555            node->index = j + 1;
556            node->value = matrix->getitem(i, j);
557            node++;
558        }
559        node->index = -1; // sentry
560        node++;
561    }
562    return x_space;
563}
564
565/*
566 * Extract an ExampleTable corresponding to the support vectors from the
567 * trained model.
568 */
569PExampleTable extract_support_vectors(svm_model * model, PExampleTable train_instances)
570{
571    PExampleTable vectors = mlnew TExampleTable(train_instances->domain);
572
573    for (int i = 0; i < model->l; i++) {
574        svm_node *node = model->SV[i];
575        int sv_index = -1;
576        if(model->param.kernel_type != PRECOMPUTED){
577            /* The value of the last node (with index == -1) holds the
578             * index of the training example.
579             */
580            while(node->index != -1) {
581                node++;
582            }
583            sv_index = int(node->value);
584        } else {
585            /* The value of the first node contains the training instance
586             * index (indices 1 based).
587             */
588            sv_index = int(node->value) - 1;
589        }
590        vectors->addExample(mlnew TExample(train_instances->at(sv_index)));
591    }
592
593    return vectors;
594}
595
596
597/*
598 * Consolidate model->SV[1] .. SV[l] vectors into a single contiguous
599 * memory block. The model will 'own' the new *(model->SV) array and
600 * will be freed in destroy_svm_model (model->free_sv == 1). Note that
601 * the original 'x_space' is left intact, it is the caller's
602 * responsibility to free it. However the model->SV array itself is
603 * reused (overwritten).
604 */
605
606void svm_model_consolidate_SV(svm_model * model) {
607    int count = 0;
608    svm_node * x_space = NULL;
609    svm_node * ptr = NULL;
610    svm_node * ptr_source = NULL;
611
612    // Count the number of elements.
613    for (int i = 0; i < model->l; i++) {
614        ptr = model->SV[i];
615        while (ptr->index != -1){
616            count++;
617            ptr++;
618        }
619    }
620    // add the sentinel count
621    count += model->l;
622
623    x_space = Malloc(svm_node, count);
624    ptr = x_space;
625    for (int i = 0; i < model->l; i++) {
626        ptr_source = model->SV[i];
627        model->SV[i] = ptr;
628        while (ptr_source->index != -1) {
629            *(ptr++) = *(ptr_source++);
630        }
631        // copy the sentinel
632        *(ptr++) = *(ptr_source++);
633    }
634    model->free_sv = 1; // XXX
635}
636
637static void print_string_null(const char* s) {}
638
639
640TSVMLearner::TSVMLearner(){
641    svm_type = NU_SVC;
642    kernel_type = RBF;
643    degree = 3;
644    gamma = 0;
645    coef0 = 0;
646    nu = 0.5;
647    cache_size = 250;
648    C = 1;
649    eps = 1e-3f;
650    p = 0.1f;
651    shrinking = 1;
652    probability = 0;
653    verbose = false;
654    nr_weight = 0;
655    weight_label = NULL;
656    weight = NULL;
657};
658
659
660PClassifier TSVMLearner::operator ()(PExampleGenerator examples, const int&){
661    svm_parameter param;
662    svm_problem prob;
663    svm_model* model;
664    svm_node* x_space;
665
666    PDomain domain = examples->domain;
667
668    int classVarType;
669    if (domain->classVar)
670        classVarType = domain->classVar->varType;
671    else {
672        classVarType = TValue::NONE;
673        if(svm_type != ONE_CLASS)
674            raiseError("Domain has no class variable");
675    }
676    if (classVarType == TValue::FLOATVAR && !(svm_type == EPSILON_SVR || svm_type == NU_SVR ||svm_type == ONE_CLASS))
677        raiseError("Domain has continuous class");
678
679    if (kernel_type == PRECOMPUTED && !kernelFunc)
680        raiseError("Custom kernel function not supplied");
681
682    PExampleTable train_data = mlnew TExampleTable(examples, /* owns= */ false);
683
684    if (classVarType == TValue::INTVAR && svm_type != ONE_CLASS) {
685        /* Sort the train data by the class columns so the order of
686         * classVar.values is preserved in libsvm's model.
687         */
688        vector<int> sort_columns(domain->variables->size() - 1);
689        train_data->sort(sort_columns);
690    }
691
692    // Initialize svm parameters
693    param.svm_type = svm_type;
694    param.kernel_type = kernel_type;
695    param.degree = degree;
696    param.gamma = gamma;
697    param.coef0 = coef0;
698    param.nu = nu;
699    param.C = C;
700    param.eps = eps;
701    param.p = p;
702    param.cache_size = cache_size;
703    param.shrinking = shrinking;
704    param.probability = probability;
705    param.nr_weight = nr_weight;
706
707    if (nr_weight > 0) {
708        param.weight_label = Malloc(int, nr_weight);
709        param.weight = Malloc(double, nr_weight);
710        int i;
711        for (i=0; i<nr_weight; i++) {
712            param.weight_label[i] = weight_label[i];
713            param.weight[i] = weight[i];
714        }
715    } else {
716        param.weight_label = NULL;
717        param.weight = NULL;
718    }
719
720    int numElements = getNumOfElements(train_data);
721
722    prob.x = NULL;
723    prob.y = NULL;
724
725    if (kernel_type != PRECOMPUTED)
726        x_space = init_problem(prob, train_data, numElements);
727    else // Compute the matrix using the kernelFunc
728        x_space = init_precomputed_problem(prob, train_data, kernelFunc.getReference());
729
730    if (param.gamma == 0)
731        param.gamma = 1.0f / (float(numElements) / float(prob.l) - 1);
732
733    const char* error = svm_check_parameter(&prob, &param);
734    if (error){
735        free(x_space);
736        free(prob.y);
737        free(prob.x);
738        svm_destroy_param(&param);
739        raiseError("LibSVM parameter error: %s", error);
740    }
741
742    // If a probability model was requested LibSVM uses 5 fold
743    // cross-validation to estimate the prediction errors. This includes a
744    // random shuffle of the data. To make the results reproducible and
745    // consistent with 'svm-train' (which always learns just on one dataset
746    // in a process run) we reset the random seed. This could have unintended
747    // consequences.
748    if (param.probability)
749    {
750        srand(1);
751    }
752    svm_set_print_string_function((verbose)? NULL : &print_string_null);
753
754    model = svm_train(&prob, &param);
755
756    if ((svm_type==C_SVC || svm_type==NU_SVC) && !model->nSV) {
757        svm_free_and_destroy_model(&model);
758        free(x_space);
759        free(prob.x);
760        free(prob.y);
761        svm_destroy_param(&param);
762        raiseError("LibSVM returned no support vectors");
763    }
764
765    svm_destroy_param(&param);
766    free(prob.y);
767    free(prob.x);
768
769    // Consolidate the SV so x_space can be safely freed
770    svm_model_consolidate_SV(model);
771
772    free(x_space);
773
774    PExampleTable supportVectors = extract_support_vectors(model, train_data);
775
776    return PClassifier(createClassifier(domain, model, supportVectors, train_data));
777}
778
779svm_node* TSVMLearner::example_to_svm(const TExample &ex, svm_node* node, double last){
780    return ::example_to_svm(ex, node, last);
781}
782
783svm_node* TSVMLearner::init_problem(svm_problem &problem, PExampleTable examples, int n_elements){
784    problem.l = examples->numberOfExamples();
785    problem.y = Malloc(double, problem.l);
786    problem.x = Malloc(svm_node*, problem.l);
787    svm_node *x_space = Malloc(svm_node, n_elements);
788    svm_node *node = x_space;
789
790    for (int i = 0; i < problem.l; i++){
791        problem.x[i] = node;
792        node = example_to_svm(examples->at(i), node, i);
793        if (examples->domain->classVar)
794            if (examples->domain->classVar->varType == TValue::FLOATVAR)
795                problem.y[i] = examples->at(i).getClass().floatV;
796            else if (examples->domain->classVar->varType == TValue::INTVAR)
797                problem.y[i] = examples->at(i).getClass().intV;
798    }
799
800//  cout << problem << endl;
801
802    return x_space;
803}
804
805int TSVMLearner::getNumOfElements(PExampleGenerator examples){
806    return ::getNumOfElements(examples);
807}
808
809TSVMClassifier* TSVMLearner::createClassifier(
810        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) {
811    PKernelFunc kfunc;
812    if (kernel_type != PRECOMPUTED) {
813        // Classifier does not need the train data and the kernelFunc.
814        examples = NULL;
815        kfunc = NULL;
816    } else {
817        kfunc = kernelFunc;
818    }
819
820    return mlnew TSVMClassifier(domain, model, supportVectors, kfunc, examples);
821}
822
823TSVMLearner::~TSVMLearner(){
824    if(weight_label)
825        free(weight_label);
826
827    if(weight)
828        free(weight);
829}
830
831svm_node* TSVMLearnerSparse::example_to_svm(const TExample &ex, svm_node* node, double last){
832    return ::example_to_svm_sparse(ex, node, last, useNonMeta);
833}
834
835int TSVMLearnerSparse::getNumOfElements(PExampleGenerator examples){
836    return ::getNumOfElements(examples, true, useNonMeta);
837}
838
839TSVMClassifier* TSVMLearnerSparse::createClassifier(
840        PDomain domain, svm_model* model, PExampleTable supportVectors, PExampleTable examples) {
841    PKernelFunc kfunc;
842    if (kernel_type != PRECOMPUTED) {
843        // Classifier does not need the train data and the kernelFunc.
844        examples = NULL;
845        kfunc = NULL;
846    } else {
847        kfunc = kernelFunc;
848    }
849    return mlnew TSVMClassifierSparse(domain, model, useNonMeta, supportVectors, kfunc, examples);
850}
851
852
853TSVMClassifier::TSVMClassifier(
854        PDomain domain, svm_model * model,
855        PExampleTable supportVectors,
856        PKernelFunc kernelFunc,
857        PExampleTable examples
858        ) : TClassifierFD(domain) {
859    this->model = model;
860    this->supportVectors = supportVectors;
861    this->kernelFunc = kernelFunc;
862    this->examples = examples;
863
864    svm_type = svm_get_svm_type(model);
865    kernel_type = model->param.kernel_type;
866
867    if (svm_type == ONE_CLASS) {
868        this->classVar = mlnew TFloatVariable("one class");
869    }
870
871    computesProbabilities = model && svm_check_probability_model(model) && \
872                (svm_type != NU_SVR && svm_type != EPSILON_SVR); // Disable prob. estimation for regression
873
874    int nr_class = svm_get_nr_class(model);
875    int i = 0;
876
877    /* Expose (copy) the model data (coef, rho, probA) to public
878     * class interface.
879     */
880    if (svm_type == C_SVC || svm_type == NU_SVC) {
881        nSV = mlnew TIntList(nr_class); // num of SVs for each class (sum(nSV) == model->l)
882        for(i = 0;i < nr_class; i++) {
883            nSV->at(i) = model->nSV[i];
884        }
885    }
886
887    coef = mlnew TFloatListList(nr_class-1);
888    for(i = 0; i < nr_class - 1; i++) {
889        TFloatList *coefs = mlnew TFloatList(model->l);
890        for(int j = 0;j < model->l; j++) {
891            coefs->at(j) = model->sv_coef[i][j];
892        }
893        coef->at(i) = coefs;
894    }
895
896    // Number of binary classifiers in the model
897    int nr_bin_cls = nr_class * (nr_class - 1) / 2;
898
899    rho = mlnew TFloatList(nr_bin_cls);
900    for(i = 0; i < nr_bin_cls; i++) {
901        rho->at(i) = model->rho[i];
902    }
903
904    if(model->probA) {
905        probA = mlnew TFloatList(nr_bin_cls);
906        if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) {
907            // Regression only has probA
908            probB = mlnew TFloatList(nr_bin_cls);
909        }
910
911        for(i=0; i<nr_bin_cls; i++) {
912            probA->at(i) = model->probA[i];
913            if (model->param.svm_type != NU_SVR && model->param.svm_type != EPSILON_SVR && model->probB) {
914                probB->at(i) = model->probB[i];
915            }
916        }
917    }
918}
919
920
921TSVMClassifier::~TSVMClassifier(){
922    if (model) {
923        svm_free_and_destroy_model(&model);
924    }
925}
926
927
928PDistribution TSVMClassifier::classDistribution(const TExample & example){
929    if(!model)
930        raiseError("No Model");
931
932    if(!computesProbabilities)
933        return TClassifierFD::classDistribution(example);
934
935    int n_elements;
936    if (model->param.kernel_type != PRECOMPUTED)
937        n_elements = getNumOfElements(example);
938    else
939        n_elements = examples->numberOfExamples() + 2;
940
941    int svm_type = svm_get_svm_type(model);
942    int nr_class = svm_get_nr_class(model);
943
944    svm_node *x = Malloc(svm_node, n_elements);
945    try{
946        if (model->param.kernel_type != PRECOMPUTED)
947            example_to_svm(example, x, -1.0);
948        else
949            example_to_svm_precomputed(example, examples, kernelFunc, x);
950    } catch (...) {
951        free(x);
952        throw;
953    }
954
955    int *labels=(int *) malloc(nr_class*sizeof(int));
956    svm_get_labels(model, labels);
957
958    double *prob_estimates = (double *) malloc(nr_class*sizeof(double));;
959    svm_predict_probability(model, x, prob_estimates);
960
961    PDistribution dist = TDistribution::create(example.domain->classVar);
962    for(int i=0; i<nr_class; i++)
963        dist->setint(labels[i], prob_estimates[i]);
964    free(x);
965    free(prob_estimates);
966    free(labels);
967    return dist;
968}
969
970TValue TSVMClassifier::operator()(const TExample & example){
971    if(!model)
972        raiseError("No Model");
973
974    int n_elements;
975    if (model->param.kernel_type != PRECOMPUTED)
976        n_elements = getNumOfElements(example);
977    else
978        n_elements = examples->numberOfExamples() + 2;
979
980    int svm_type = svm_get_svm_type(model);
981    int nr_class = svm_get_nr_class(model);
982
983    svm_node *x = Malloc(svm_node, n_elements);
984    try {
985        if (model->param.kernel_type != PRECOMPUTED)
986            example_to_svm(example, x);
987        else
988            example_to_svm_precomputed(example, examples, kernelFunc, x);
989    } catch (...) {
990        free(x);
991        throw;
992    }
993
994    double v;
995
996    if(svm_check_probability_model(model)){
997        double *prob = (double *) malloc(nr_class*sizeof(double));
998        v = svm_predict_probability(model, x, prob);
999        free(prob);
1000    } else
1001        v = svm_predict(model, x);
1002
1003    free(x);
1004    if(svm_type==NU_SVR || svm_type==EPSILON_SVR || svm_type==ONE_CLASS)
1005        return TValue(v);
1006    else
1007        return TValue(int(v));
1008}
1009
1010PFloatList TSVMClassifier::getDecisionValues(const TExample &example){
1011    if(!model)
1012        raiseError("No Model");
1013
1014    int n_elements;
1015    if (model->param.kernel_type != PRECOMPUTED)
1016        n_elements = getNumOfElements(example);
1017    else
1018        n_elements = examples->numberOfExamples() + 2;
1019
1020    int svm_type=svm_get_svm_type(model);
1021    int nr_class=svm_get_nr_class(model);
1022
1023    svm_node *x = Malloc(svm_node, n_elements);
1024    try {
1025        if (model->param.kernel_type != PRECOMPUTED)
1026            example_to_svm(example, x);
1027        else
1028            example_to_svm_precomputed(example, examples, kernelFunc, x);
1029    } catch (...) {
1030        free(x);
1031        throw;
1032    }
1033
1034    int nDecValues = nr_class*(nr_class-1)/2;
1035    double *dec = (double*) malloc(sizeof(double)*nDecValues);
1036    svm_predict_values(model, x, dec);
1037    PFloatList res = mlnew TFloatList(nDecValues);
1038    for(int i=0; i<nDecValues; i++){
1039        res->at(i) = dec[i];
1040    }
1041    free(x);
1042    free(dec);
1043    return res;
1044}
1045
1046svm_node *TSVMClassifier::example_to_svm(const TExample &ex, svm_node *node, double last){
1047    return ::example_to_svm(ex, node, last);
1048}
1049
1050int TSVMClassifier::getNumOfElements(const TExample& example){
1051    return ::getNumOfElements(example);
1052}
1053svm_node *TSVMClassifierSparse::example_to_svm(const TExample &ex, svm_node *node, double last){
1054    return ::example_to_svm_sparse(ex, node, last, useNonMeta);
1055}
1056
1057int TSVMClassifierSparse::getNumOfElements(const TExample& example){
1058    return ::getNumOfElements(example, true, useNonMeta);
1059}
1060
1061
1062#include "libsvm_interface.ppp"
Note: See TracBrowser for help on using the repository browser.