source: orange/source/orange/libsvm/svm.cpp @ 11664:4bb73da41192

Revision 11664:4bb73da41192, 63.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Updated the included LIBSVM to version 3.17.

RevLine 
[8978]1#include <math.h>
2#include <stdio.h>
3#include <stdlib.h>
4#include <ctype.h>
5#include <float.h>
6#include <string.h>
7#include <stdarg.h>
[11664]8#include <limits.h>
9#include <locale.h>
[8978]10#include "svm.h"
11int libsvm_version = LIBSVM_VERSION;
12typedef float Qfloat;
13typedef signed char schar;
14#ifndef min
15template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
16#endif
17#ifndef max
18template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
19#endif
20template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
21template <class S, class T> static inline void clone(T*& dst, S* src, int n)
22{
23    dst = new T[n];
24    memcpy((void *)dst,(void *)src,sizeof(T)*n);
25}
26static inline double powi(double base, int times)
27{
28    double tmp = base, ret = 1.0;
29
30    for(int t=times; t>0; t/=2)
31    {
32        if(t%2==1) ret*=tmp;
33        tmp = tmp * tmp;
34    }
35    return ret;
36}
37#define INF HUGE_VAL
38#define TAU 1e-12
39#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
40
41static void print_string_stdout(const char *s)
42{
43    fputs(s,stdout);
44    fflush(stdout);
45}
46static void (*svm_print_string) (const char *) = &print_string_stdout;
47#if 1
48static void info(const char *fmt,...)
49{
50    char buf[BUFSIZ];
51    va_list ap;
52    va_start(ap,fmt);
53    vsprintf(buf,fmt,ap);
54    va_end(ap);
55    (*svm_print_string)(buf);
56}
57#else
58static void info(const char *fmt,...) {}
59#endif
60
61//
62// Kernel Cache
63//
64// l is the number of total data items
65// size is the cache size limit in bytes
66//
67class Cache
68{
69public:
70    Cache(int l,long int size);
71    ~Cache();
72
73    // request data [0,len)
74    // return some position p where [p,len) need to be filled
75    // (p >= len if nothing needs to be filled)
76    int get_data(const int index, Qfloat **data, int len);
77    void swap_index(int i, int j); 
78private:
79    int l;
80    long int size;
81    struct head_t
82    {
83        head_t *prev, *next;    // a circular list
84        Qfloat *data;
85        int len;        // data[0,len) is cached in this entry
86    };
87
88    head_t *head;
89    head_t lru_head;
90    void lru_delete(head_t *h);
91    void lru_insert(head_t *h);
92};
93
94Cache::Cache(int l_,long int size_):l(l_),size(size_)
95{
96    head = (head_t *)calloc(l,sizeof(head_t));  // initialized to 0
97    size /= sizeof(Qfloat);
98    size -= l * sizeof(head_t) / sizeof(Qfloat);
99    size = max(size, 2 * (long int) l); // cache must be large enough for two columns
100    lru_head.next = lru_head.prev = &lru_head;
101}
102
103Cache::~Cache()
104{
105    for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
106        free(h->data);
107    free(head);
108}
109
110void Cache::lru_delete(head_t *h)
111{
112    // delete from current location
113    h->prev->next = h->next;
114    h->next->prev = h->prev;
115}
116
117void Cache::lru_insert(head_t *h)
118{
119    // insert to last position
120    h->next = &lru_head;
121    h->prev = lru_head.prev;
122    h->prev->next = h;
123    h->next->prev = h;
124}
125
126int Cache::get_data(const int index, Qfloat **data, int len)
127{
128    head_t *h = &head[index];
129    if(h->len) lru_delete(h);
130    int more = len - h->len;
131
132    if(more > 0)
133    {
134        // free old space
135        while(size < more)
136        {
137            head_t *old = lru_head.next;
138            lru_delete(old);
139            free(old->data);
140            size += old->len;
141            old->data = 0;
142            old->len = 0;
143        }
144
145        // allocate new space
146        h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
147        size -= more;
148        swap(h->len,len);
149    }
150
151    lru_insert(h);
152    *data = h->data;
153    return len;
154}
155
156void Cache::swap_index(int i, int j)
157{
158    if(i==j) return;
159
160    if(head[i].len) lru_delete(&head[i]);
161    if(head[j].len) lru_delete(&head[j]);
162    swap(head[i].data,head[j].data);
163    swap(head[i].len,head[j].len);
164    if(head[i].len) lru_insert(&head[i]);
165    if(head[j].len) lru_insert(&head[j]);
166
167    if(i>j) swap(i,j);
168    for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
169    {
170        if(h->len > i)
171        {
172            if(h->len > j)
173                swap(h->data[i],h->data[j]);
174            else
175            {
176                // give up
177                lru_delete(h);
178                free(h->data);
179                size += h->len;
180                h->data = 0;
181                h->len = 0;
182            }
183        }
184    }
185}
186
187//
188// Kernel evaluation
189//
190// the static method k_function is for doing single kernel evaluation
191// the constructor of Kernel prepares to calculate the l*l kernel matrix
192// the member function get_Q is for getting one column from the Q Matrix
193//
194class QMatrix {
195public:
196    virtual Qfloat *get_Q(int column, int len) const = 0;
197    virtual double *get_QD() const = 0;
198    virtual void swap_index(int i, int j) const = 0;
199    virtual ~QMatrix() {}
200};
201
202class Kernel: public QMatrix {
203public:
204    Kernel(int l, svm_node * const * x, const svm_parameter& param);
205    virtual ~Kernel();
206
207    static double k_function(const svm_node *x, const svm_node *y,
208                 const svm_parameter& param);
209    virtual Qfloat *get_Q(int column, int len) const = 0;
210    virtual double *get_QD() const = 0;
211    virtual void swap_index(int i, int j) const // no so const...
212    {
213        swap(x[i],x[j]);
214        if(x_square) swap(x_square[i],x_square[j]);
215    }
216protected:
217
218    double (Kernel::*kernel_function)(int i, int j) const;
219
220private:
221    const svm_node **x;
222    double *x_square;
223
224    // svm_parameter
225    const int kernel_type;
226    const int degree;
227    const double gamma;
228    const double coef0;
229
230    static double dot(const svm_node *px, const svm_node *py);
231    double kernel_linear(int i, int j) const
232    {
233        return dot(x[i],x[j]);
234    }
235    double kernel_poly(int i, int j) const
236    {
237        return powi(gamma*dot(x[i],x[j])+coef0,degree);
238    }
239    double kernel_rbf(int i, int j) const
240    {
241        return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
242    }
243    double kernel_sigmoid(int i, int j) const
244    {
245        return tanh(gamma*dot(x[i],x[j])+coef0);
246    }
247    double kernel_precomputed(int i, int j) const
248    {
249        return x[i][(int)(x[j][0].value)].value;
250    }
251};
252
253Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
254:kernel_type(param.kernel_type), degree(param.degree),
255 gamma(param.gamma), coef0(param.coef0)
256{
257    switch(kernel_type)
258    {
259        case LINEAR:
260            kernel_function = &Kernel::kernel_linear;
261            break;
262        case POLY:
263            kernel_function = &Kernel::kernel_poly;
264            break;
265        case RBF:
266            kernel_function = &Kernel::kernel_rbf;
267            break;
268        case SIGMOID:
269            kernel_function = &Kernel::kernel_sigmoid;
270            break;
271        case PRECOMPUTED:
272            kernel_function = &Kernel::kernel_precomputed;
273            break;
274    }
275
276    clone(x,x_,l);
277
278    if(kernel_type == RBF)
279    {
280        x_square = new double[l];
281        for(int i=0;i<l;i++)
282            x_square[i] = dot(x[i],x[i]);
283    }
284    else
285        x_square = 0;
286}
287
288Kernel::~Kernel()
289{
290    delete[] x;
291    delete[] x_square;
292}
293
294double Kernel::dot(const svm_node *px, const svm_node *py)
295{
296    double sum = 0;
297    while(px->index != -1 && py->index != -1)
298    {
299        if(px->index == py->index)
300        {
301            sum += px->value * py->value;
302            ++px;
303            ++py;
304        }
305        else
306        {
307            if(px->index > py->index)
308                ++py;
309            else
310                ++px;
311        }           
312    }
313    return sum;
314}
315
316double Kernel::k_function(const svm_node *x, const svm_node *y,
317              const svm_parameter& param)
318{
319    switch(param.kernel_type)
320    {
321        case LINEAR:
322            return dot(x,y);
323        case POLY:
324            return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
325        case RBF:
326        {
327            double sum = 0;
328            while(x->index != -1 && y->index !=-1)
329            {
330                if(x->index == y->index)
331                {
332                    double d = x->value - y->value;
333                    sum += d*d;
334                    ++x;
335                    ++y;
336                }
337                else
338                {
339                    if(x->index > y->index)
340                    {   
341                        sum += y->value * y->value;
342                        ++y;
343                    }
344                    else
345                    {
346                        sum += x->value * x->value;
347                        ++x;
348                    }
349                }
350            }
351
352            while(x->index != -1)
353            {
354                sum += x->value * x->value;
355                ++x;
356            }
357
358            while(y->index != -1)
359            {
360                sum += y->value * y->value;
361                ++y;
362            }
363           
364            return exp(-param.gamma*sum);
365        }
366        case SIGMOID:
367            return tanh(param.gamma*dot(x,y)+param.coef0);
368        case PRECOMPUTED:  //x: test (validation), y: SV
369            return x[(int)(y->value)].value;
370        default:
371            return 0;  // Unreachable
372    }
373}
374
375// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
376// Solves:
377//
378//  min 0.5(\alpha^T Q \alpha) + p^T \alpha
379//
380//      y^T \alpha = \delta
381//      y_i = +1 or -1
382//      0 <= alpha_i <= Cp for y_i = 1
383//      0 <= alpha_i <= Cn for y_i = -1
384//
385// Given:
386//
387//  Q, p, y, Cp, Cn, and an initial feasible point \alpha
388//  l is the size of vectors and matrices
389//  eps is the stopping tolerance
390//
391// solution will be put in \alpha, objective value will be put in obj
392//
393class Solver {
394public:
395    Solver() {};
396    virtual ~Solver() {};
397
398    struct SolutionInfo {
399        double obj;
400        double rho;
401        double upper_bound_p;
402        double upper_bound_n;
403        double r;   // for Solver_NU
404    };
405
406    void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
407           double *alpha_, double Cp, double Cn, double eps,
408           SolutionInfo* si, int shrinking);
409protected:
410    int active_size;
411    schar *y;
412    double *G;      // gradient of objective function
413    enum { LOWER_BOUND, UPPER_BOUND, FREE };
414    char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
415    double *alpha;
416    const QMatrix *Q;
417    const double *QD;
418    double eps;
419    double Cp,Cn;
420    double *p;
421    int *active_set;
422    double *G_bar;      // gradient, if we treat free variables as 0
423    int l;
424    bool unshrink;  // XXX
425
426    double get_C(int i)
427    {
428        return (y[i] > 0)? Cp : Cn;
429    }
430    void update_alpha_status(int i)
431    {
432        if(alpha[i] >= get_C(i))
433            alpha_status[i] = UPPER_BOUND;
434        else if(alpha[i] <= 0)
435            alpha_status[i] = LOWER_BOUND;
436        else alpha_status[i] = FREE;
437    }
438    bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
439    bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
440    bool is_free(int i) { return alpha_status[i] == FREE; }
441    void swap_index(int i, int j);
442    void reconstruct_gradient();
443    virtual int select_working_set(int &i, int &j);
444    virtual double calculate_rho();
445    virtual void do_shrinking();
446private:
447    bool be_shrunk(int i, double Gmax1, double Gmax2); 
448};
449
450void Solver::swap_index(int i, int j)
451{
452    Q->swap_index(i,j);
453    swap(y[i],y[j]);
454    swap(G[i],G[j]);
455    swap(alpha_status[i],alpha_status[j]);
456    swap(alpha[i],alpha[j]);
457    swap(p[i],p[j]);
458    swap(active_set[i],active_set[j]);
459    swap(G_bar[i],G_bar[j]);
460}
461
462void Solver::reconstruct_gradient()
463{
464    // reconstruct inactive elements of G from G_bar and free variables
465
466    if(active_size == l) return;
467
468    int i,j;
469    int nr_free = 0;
470
471    for(j=active_size;j<l;j++)
472        G[j] = G_bar[j] + p[j];
473
474    for(j=0;j<active_size;j++)
475        if(is_free(j))
476            nr_free++;
477
478    if(2*nr_free < active_size)
[11664]479        info("\nWARNING: using -h 0 may be faster\n");
[8978]480
481    if (nr_free*l > 2*active_size*(l-active_size))
482    {
483        for(i=active_size;i<l;i++)
484        {
485            const Qfloat *Q_i = Q->get_Q(i,active_size);
486            for(j=0;j<active_size;j++)
487                if(is_free(j))
488                    G[i] += alpha[j] * Q_i[j];
489        }
490    }
491    else
492    {
493        for(i=0;i<active_size;i++)
494            if(is_free(i))
495            {
496                const Qfloat *Q_i = Q->get_Q(i,l);
497                double alpha_i = alpha[i];
498                for(j=active_size;j<l;j++)
499                    G[j] += alpha_i * Q_i[j];
500            }
501    }
502}
503
504void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
505           double *alpha_, double Cp, double Cn, double eps,
506           SolutionInfo* si, int shrinking)
507{
508    this->l = l;
509    this->Q = &Q;
510    QD=Q.get_QD();
511    clone(p, p_,l);
512    clone(y, y_,l);
513    clone(alpha,alpha_,l);
514    this->Cp = Cp;
515    this->Cn = Cn;
516    this->eps = eps;
517    unshrink = false;
518
519    // initialize alpha_status
520    {
521        alpha_status = new char[l];
522        for(int i=0;i<l;i++)
523            update_alpha_status(i);
524    }
525
526    // initialize active set (for shrinking)
527    {
528        active_set = new int[l];
529        for(int i=0;i<l;i++)
530            active_set[i] = i;
531        active_size = l;
532    }
533
534    // initialize gradient
535    {
536        G = new double[l];
537        G_bar = new double[l];
538        int i;
539        for(i=0;i<l;i++)
540        {
541            G[i] = p[i];
542            G_bar[i] = 0;
543        }
544        for(i=0;i<l;i++)
545            if(!is_lower_bound(i))
546            {
547                const Qfloat *Q_i = Q.get_Q(i,l);
548                double alpha_i = alpha[i];
549                int j;
550                for(j=0;j<l;j++)
551                    G[j] += alpha_i*Q_i[j];
552                if(is_upper_bound(i))
553                    for(j=0;j<l;j++)
554                        G_bar[j] += get_C(i) * Q_i[j];
555            }
556    }
557
558    // optimization step
559
560    int iter = 0;
[11664]561    int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
[8978]562    int counter = min(l,1000)+1;
[11664]563   
564    while(iter < max_iter)
[8978]565    {
566        // show progress and do shrinking
567
568        if(--counter == 0)
569        {
570            counter = min(l,1000);
571            if(shrinking) do_shrinking();
572            info(".");
573        }
574
575        int i,j;
576        if(select_working_set(i,j)!=0)
577        {
578            // reconstruct the whole gradient
579            reconstruct_gradient();
580            // reset active set size and check
581            active_size = l;
582            info("*");
583            if(select_working_set(i,j)!=0)
584                break;
585            else
586                counter = 1;    // do shrinking next iteration
587        }
588       
589        ++iter;
590
591        // update alpha[i] and alpha[j], handle bounds carefully
592       
593        const Qfloat *Q_i = Q.get_Q(i,active_size);
594        const Qfloat *Q_j = Q.get_Q(j,active_size);
595
596        double C_i = get_C(i);
597        double C_j = get_C(j);
598
599        double old_alpha_i = alpha[i];
600        double old_alpha_j = alpha[j];
601
602        if(y[i]!=y[j])
603        {
604            double quad_coef = QD[i]+QD[j]+2*Q_i[j];
605            if (quad_coef <= 0)
606                quad_coef = TAU;
607            double delta = (-G[i]-G[j])/quad_coef;
608            double diff = alpha[i] - alpha[j];
609            alpha[i] += delta;
610            alpha[j] += delta;
611           
612            if(diff > 0)
613            {
614                if(alpha[j] < 0)
615                {
616                    alpha[j] = 0;
617                    alpha[i] = diff;
618                }
619            }
620            else
621            {
622                if(alpha[i] < 0)
623                {
624                    alpha[i] = 0;
625                    alpha[j] = -diff;
626                }
627            }
628            if(diff > C_i - C_j)
629            {
630                if(alpha[i] > C_i)
631                {
632                    alpha[i] = C_i;
633                    alpha[j] = C_i - diff;
634                }
635            }
636            else
637            {
638                if(alpha[j] > C_j)
639                {
640                    alpha[j] = C_j;
641                    alpha[i] = C_j + diff;
642                }
643            }
644        }
645        else
646        {
647            double quad_coef = QD[i]+QD[j]-2*Q_i[j];
648            if (quad_coef <= 0)
649                quad_coef = TAU;
650            double delta = (G[i]-G[j])/quad_coef;
651            double sum = alpha[i] + alpha[j];
652            alpha[i] -= delta;
653            alpha[j] += delta;
654
655            if(sum > C_i)
656            {
657                if(alpha[i] > C_i)
658                {
659                    alpha[i] = C_i;
660                    alpha[j] = sum - C_i;
661                }
662            }
663            else
664            {
665                if(alpha[j] < 0)
666                {
667                    alpha[j] = 0;
668                    alpha[i] = sum;
669                }
670            }
671            if(sum > C_j)
672            {
673                if(alpha[j] > C_j)
674                {
675                    alpha[j] = C_j;
676                    alpha[i] = sum - C_j;
677                }
678            }
679            else
680            {
681                if(alpha[i] < 0)
682                {
683                    alpha[i] = 0;
684                    alpha[j] = sum;
685                }
686            }
687        }
688
689        // update G
690
691        double delta_alpha_i = alpha[i] - old_alpha_i;
692        double delta_alpha_j = alpha[j] - old_alpha_j;
693       
694        for(int k=0;k<active_size;k++)
695        {
696            G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
697        }
698
699        // update alpha_status and G_bar
700
701        {
702            bool ui = is_upper_bound(i);
703            bool uj = is_upper_bound(j);
704            update_alpha_status(i);
705            update_alpha_status(j);
706            int k;
707            if(ui != is_upper_bound(i))
708            {
709                Q_i = Q.get_Q(i,l);
710                if(ui)
711                    for(k=0;k<l;k++)
712                        G_bar[k] -= C_i * Q_i[k];
713                else
714                    for(k=0;k<l;k++)
715                        G_bar[k] += C_i * Q_i[k];
716            }
717
718            if(uj != is_upper_bound(j))
719            {
720                Q_j = Q.get_Q(j,l);
721                if(uj)
722                    for(k=0;k<l;k++)
723                        G_bar[k] -= C_j * Q_j[k];
724                else
725                    for(k=0;k<l;k++)
726                        G_bar[k] += C_j * Q_j[k];
727            }
728        }
729    }
730
[11664]731    if(iter >= max_iter)
732    {
733        if(active_size < l)
734        {
735            // reconstruct the whole gradient to calculate objective value
736            reconstruct_gradient();
737            active_size = l;
738            info("*");
739        }
740        fprintf(stderr,"\nWARNING: reaching max number of iterations\n");
741    }
742
[8978]743    // calculate rho
744
745    si->rho = calculate_rho();
746
747    // calculate objective value
748    {
749        double v = 0;
750        int i;
751        for(i=0;i<l;i++)
752            v += alpha[i] * (G[i] + p[i]);
753
754        si->obj = v/2;
755    }
756
757    // put back the solution
758    {
759        for(int i=0;i<l;i++)
760            alpha_[active_set[i]] = alpha[i];
761    }
762
763    // juggle everything back
764    /*{
765        for(int i=0;i<l;i++)
766            while(active_set[i] != i)
767                swap_index(i,active_set[i]);
768                // or Q.swap_index(i,active_set[i]);
769    }*/
770
771    si->upper_bound_p = Cp;
772    si->upper_bound_n = Cn;
773
774    info("\noptimization finished, #iter = %d\n",iter);
775
776    delete[] p;
777    delete[] y;
778    delete[] alpha;
779    delete[] alpha_status;
780    delete[] active_set;
781    delete[] G;
782    delete[] G_bar;
783}
784
785// return 1 if already optimal, return 0 otherwise
786int Solver::select_working_set(int &out_i, int &out_j)
787{
788    // return i,j such that
789    // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
790    // j: minimizes the decrease of obj value
791    //    (if quadratic coefficeint <= 0, replace it with tau)
792    //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
793   
794    double Gmax = -INF;
795    double Gmax2 = -INF;
796    int Gmax_idx = -1;
797    int Gmin_idx = -1;
798    double obj_diff_min = INF;
799
800    for(int t=0;t<active_size;t++)
801        if(y[t]==+1)   
802        {
803            if(!is_upper_bound(t))
804                if(-G[t] >= Gmax)
805                {
806                    Gmax = -G[t];
807                    Gmax_idx = t;
808                }
809        }
810        else
811        {
812            if(!is_lower_bound(t))
813                if(G[t] >= Gmax)
814                {
815                    Gmax = G[t];
816                    Gmax_idx = t;
817                }
818        }
819
820    int i = Gmax_idx;
821    const Qfloat *Q_i = NULL;
822    if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
823        Q_i = Q->get_Q(i,active_size);
824
825    for(int j=0;j<active_size;j++)
826    {
827        if(y[j]==+1)
828        {
829            if (!is_lower_bound(j))
830            {
831                double grad_diff=Gmax+G[j];
832                if (G[j] >= Gmax2)
833                    Gmax2 = G[j];
834                if (grad_diff > 0)
835                {
836                    double obj_diff; 
837                    double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
838                    if (quad_coef > 0)
839                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
840                    else
841                        obj_diff = -(grad_diff*grad_diff)/TAU;
842
843                    if (obj_diff <= obj_diff_min)
844                    {
845                        Gmin_idx=j;
846                        obj_diff_min = obj_diff;
847                    }
848                }
849            }
850        }
851        else
852        {
853            if (!is_upper_bound(j))
854            {
855                double grad_diff= Gmax-G[j];
856                if (-G[j] >= Gmax2)
857                    Gmax2 = -G[j];
858                if (grad_diff > 0)
859                {
860                    double obj_diff; 
861                    double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
862                    if (quad_coef > 0)
863                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
864                    else
865                        obj_diff = -(grad_diff*grad_diff)/TAU;
866
867                    if (obj_diff <= obj_diff_min)
868                    {
869                        Gmin_idx=j;
870                        obj_diff_min = obj_diff;
871                    }
872                }
873            }
874        }
875    }
876
877    if(Gmax+Gmax2 < eps)
878        return 1;
879
880    out_i = Gmax_idx;
881    out_j = Gmin_idx;
882    return 0;
883}
884
885bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
886{
887    if(is_upper_bound(i))
888    {
889        if(y[i]==+1)
890            return(-G[i] > Gmax1);
891        else
892            return(-G[i] > Gmax2);
893    }
894    else if(is_lower_bound(i))
895    {
896        if(y[i]==+1)
897            return(G[i] > Gmax2);
898        else   
899            return(G[i] > Gmax1);
900    }
901    else
902        return(false);
903}
904
905void Solver::do_shrinking()
906{
907    int i;
908    double Gmax1 = -INF;        // max { -y_i * grad(f)_i | i in I_up(\alpha) }
909    double Gmax2 = -INF;        // max { y_i * grad(f)_i | i in I_low(\alpha) }
910
911    // find maximal violating pair first
912    for(i=0;i<active_size;i++)
913    {
914        if(y[i]==+1)   
915        {
916            if(!is_upper_bound(i)) 
917            {
918                if(-G[i] >= Gmax1)
919                    Gmax1 = -G[i];
920            }
921            if(!is_lower_bound(i)) 
922            {
923                if(G[i] >= Gmax2)
924                    Gmax2 = G[i];
925            }
926        }
927        else   
928        {
929            if(!is_upper_bound(i)) 
930            {
931                if(-G[i] >= Gmax2)
932                    Gmax2 = -G[i];
933            }
934            if(!is_lower_bound(i)) 
935            {
936                if(G[i] >= Gmax1)
937                    Gmax1 = G[i];
938            }
939        }
940    }
941
942    if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
943    {
944        unshrink = true;
945        reconstruct_gradient();
946        active_size = l;
947        info("*");
948    }
949
950    for(i=0;i<active_size;i++)
951        if (be_shrunk(i, Gmax1, Gmax2))
952        {
953            active_size--;
954            while (active_size > i)
955            {
956                if (!be_shrunk(active_size, Gmax1, Gmax2))
957                {
958                    swap_index(i,active_size);
959                    break;
960                }
961                active_size--;
962            }
963        }
964}
965
966double Solver::calculate_rho()
967{
968    double r;
969    int nr_free = 0;
970    double ub = INF, lb = -INF, sum_free = 0;
971    for(int i=0;i<active_size;i++)
972    {
973        double yG = y[i]*G[i];
974
975        if(is_upper_bound(i))
976        {
977            if(y[i]==-1)
978                ub = min(ub,yG);
979            else
980                lb = max(lb,yG);
981        }
982        else if(is_lower_bound(i))
983        {
984            if(y[i]==+1)
985                ub = min(ub,yG);
986            else
987                lb = max(lb,yG);
988        }
989        else
990        {
991            ++nr_free;
992            sum_free += yG;
993        }
994    }
995
996    if(nr_free>0)
997        r = sum_free/nr_free;
998    else
999        r = (ub+lb)/2;
1000
1001    return r;
1002}
1003
1004//
1005// Solver for nu-svm classification and regression
1006//
1007// additional constraint: e^T \alpha = constant
1008//
[11664]1009class Solver_NU: public Solver
[8978]1010{
1011public:
1012    Solver_NU() {}
1013    void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
1014           double *alpha, double Cp, double Cn, double eps,
1015           SolutionInfo* si, int shrinking)
1016    {
1017        this->si = si;
1018        Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1019    }
1020private:
1021    SolutionInfo *si;
1022    int select_working_set(int &i, int &j);
1023    double calculate_rho();
1024    bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1025    void do_shrinking();
1026};
1027
1028// return 1 if already optimal, return 0 otherwise
1029int Solver_NU::select_working_set(int &out_i, int &out_j)
1030{
1031    // return i,j such that y_i = y_j and
1032    // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1033    // j: minimizes the decrease of obj value
1034    //    (if quadratic coefficeint <= 0, replace it with tau)
1035    //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1036
1037    double Gmaxp = -INF;
1038    double Gmaxp2 = -INF;
1039    int Gmaxp_idx = -1;
1040
1041    double Gmaxn = -INF;
1042    double Gmaxn2 = -INF;
1043    int Gmaxn_idx = -1;
1044
1045    int Gmin_idx = -1;
1046    double obj_diff_min = INF;
1047
1048    for(int t=0;t<active_size;t++)
1049        if(y[t]==+1)
1050        {
1051            if(!is_upper_bound(t))
1052                if(-G[t] >= Gmaxp)
1053                {
1054                    Gmaxp = -G[t];
1055                    Gmaxp_idx = t;
1056                }
1057        }
1058        else
1059        {
1060            if(!is_lower_bound(t))
1061                if(G[t] >= Gmaxn)
1062                {
1063                    Gmaxn = G[t];
1064                    Gmaxn_idx = t;
1065                }
1066        }
1067
1068    int ip = Gmaxp_idx;
1069    int in = Gmaxn_idx;
1070    const Qfloat *Q_ip = NULL;
1071    const Qfloat *Q_in = NULL;
1072    if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1073        Q_ip = Q->get_Q(ip,active_size);
1074    if(in != -1)
1075        Q_in = Q->get_Q(in,active_size);
1076
1077    for(int j=0;j<active_size;j++)
1078    {
1079        if(y[j]==+1)
1080        {
1081            if (!is_lower_bound(j)) 
1082            {
1083                double grad_diff=Gmaxp+G[j];
1084                if (G[j] >= Gmaxp2)
1085                    Gmaxp2 = G[j];
1086                if (grad_diff > 0)
1087                {
1088                    double obj_diff; 
1089                    double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
1090                    if (quad_coef > 0)
1091                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
1092                    else
1093                        obj_diff = -(grad_diff*grad_diff)/TAU;
1094
1095                    if (obj_diff <= obj_diff_min)
1096                    {
1097                        Gmin_idx=j;
1098                        obj_diff_min = obj_diff;
1099                    }
1100                }
1101            }
1102        }
1103        else
1104        {
1105            if (!is_upper_bound(j))
1106            {
1107                double grad_diff=Gmaxn-G[j];
1108                if (-G[j] >= Gmaxn2)
1109                    Gmaxn2 = -G[j];
1110                if (grad_diff > 0)
1111                {
1112                    double obj_diff; 
1113                    double quad_coef = QD[in]+QD[j]-2*Q_in[j];
1114                    if (quad_coef > 0)
1115                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
1116                    else
1117                        obj_diff = -(grad_diff*grad_diff)/TAU;
1118
1119                    if (obj_diff <= obj_diff_min)
1120                    {
1121                        Gmin_idx=j;
1122                        obj_diff_min = obj_diff;
1123                    }
1124                }
1125            }
1126        }
1127    }
1128
1129    if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1130        return 1;
1131
1132    if (y[Gmin_idx] == +1)
1133        out_i = Gmaxp_idx;
1134    else
1135        out_i = Gmaxn_idx;
1136    out_j = Gmin_idx;
1137
1138    return 0;
1139}
1140
1141bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1142{
1143    if(is_upper_bound(i))
1144    {
1145        if(y[i]==+1)
1146            return(-G[i] > Gmax1);
1147        else   
1148            return(-G[i] > Gmax4);
1149    }
1150    else if(is_lower_bound(i))
1151    {
1152        if(y[i]==+1)
1153            return(G[i] > Gmax2);
1154        else   
1155            return(G[i] > Gmax3);
1156    }
1157    else
1158        return(false);
1159}
1160
1161void Solver_NU::do_shrinking()
1162{
1163    double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1164    double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1165    double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1166    double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1167
1168    // find maximal violating pair first
1169    int i;
1170    for(i=0;i<active_size;i++)
1171    {
1172        if(!is_upper_bound(i))
1173        {
1174            if(y[i]==+1)
1175            {
1176                if(-G[i] > Gmax1) Gmax1 = -G[i];
1177            }
1178            else    if(-G[i] > Gmax4) Gmax4 = -G[i];
1179        }
1180        if(!is_lower_bound(i))
1181        {
1182            if(y[i]==+1)
1183            {   
1184                if(G[i] > Gmax2) Gmax2 = G[i];
1185            }
1186            else    if(G[i] > Gmax3) Gmax3 = G[i];
1187        }
1188    }
1189
1190    if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
1191    {
1192        unshrink = true;
1193        reconstruct_gradient();
1194        active_size = l;
1195    }
1196
1197    for(i=0;i<active_size;i++)
1198        if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1199        {
1200            active_size--;
1201            while (active_size > i)
1202            {
1203                if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1204                {
1205                    swap_index(i,active_size);
1206                    break;
1207                }
1208                active_size--;
1209            }
1210        }
1211}
1212
1213double Solver_NU::calculate_rho()
1214{
1215    int nr_free1 = 0,nr_free2 = 0;
1216    double ub1 = INF, ub2 = INF;
1217    double lb1 = -INF, lb2 = -INF;
1218    double sum_free1 = 0, sum_free2 = 0;
1219
1220    for(int i=0;i<active_size;i++)
1221    {
1222        if(y[i]==+1)
1223        {
1224            if(is_upper_bound(i))
1225                lb1 = max(lb1,G[i]);
1226            else if(is_lower_bound(i))
1227                ub1 = min(ub1,G[i]);
1228            else
1229            {
1230                ++nr_free1;
1231                sum_free1 += G[i];
1232            }
1233        }
1234        else
1235        {
1236            if(is_upper_bound(i))
1237                lb2 = max(lb2,G[i]);
1238            else if(is_lower_bound(i))
1239                ub2 = min(ub2,G[i]);
1240            else
1241            {
1242                ++nr_free2;
1243                sum_free2 += G[i];
1244            }
1245        }
1246    }
1247
1248    double r1,r2;
1249    if(nr_free1 > 0)
1250        r1 = sum_free1/nr_free1;
1251    else
1252        r1 = (ub1+lb1)/2;
1253   
1254    if(nr_free2 > 0)
1255        r2 = sum_free2/nr_free2;
1256    else
1257        r2 = (ub2+lb2)/2;
1258   
1259    si->r = (r1+r2)/2;
1260    return (r1-r2)/2;
1261}
1262
1263//
1264// Q matrices for various formulations
1265//
1266class SVC_Q: public Kernel
1267{ 
1268public:
1269    SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1270    :Kernel(prob.l, prob.x, param)
1271    {
1272        clone(y,y_,prob.l);
1273        cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1274        QD = new double[prob.l];
1275        for(int i=0;i<prob.l;i++)
1276            QD[i] = (this->*kernel_function)(i,i);
1277    }
1278   
1279    Qfloat *get_Q(int i, int len) const
1280    {
1281        Qfloat *data;
1282        int start, j;
1283        if((start = cache->get_data(i,&data,len)) < len)
1284        {
1285            for(j=start;j<len;j++)
1286                data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1287        }
1288        return data;
1289    }
1290
1291    double *get_QD() const
1292    {
1293        return QD;
1294    }
1295
1296    void swap_index(int i, int j) const
1297    {
1298        cache->swap_index(i,j);
1299        Kernel::swap_index(i,j);
1300        swap(y[i],y[j]);
1301        swap(QD[i],QD[j]);
1302    }
1303
1304    ~SVC_Q()
1305    {
1306        delete[] y;
1307        delete cache;
1308        delete[] QD;
1309    }
1310private:
1311    schar *y;
1312    Cache *cache;
1313    double *QD;
1314};
1315
1316class ONE_CLASS_Q: public Kernel
1317{
1318public:
1319    ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1320    :Kernel(prob.l, prob.x, param)
1321    {
1322        cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1323        QD = new double[prob.l];
1324        for(int i=0;i<prob.l;i++)
1325            QD[i] = (this->*kernel_function)(i,i);
1326    }
1327   
1328    Qfloat *get_Q(int i, int len) const
1329    {
1330        Qfloat *data;
1331        int start, j;
1332        if((start = cache->get_data(i,&data,len)) < len)
1333        {
1334            for(j=start;j<len;j++)
1335                data[j] = (Qfloat)(this->*kernel_function)(i,j);
1336        }
1337        return data;
1338    }
1339
1340    double *get_QD() const
1341    {
1342        return QD;
1343    }
1344
1345    void swap_index(int i, int j) const
1346    {
1347        cache->swap_index(i,j);
1348        Kernel::swap_index(i,j);
1349        swap(QD[i],QD[j]);
1350    }
1351
1352    ~ONE_CLASS_Q()
1353    {
1354        delete cache;
1355        delete[] QD;
1356    }
1357private:
1358    Cache *cache;
1359    double *QD;
1360};
1361
1362class SVR_Q: public Kernel
1363{ 
1364public:
1365    SVR_Q(const svm_problem& prob, const svm_parameter& param)
1366    :Kernel(prob.l, prob.x, param)
1367    {
1368        l = prob.l;
1369        cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1370        QD = new double[2*l];
1371        sign = new schar[2*l];
1372        index = new int[2*l];
1373        for(int k=0;k<l;k++)
1374        {
1375            sign[k] = 1;
1376            sign[k+l] = -1;
1377            index[k] = k;
1378            index[k+l] = k;
1379            QD[k] = (this->*kernel_function)(k,k);
1380            QD[k+l] = QD[k];
1381        }
1382        buffer[0] = new Qfloat[2*l];
1383        buffer[1] = new Qfloat[2*l];
1384        next_buffer = 0;
1385    }
1386
1387    void swap_index(int i, int j) const
1388    {
1389        swap(sign[i],sign[j]);
1390        swap(index[i],index[j]);
1391        swap(QD[i],QD[j]);
1392    }
1393   
1394    Qfloat *get_Q(int i, int len) const
1395    {
1396        Qfloat *data;
1397        int j, real_i = index[i];
1398        if(cache->get_data(real_i,&data,l) < l)
1399        {
1400            for(j=0;j<l;j++)
1401                data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1402        }
1403
1404        // reorder and copy
1405        Qfloat *buf = buffer[next_buffer];
1406        next_buffer = 1 - next_buffer;
1407        schar si = sign[i];
1408        for(j=0;j<len;j++)
1409            buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1410        return buf;
1411    }
1412
1413    double *get_QD() const
1414    {
1415        return QD;
1416    }
1417
1418    ~SVR_Q()
1419    {
1420        delete cache;
1421        delete[] sign;
1422        delete[] index;
1423        delete[] buffer[0];
1424        delete[] buffer[1];
1425        delete[] QD;
1426    }
1427private:
1428    int l;
1429    Cache *cache;
1430    schar *sign;
1431    int *index;
1432    mutable int next_buffer;
1433    Qfloat *buffer[2];
1434    double *QD;
1435};
1436
1437//
1438// construct and solve various formulations
1439//
1440static void solve_c_svc(
1441    const svm_problem *prob, const svm_parameter* param,
1442    double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1443{
1444    int l = prob->l;
1445    double *minus_ones = new double[l];
1446    schar *y = new schar[l];
1447
1448    int i;
1449
1450    for(i=0;i<l;i++)
1451    {
1452        alpha[i] = 0;
1453        minus_ones[i] = -1;
1454        if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
1455    }
1456
1457    Solver s;
1458    s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1459        alpha, Cp, Cn, param->eps, si, param->shrinking);
1460
1461    double sum_alpha=0;
1462    for(i=0;i<l;i++)
1463        sum_alpha += alpha[i];
1464
1465    if (Cp==Cn)
1466        info("nu = %f\n", sum_alpha/(Cp*prob->l));
1467
1468    for(i=0;i<l;i++)
1469        alpha[i] *= y[i];
1470
1471    delete[] minus_ones;
1472    delete[] y;
1473}
1474
1475static void solve_nu_svc(
1476    const svm_problem *prob, const svm_parameter *param,
1477    double *alpha, Solver::SolutionInfo* si)
1478{
1479    int i;
1480    int l = prob->l;
1481    double nu = param->nu;
1482
1483    schar *y = new schar[l];
1484
1485    for(i=0;i<l;i++)
1486        if(prob->y[i]>0)
1487            y[i] = +1;
1488        else
1489            y[i] = -1;
1490
1491    double sum_pos = nu*l/2;
1492    double sum_neg = nu*l/2;
1493
1494    for(i=0;i<l;i++)
1495        if(y[i] == +1)
1496        {
1497            alpha[i] = min(1.0,sum_pos);
1498            sum_pos -= alpha[i];
1499        }
1500        else
1501        {
1502            alpha[i] = min(1.0,sum_neg);
1503            sum_neg -= alpha[i];
1504        }
1505
1506    double *zeros = new double[l];
1507
1508    for(i=0;i<l;i++)
1509        zeros[i] = 0;
1510
1511    Solver_NU s;
1512    s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1513        alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
1514    double r = si->r;
1515
1516    info("C = %f\n",1/r);
1517
1518    for(i=0;i<l;i++)
1519        alpha[i] *= y[i]/r;
1520
1521    si->rho /= r;
1522    si->obj /= (r*r);
1523    si->upper_bound_p = 1/r;
1524    si->upper_bound_n = 1/r;
1525
1526    delete[] y;
1527    delete[] zeros;
1528}
1529
1530static void solve_one_class(
1531    const svm_problem *prob, const svm_parameter *param,
1532    double *alpha, Solver::SolutionInfo* si)
1533{
1534    int l = prob->l;
1535    double *zeros = new double[l];
1536    schar *ones = new schar[l];
1537    int i;
1538
1539    int n = (int)(param->nu*prob->l);   // # of alpha's at upper bound
1540
1541    for(i=0;i<n;i++)
1542        alpha[i] = 1;
1543    if(n<prob->l)
1544        alpha[n] = param->nu * prob->l - n;
1545    for(i=n+1;i<l;i++)
1546        alpha[i] = 0;
1547
1548    for(i=0;i<l;i++)
1549    {
1550        zeros[i] = 0;
1551        ones[i] = 1;
1552    }
1553
1554    Solver s;
1555    s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1556        alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1557
1558    delete[] zeros;
1559    delete[] ones;
1560}
1561
1562static void solve_epsilon_svr(
1563    const svm_problem *prob, const svm_parameter *param,
1564    double *alpha, Solver::SolutionInfo* si)
1565{
1566    int l = prob->l;
1567    double *alpha2 = new double[2*l];
1568    double *linear_term = new double[2*l];
1569    schar *y = new schar[2*l];
1570    int i;
1571
1572    for(i=0;i<l;i++)
1573    {
1574        alpha2[i] = 0;
1575        linear_term[i] = param->p - prob->y[i];
1576        y[i] = 1;
1577
1578        alpha2[i+l] = 0;
1579        linear_term[i+l] = param->p + prob->y[i];
1580        y[i+l] = -1;
1581    }
1582
1583    Solver s;
1584    s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1585        alpha2, param->C, param->C, param->eps, si, param->shrinking);
1586
1587    double sum_alpha = 0;
1588    for(i=0;i<l;i++)
1589    {
1590        alpha[i] = alpha2[i] - alpha2[i+l];
1591        sum_alpha += fabs(alpha[i]);
1592    }
1593    info("nu = %f\n",sum_alpha/(param->C*l));
1594
1595    delete[] alpha2;
1596    delete[] linear_term;
1597    delete[] y;
1598}
1599
1600static void solve_nu_svr(
1601    const svm_problem *prob, const svm_parameter *param,
1602    double *alpha, Solver::SolutionInfo* si)
1603{
1604    int l = prob->l;
1605    double C = param->C;
1606    double *alpha2 = new double[2*l];
1607    double *linear_term = new double[2*l];
1608    schar *y = new schar[2*l];
1609    int i;
1610
1611    double sum = C * param->nu * l / 2;
1612    for(i=0;i<l;i++)
1613    {
1614        alpha2[i] = alpha2[i+l] = min(sum,C);
1615        sum -= alpha2[i];
1616
1617        linear_term[i] = - prob->y[i];
1618        y[i] = 1;
1619
1620        linear_term[i+l] = prob->y[i];
1621        y[i+l] = -1;
1622    }
1623
1624    Solver_NU s;
1625    s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1626        alpha2, C, C, param->eps, si, param->shrinking);
1627
1628    info("epsilon = %f\n",-si->r);
1629
1630    for(i=0;i<l;i++)
1631        alpha[i] = alpha2[i] - alpha2[i+l];
1632
1633    delete[] alpha2;
1634    delete[] linear_term;
1635    delete[] y;
1636}
1637
1638//
1639// decision_function
1640//
1641struct decision_function
1642{
1643    double *alpha;
1644    double rho; 
1645};
1646
1647static decision_function svm_train_one(
1648    const svm_problem *prob, const svm_parameter *param,
1649    double Cp, double Cn)
1650{
1651    double *alpha = Malloc(double,prob->l);
1652    Solver::SolutionInfo si;
1653    switch(param->svm_type)
1654    {
1655        case C_SVC:
1656            solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1657            break;
1658        case NU_SVC:
1659            solve_nu_svc(prob,param,alpha,&si);
1660            break;
1661        case ONE_CLASS:
1662            solve_one_class(prob,param,alpha,&si);
1663            break;
1664        case EPSILON_SVR:
1665            solve_epsilon_svr(prob,param,alpha,&si);
1666            break;
1667        case NU_SVR:
1668            solve_nu_svr(prob,param,alpha,&si);
1669            break;
1670    }
1671
1672    info("obj = %f, rho = %f\n",si.obj,si.rho);
1673
1674    // output SVs
1675
1676    int nSV = 0;
1677    int nBSV = 0;
1678    for(int i=0;i<prob->l;i++)
1679    {
1680        if(fabs(alpha[i]) > 0)
1681        {
1682            ++nSV;
1683            if(prob->y[i] > 0)
1684            {
1685                if(fabs(alpha[i]) >= si.upper_bound_p)
1686                    ++nBSV;
1687            }
1688            else
1689            {
1690                if(fabs(alpha[i]) >= si.upper_bound_n)
1691                    ++nBSV;
1692            }
1693        }
1694    }
1695
1696    info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1697
1698    decision_function f;
1699    f.alpha = alpha;
1700    f.rho = si.rho;
1701    return f;
1702}
1703
1704// Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1705static void sigmoid_train(
1706    int l, const double *dec_values, const double *labels, 
1707    double& A, double& B)
1708{
1709    double prior1=0, prior0 = 0;
1710    int i;
1711
1712    for (i=0;i<l;i++)
1713        if (labels[i] > 0) prior1+=1;
1714        else prior0+=1;
1715   
1716    int max_iter=100;   // Maximal number of iterations
1717    double min_step=1e-10;  // Minimal step taken in line search
1718    double sigma=1e-12; // For numerically strict PD of Hessian
1719    double eps=1e-5;
1720    double hiTarget=(prior1+1.0)/(prior1+2.0);
1721    double loTarget=1/(prior0+2.0);
1722    double *t=Malloc(double,l);
1723    double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1724    double newA,newB,newf,d1,d2;
1725    int iter; 
1726   
1727    // Initial Point and Initial Fun Value
1728    A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1729    double fval = 0.0;
1730
1731    for (i=0;i<l;i++)
1732    {
1733        if (labels[i]>0) t[i]=hiTarget;
1734        else t[i]=loTarget;
1735        fApB = dec_values[i]*A+B;
1736        if (fApB>=0)
1737            fval += t[i]*fApB + log(1+exp(-fApB));
1738        else
1739            fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1740    }
1741    for (iter=0;iter<max_iter;iter++)
1742    {
1743        // Update Gradient and Hessian (use H' = H + sigma I)
1744        h11=sigma; // numerically ensures strict PD
1745        h22=sigma;
1746        h21=0.0;g1=0.0;g2=0.0;
1747        for (i=0;i<l;i++)
1748        {
1749            fApB = dec_values[i]*A+B;
1750            if (fApB >= 0)
1751            {
1752                p=exp(-fApB)/(1.0+exp(-fApB));
1753                q=1.0/(1.0+exp(-fApB));
1754            }
1755            else
1756            {
1757                p=1.0/(1.0+exp(fApB));
1758                q=exp(fApB)/(1.0+exp(fApB));
1759            }
1760            d2=p*q;
1761            h11+=dec_values[i]*dec_values[i]*d2;
1762            h22+=d2;
1763            h21+=dec_values[i]*d2;
1764            d1=t[i]-p;
1765            g1+=dec_values[i]*d1;
1766            g2+=d1;
1767        }
1768
1769        // Stopping Criteria
1770        if (fabs(g1)<eps && fabs(g2)<eps)
1771            break;
1772
1773        // Finding Newton direction: -inv(H') * g
1774        det=h11*h22-h21*h21;
1775        dA=-(h22*g1 - h21 * g2) / det;
1776        dB=-(-h21*g1+ h11 * g2) / det;
1777        gd=g1*dA+g2*dB;
1778
1779
1780        stepsize = 1;       // Line Search
1781        while (stepsize >= min_step)
1782        {
1783            newA = A + stepsize * dA;
1784            newB = B + stepsize * dB;
1785
1786            // New function value
1787            newf = 0.0;
1788            for (i=0;i<l;i++)
1789            {
1790                fApB = dec_values[i]*newA+newB;
1791                if (fApB >= 0)
1792                    newf += t[i]*fApB + log(1+exp(-fApB));
1793                else
1794                    newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1795            }
1796            // Check sufficient decrease
1797            if (newf<fval+0.0001*stepsize*gd)
1798            {
1799                A=newA;B=newB;fval=newf;
1800                break;
1801            }
1802            else
1803                stepsize = stepsize / 2.0;
1804        }
1805
1806        if (stepsize < min_step)
1807        {
1808            info("Line search fails in two-class probability estimates\n");
1809            break;
1810        }
1811    }
1812
1813    if (iter>=max_iter)
1814        info("Reaching maximal iterations in two-class probability estimates\n");
1815    free(t);
1816}
1817
1818static double sigmoid_predict(double decision_value, double A, double B)
1819{
1820    double fApB = decision_value*A+B;
1821    // 1-p used later; avoid catastrophic cancellation
1822    if (fApB >= 0)
1823        return exp(-fApB)/(1.0+exp(-fApB));
1824    else
1825        return 1.0/(1+exp(fApB)) ;
1826}
1827
1828// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1829static void multiclass_probability(int k, double **r, double *p)
1830{
1831    int t,j;
1832    int iter = 0, max_iter=max(100,k);
1833    double **Q=Malloc(double *,k);
1834    double *Qp=Malloc(double,k);
1835    double pQp, eps=0.005/k;
1836   
1837    for (t=0;t<k;t++)
1838    {
1839        p[t]=1.0/k;  // Valid if k = 1
1840        Q[t]=Malloc(double,k);
1841        Q[t][t]=0;
1842        for (j=0;j<t;j++)
1843        {
1844            Q[t][t]+=r[j][t]*r[j][t];
1845            Q[t][j]=Q[j][t];
1846        }
1847        for (j=t+1;j<k;j++)
1848        {
1849            Q[t][t]+=r[j][t]*r[j][t];
1850            Q[t][j]=-r[j][t]*r[t][j];
1851        }
1852    }
1853    for (iter=0;iter<max_iter;iter++)
1854    {
1855        // stopping condition, recalculate QP,pQP for numerical accuracy
1856        pQp=0;
1857        for (t=0;t<k;t++)
1858        {
1859            Qp[t]=0;
1860            for (j=0;j<k;j++)
1861                Qp[t]+=Q[t][j]*p[j];
1862            pQp+=p[t]*Qp[t];
1863        }
1864        double max_error=0;
1865        for (t=0;t<k;t++)
1866        {
1867            double error=fabs(Qp[t]-pQp);
1868            if (error>max_error)
1869                max_error=error;
1870        }
1871        if (max_error<eps) break;
1872       
1873        for (t=0;t<k;t++)
1874        {
1875            double diff=(-Qp[t]+pQp)/Q[t][t];
1876            p[t]+=diff;
1877            pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1878            for (j=0;j<k;j++)
1879            {
1880                Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1881                p[j]/=(1+diff);
1882            }
1883        }
1884    }
1885    if (iter>=max_iter)
1886        info("Exceeds max_iter in multiclass_prob\n");
1887    for(t=0;t<k;t++) free(Q[t]);
1888    free(Q);
1889    free(Qp);
1890}
1891
1892// Cross-validation decision values for probability estimates
1893static void svm_binary_svc_probability(
1894    const svm_problem *prob, const svm_parameter *param,
1895    double Cp, double Cn, double& probA, double& probB)
1896{
1897    int i;
1898    int nr_fold = 5;
1899    int *perm = Malloc(int,prob->l);
1900    double *dec_values = Malloc(double,prob->l);
1901
1902    // random shuffle
1903    for(i=0;i<prob->l;i++) perm[i]=i;
1904    for(i=0;i<prob->l;i++)
1905    {
1906        int j = i+rand()%(prob->l-i);
1907        swap(perm[i],perm[j]);
1908    }
1909    for(i=0;i<nr_fold;i++)
1910    {
1911        int begin = i*prob->l/nr_fold;
1912        int end = (i+1)*prob->l/nr_fold;
1913        int j,k;
1914        struct svm_problem subprob;
1915
1916        subprob.l = prob->l-(end-begin);
1917        subprob.x = Malloc(struct svm_node*,subprob.l);
1918        subprob.y = Malloc(double,subprob.l);
1919           
1920        k=0;
1921        for(j=0;j<begin;j++)
1922        {
1923            subprob.x[k] = prob->x[perm[j]];
1924            subprob.y[k] = prob->y[perm[j]];
1925            ++k;
1926        }
1927        for(j=end;j<prob->l;j++)
1928        {
1929            subprob.x[k] = prob->x[perm[j]];
1930            subprob.y[k] = prob->y[perm[j]];
1931            ++k;
1932        }
1933        int p_count=0,n_count=0;
1934        for(j=0;j<k;j++)
1935            if(subprob.y[j]>0)
1936                p_count++;
1937            else
1938                n_count++;
1939
1940        if(p_count==0 && n_count==0)
1941            for(j=begin;j<end;j++)
1942                dec_values[perm[j]] = 0;
1943        else if(p_count > 0 && n_count == 0)
1944            for(j=begin;j<end;j++)
1945                dec_values[perm[j]] = 1;
1946        else if(p_count == 0 && n_count > 0)
1947            for(j=begin;j<end;j++)
1948                dec_values[perm[j]] = -1;
1949        else
1950        {
1951            svm_parameter subparam = *param;
1952            subparam.probability=0;
1953            subparam.C=1.0;
1954            subparam.nr_weight=2;
1955            subparam.weight_label = Malloc(int,2);
1956            subparam.weight = Malloc(double,2);
1957            subparam.weight_label[0]=+1;
1958            subparam.weight_label[1]=-1;
1959            subparam.weight[0]=Cp;
1960            subparam.weight[1]=Cn;
1961            struct svm_model *submodel = svm_train(&subprob,&subparam);
1962            for(j=begin;j<end;j++)
1963            {
1964                svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 
1965                // ensure +1 -1 order; reason not using CV subroutine
1966                dec_values[perm[j]] *= submodel->label[0];
1967            }       
1968            svm_free_and_destroy_model(&submodel);
1969            svm_destroy_param(&subparam);
1970        }
1971        free(subprob.x);
1972        free(subprob.y);
1973    }       
1974    sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1975    free(dec_values);
1976    free(perm);
1977}
1978
1979// Return parameter of a Laplace distribution
1980static double svm_svr_probability(
1981    const svm_problem *prob, const svm_parameter *param)
1982{
1983    int i;
1984    int nr_fold = 5;
1985    double *ymv = Malloc(double,prob->l);
1986    double mae = 0;
1987
1988    svm_parameter newparam = *param;
1989    newparam.probability = 0;
1990    svm_cross_validation(prob,&newparam,nr_fold,ymv);
1991    for(i=0;i<prob->l;i++)
1992    {
1993        ymv[i]=prob->y[i]-ymv[i];
1994        mae += fabs(ymv[i]);
1995    }       
1996    mae /= prob->l;
1997    double std=sqrt(2*mae*mae);
1998    int count=0;
1999    mae=0;
2000    for(i=0;i<prob->l;i++)
2001        if (fabs(ymv[i]) > 5*std) 
2002            count=count+1;
2003        else 
2004            mae+=fabs(ymv[i]);
2005    mae /= (prob->l-count);
2006    info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
2007    free(ymv);
2008    return mae;
2009}
2010
2011
2012// label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
2013// perm, length l, must be allocated before calling this subroutine
2014static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
2015{
2016    int l = prob->l;
2017    int max_nr_class = 16;
2018    int nr_class = 0;
2019    int *label = Malloc(int,max_nr_class);
2020    int *count = Malloc(int,max_nr_class);
2021    int *data_label = Malloc(int,l);   
2022    int i;
2023
2024    for(i=0;i<l;i++)
2025    {
2026        int this_label = (int)prob->y[i];
2027        int j;
2028        for(j=0;j<nr_class;j++)
2029        {
2030            if(this_label == label[j])
2031            {
2032                ++count[j];
2033                break;
2034            }
2035        }
2036        data_label[i] = j;
2037        if(j == nr_class)
2038        {
2039            if(nr_class == max_nr_class)
2040            {
2041                max_nr_class *= 2;
2042                label = (int *)realloc(label,max_nr_class*sizeof(int));
2043                count = (int *)realloc(count,max_nr_class*sizeof(int));
2044            }
2045            label[nr_class] = this_label;
2046            count[nr_class] = 1;
2047            ++nr_class;
2048        }
2049    }
2050
[11664]2051    //
2052    // Labels are ordered by their first occurrence in the training set.
2053    // However, for two-class sets with -1/+1 labels and -1 appears first,
2054    // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.
2055    //
2056    if (nr_class == 2 && label[0] == -1 && label[1] == 1)
2057    {
2058        swap(label[0],label[1]);
2059        swap(count[0],count[1]);
2060        for(i=0;i<l;i++)
2061        {
2062            if(data_label[i] == 0)
2063                data_label[i] = 1;
2064            else
2065                data_label[i] = 0;
2066        }
2067    }
2068
[8978]2069    int *start = Malloc(int,nr_class);
2070    start[0] = 0;
2071    for(i=1;i<nr_class;i++)
2072        start[i] = start[i-1]+count[i-1];
2073    for(i=0;i<l;i++)
2074    {
2075        perm[start[data_label[i]]] = i;
2076        ++start[data_label[i]];
2077    }
2078    start[0] = 0;
2079    for(i=1;i<nr_class;i++)
2080        start[i] = start[i-1]+count[i-1];
2081
2082    *nr_class_ret = nr_class;
2083    *label_ret = label;
2084    *start_ret = start;
2085    *count_ret = count;
2086    free(data_label);
2087}
2088
2089//
2090// Interface functions
2091//
2092svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2093{
2094    svm_model *model = Malloc(svm_model,1);
2095    model->param = *param;
2096    model->free_sv = 0; // XXX
2097
2098    if(param->svm_type == ONE_CLASS ||
2099       param->svm_type == EPSILON_SVR ||
2100       param->svm_type == NU_SVR)
2101    {
2102        // regression or one-class-svm
2103        model->nr_class = 2;
2104        model->label = NULL;
2105        model->nSV = NULL;
2106        model->probA = NULL; model->probB = NULL;
2107        model->sv_coef = Malloc(double *,1);
2108
2109        if(param->probability && 
2110           (param->svm_type == EPSILON_SVR ||
2111            param->svm_type == NU_SVR))
2112        {
2113            model->probA = Malloc(double,1);
2114            model->probA[0] = svm_svr_probability(prob,param);
2115        }
2116
2117        decision_function f = svm_train_one(prob,param,0,0);
2118        model->rho = Malloc(double,1);
2119        model->rho[0] = f.rho;
2120
2121        int nSV = 0;
2122        int i;
2123        for(i=0;i<prob->l;i++)
2124            if(fabs(f.alpha[i]) > 0) ++nSV;
2125        model->l = nSV;
2126        model->SV = Malloc(svm_node *,nSV);
2127        model->sv_coef[0] = Malloc(double,nSV);
[11664]2128        model->sv_indices = Malloc(int,nSV);
[8978]2129        int j = 0;
2130        for(i=0;i<prob->l;i++)
2131            if(fabs(f.alpha[i]) > 0)
2132            {
2133                model->SV[j] = prob->x[i];
2134                model->sv_coef[0][j] = f.alpha[i];
[11664]2135                model->sv_indices[j] = i+1;
[8978]2136                ++j;
2137            }       
2138
2139        free(f.alpha);
2140    }
2141    else
2142    {
2143        // classification
2144        int l = prob->l;
2145        int nr_class;
2146        int *label = NULL;
2147        int *start = NULL;
2148        int *count = NULL;
2149        int *perm = Malloc(int,l);
2150
2151        // group training data of the same class
[11664]2152        svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2153        if(nr_class == 1) 
2154            info("WARNING: training data in only one class. See README for details.\n");
2155       
[8978]2156        svm_node **x = Malloc(svm_node *,l);
2157        int i;
2158        for(i=0;i<l;i++)
2159            x[i] = prob->x[perm[i]];
2160
2161        // calculate weighted C
2162
2163        double *weighted_C = Malloc(double, nr_class);
2164        for(i=0;i<nr_class;i++)
2165            weighted_C[i] = param->C;
2166        for(i=0;i<param->nr_weight;i++)
2167        {   
2168            int j;
2169            for(j=0;j<nr_class;j++)
2170                if(param->weight_label[i] == label[j])
2171                    break;
2172            if(j == nr_class)
[11664]2173                fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]);
[8978]2174            else
2175                weighted_C[j] *= param->weight[i];
2176        }
2177
2178        // train k*(k-1)/2 models
2179       
2180        bool *nonzero = Malloc(bool,l);
2181        for(i=0;i<l;i++)
2182            nonzero[i] = false;
2183        decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2184
2185        double *probA=NULL,*probB=NULL;
2186        if (param->probability)
2187        {
2188            probA=Malloc(double,nr_class*(nr_class-1)/2);
2189            probB=Malloc(double,nr_class*(nr_class-1)/2);
2190        }
2191
2192        int p = 0;
2193        for(i=0;i<nr_class;i++)
2194            for(int j=i+1;j<nr_class;j++)
2195            {
2196                svm_problem sub_prob;
2197                int si = start[i], sj = start[j];
2198                int ci = count[i], cj = count[j];
2199                sub_prob.l = ci+cj;
2200                sub_prob.x = Malloc(svm_node *,sub_prob.l);
2201                sub_prob.y = Malloc(double,sub_prob.l);
2202                int k;
2203                for(k=0;k<ci;k++)
2204                {
2205                    sub_prob.x[k] = x[si+k];
2206                    sub_prob.y[k] = +1;
2207                }
2208                for(k=0;k<cj;k++)
2209                {
2210                    sub_prob.x[ci+k] = x[sj+k];
2211                    sub_prob.y[ci+k] = -1;
2212                }
2213
2214                if(param->probability)
2215                    svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2216
2217                f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2218                for(k=0;k<ci;k++)
2219                    if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2220                        nonzero[si+k] = true;
2221                for(k=0;k<cj;k++)
2222                    if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2223                        nonzero[sj+k] = true;
2224                free(sub_prob.x);
2225                free(sub_prob.y);
2226                ++p;
2227            }
2228
2229        // build output
2230
2231        model->nr_class = nr_class;
2232       
2233        model->label = Malloc(int,nr_class);
2234        for(i=0;i<nr_class;i++)
2235            model->label[i] = label[i];
2236       
2237        model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2238        for(i=0;i<nr_class*(nr_class-1)/2;i++)
2239            model->rho[i] = f[i].rho;
2240
2241        if(param->probability)
2242        {
2243            model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2244            model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2245            for(i=0;i<nr_class*(nr_class-1)/2;i++)
2246            {
2247                model->probA[i] = probA[i];
2248                model->probB[i] = probB[i];
2249            }
2250        }
2251        else
2252        {
2253            model->probA=NULL;
2254            model->probB=NULL;
2255        }
2256
2257        int total_sv = 0;
2258        int *nz_count = Malloc(int,nr_class);
2259        model->nSV = Malloc(int,nr_class);
2260        for(i=0;i<nr_class;i++)
2261        {
2262            int nSV = 0;
2263            for(int j=0;j<count[i];j++)
2264                if(nonzero[start[i]+j])
2265                {   
2266                    ++nSV;
2267                    ++total_sv;
2268                }
2269            model->nSV[i] = nSV;
2270            nz_count[i] = nSV;
2271        }
2272       
2273        info("Total nSV = %d\n",total_sv);
2274
2275        model->l = total_sv;
2276        model->SV = Malloc(svm_node *,total_sv);
[11664]2277        model->sv_indices = Malloc(int,total_sv);
[8978]2278        p = 0;
2279        for(i=0;i<l;i++)
[11664]2280            if(nonzero[i])
2281            {
2282                model->SV[p] = x[i];
2283                model->sv_indices[p++] = perm[i] + 1;
2284            }
[8978]2285
2286        int *nz_start = Malloc(int,nr_class);
2287        nz_start[0] = 0;
2288        for(i=1;i<nr_class;i++)
2289            nz_start[i] = nz_start[i-1]+nz_count[i-1];
2290
2291        model->sv_coef = Malloc(double *,nr_class-1);
2292        for(i=0;i<nr_class-1;i++)
2293            model->sv_coef[i] = Malloc(double,total_sv);
2294
2295        p = 0;
2296        for(i=0;i<nr_class;i++)
2297            for(int j=i+1;j<nr_class;j++)
2298            {
2299                // classifier (i,j): coefficients with
2300                // i are in sv_coef[j-1][nz_start[i]...],
2301                // j are in sv_coef[i][nz_start[j]...]
2302
2303                int si = start[i];
2304                int sj = start[j];
2305                int ci = count[i];
2306                int cj = count[j];
2307               
2308                int q = nz_start[i];
2309                int k;
2310                for(k=0;k<ci;k++)
2311                    if(nonzero[si+k])
2312                        model->sv_coef[j-1][q++] = f[p].alpha[k];
2313                q = nz_start[j];
2314                for(k=0;k<cj;k++)
2315                    if(nonzero[sj+k])
2316                        model->sv_coef[i][q++] = f[p].alpha[ci+k];
2317                ++p;
2318            }
2319       
2320        free(label);
2321        free(probA);
2322        free(probB);
2323        free(count);
2324        free(perm);
2325        free(start);
2326        free(x);
2327        free(weighted_C);
2328        free(nonzero);
2329        for(i=0;i<nr_class*(nr_class-1)/2;i++)
2330            free(f[i].alpha);
2331        free(f);
2332        free(nz_count);
2333        free(nz_start);
2334    }
2335    return model;
2336}
2337
2338// Stratified cross validation
2339void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2340{
2341    int i;
[11664]2342    int *fold_start;
[8978]2343    int l = prob->l;
2344    int *perm = Malloc(int,l);
2345    int nr_class;
[11664]2346    if (nr_fold > l)
2347    {
2348        nr_fold = l;
2349        fprintf(stderr,"WARNING: # folds > # data. Will use # folds = # data instead (i.e., leave-one-out cross validation)\n");
2350    }
2351    fold_start = Malloc(int,nr_fold+1);
[8978]2352    // stratified cv may not give leave-one-out rate
2353    // Each class to l folds -> some folds may have zero elements
2354    if((param->svm_type == C_SVC ||
2355        param->svm_type == NU_SVC) && nr_fold < l)
2356    {
2357        int *start = NULL;
2358        int *label = NULL;
2359        int *count = NULL;
2360        svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2361
2362        // random shuffle and then data grouped by fold using the array perm
2363        int *fold_count = Malloc(int,nr_fold);
2364        int c;
2365        int *index = Malloc(int,l);
2366        for(i=0;i<l;i++)
2367            index[i]=perm[i];
2368        for (c=0; c<nr_class; c++) 
2369            for(i=0;i<count[c];i++)
2370            {
2371                int j = i+rand()%(count[c]-i);
2372                swap(index[start[c]+j],index[start[c]+i]);
2373            }
2374        for(i=0;i<nr_fold;i++)
2375        {
2376            fold_count[i] = 0;
2377            for (c=0; c<nr_class;c++)
2378                fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2379        }
2380        fold_start[0]=0;
2381        for (i=1;i<=nr_fold;i++)
2382            fold_start[i] = fold_start[i-1]+fold_count[i-1];
2383        for (c=0; c<nr_class;c++)
2384            for(i=0;i<nr_fold;i++)
2385            {
2386                int begin = start[c]+i*count[c]/nr_fold;
2387                int end = start[c]+(i+1)*count[c]/nr_fold;
2388                for(int j=begin;j<end;j++)
2389                {
2390                    perm[fold_start[i]] = index[j];
2391                    fold_start[i]++;
2392                }
2393            }
2394        fold_start[0]=0;
2395        for (i=1;i<=nr_fold;i++)
2396            fold_start[i] = fold_start[i-1]+fold_count[i-1];
2397        free(start);   
2398        free(label);
2399        free(count);   
2400        free(index);
2401        free(fold_count);
2402    }
2403    else
2404    {
2405        for(i=0;i<l;i++) perm[i]=i;
2406        for(i=0;i<l;i++)
2407        {
2408            int j = i+rand()%(l-i);
2409            swap(perm[i],perm[j]);
2410        }
2411        for(i=0;i<=nr_fold;i++)
2412            fold_start[i]=i*l/nr_fold;
2413    }
2414
2415    for(i=0;i<nr_fold;i++)
2416    {
2417        int begin = fold_start[i];
2418        int end = fold_start[i+1];
2419        int j,k;
2420        struct svm_problem subprob;
2421
2422        subprob.l = l-(end-begin);
2423        subprob.x = Malloc(struct svm_node*,subprob.l);
2424        subprob.y = Malloc(double,subprob.l);
2425           
2426        k=0;
2427        for(j=0;j<begin;j++)
2428        {
2429            subprob.x[k] = prob->x[perm[j]];
2430            subprob.y[k] = prob->y[perm[j]];
2431            ++k;
2432        }
2433        for(j=end;j<l;j++)
2434        {
2435            subprob.x[k] = prob->x[perm[j]];
2436            subprob.y[k] = prob->y[perm[j]];
2437            ++k;
2438        }
2439        struct svm_model *submodel = svm_train(&subprob,param);
2440        if(param->probability && 
2441           (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2442        {
2443            double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2444            for(j=begin;j<end;j++)
2445                target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2446            free(prob_estimates);           
2447        }
2448        else
2449            for(j=begin;j<end;j++)
2450                target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2451        svm_free_and_destroy_model(&submodel);
2452        free(subprob.x);
2453        free(subprob.y);
2454    }       
2455    free(fold_start);
2456    free(perm); 
2457}
2458
2459
2460int svm_get_svm_type(const svm_model *model)
2461{
2462    return model->param.svm_type;
2463}
2464
2465int svm_get_nr_class(const svm_model *model)
2466{
2467    return model->nr_class;
2468}
2469
2470void svm_get_labels(const svm_model *model, int* label)
2471{
2472    if (model->label != NULL)
2473        for(int i=0;i<model->nr_class;i++)
2474            label[i] = model->label[i];
2475}
2476
[11664]2477void svm_get_sv_indices(const svm_model *model, int* indices)
2478{
2479    if (model->sv_indices != NULL)
2480        for(int i=0;i<model->l;i++)
2481            indices[i] = model->sv_indices[i];
2482}
2483
2484int svm_get_nr_sv(const svm_model *model)
2485{
2486    return model->l;
2487}
2488
[8978]2489double svm_get_svr_probability(const svm_model *model)
2490{
2491    if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2492        model->probA!=NULL)
2493        return model->probA[0];
2494    else
2495    {
2496        fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2497        return 0;
2498    }
2499}
2500
2501double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2502{
[11664]2503    int i;
[8978]2504    if(model->param.svm_type == ONE_CLASS ||
2505       model->param.svm_type == EPSILON_SVR ||
2506       model->param.svm_type == NU_SVR)
2507    {
2508        double *sv_coef = model->sv_coef[0];
2509        double sum = 0;
[11664]2510        for(i=0;i<model->l;i++)
[8978]2511            sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2512        sum -= model->rho[0];
2513        *dec_values = sum;
2514
2515        if(model->param.svm_type == ONE_CLASS)
2516            return (sum>0)?1:-1;
2517        else
2518            return sum;
2519    }
2520    else
2521    {
2522        int nr_class = model->nr_class;
2523        int l = model->l;
2524       
2525        double *kvalue = Malloc(double,l);
2526        for(i=0;i<l;i++)
2527            kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2528
2529        int *start = Malloc(int,nr_class);
2530        start[0] = 0;
2531        for(i=1;i<nr_class;i++)
2532            start[i] = start[i-1]+model->nSV[i-1];
2533
2534        int *vote = Malloc(int,nr_class);
2535        for(i=0;i<nr_class;i++)
2536            vote[i] = 0;
2537
2538        int p=0;
2539        for(i=0;i<nr_class;i++)
2540            for(int j=i+1;j<nr_class;j++)
2541            {
2542                double sum = 0;
2543                int si = start[i];
2544                int sj = start[j];
2545                int ci = model->nSV[i];
2546                int cj = model->nSV[j];
2547               
2548                int k;
2549                double *coef1 = model->sv_coef[j-1];
2550                double *coef2 = model->sv_coef[i];
2551                for(k=0;k<ci;k++)
2552                    sum += coef1[si+k] * kvalue[si+k];
2553                for(k=0;k<cj;k++)
2554                    sum += coef2[sj+k] * kvalue[sj+k];
2555                sum -= model->rho[p];
2556                dec_values[p] = sum;
2557
2558                if(dec_values[p] > 0)
2559                    ++vote[i];
2560                else
2561                    ++vote[j];
2562                p++;
2563            }
2564
2565        int vote_max_idx = 0;
2566        for(i=1;i<nr_class;i++)
2567            if(vote[i] > vote[vote_max_idx])
2568                vote_max_idx = i;
2569
2570        free(kvalue);
2571        free(start);
2572        free(vote);
2573        return model->label[vote_max_idx];
2574    }
2575}
2576
2577double svm_predict(const svm_model *model, const svm_node *x)
2578{
2579    int nr_class = model->nr_class;
2580    double *dec_values;
2581    if(model->param.svm_type == ONE_CLASS ||
2582       model->param.svm_type == EPSILON_SVR ||
2583       model->param.svm_type == NU_SVR)
2584        dec_values = Malloc(double, 1);
2585    else 
2586        dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2587    double pred_result = svm_predict_values(model, x, dec_values);
2588    free(dec_values);
2589    return pred_result;
2590}
2591
2592double svm_predict_probability(
2593    const svm_model *model, const svm_node *x, double *prob_estimates)
2594{
2595    if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2596        model->probA!=NULL && model->probB!=NULL)
2597    {
2598        int i;
2599        int nr_class = model->nr_class;
2600        double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2601        svm_predict_values(model, x, dec_values);
2602
2603        double min_prob=1e-7;
2604        double **pairwise_prob=Malloc(double *,nr_class);
2605        for(i=0;i<nr_class;i++)
2606            pairwise_prob[i]=Malloc(double,nr_class);
2607        int k=0;
2608        for(i=0;i<nr_class;i++)
2609            for(int j=i+1;j<nr_class;j++)
2610            {
2611                pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2612                pairwise_prob[j][i]=1-pairwise_prob[i][j];
2613                k++;
2614            }
2615        multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2616
2617        int prob_max_idx = 0;
2618        for(i=1;i<nr_class;i++)
2619            if(prob_estimates[i] > prob_estimates[prob_max_idx])
2620                prob_max_idx = i;
2621        for(i=0;i<nr_class;i++)
2622            free(pairwise_prob[i]);
2623        free(dec_values);
2624        free(pairwise_prob);         
2625        return model->label[prob_max_idx];
2626    }
2627    else 
2628        return svm_predict(model, x);
2629}
2630
2631static const char *svm_type_table[] =
2632{
2633    "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2634};
2635
2636static const char *kernel_type_table[]=
2637{
2638    "linear","polynomial","rbf","sigmoid","precomputed",NULL
2639};
2640
2641int svm_save_model(const char *model_file_name, const svm_model *model)
2642{
2643    FILE *fp = fopen(model_file_name,"w");
2644    if(fp==NULL) return -1;
2645
[11664]2646    char *old_locale = strdup(setlocale(LC_ALL, NULL));
2647    setlocale(LC_ALL, "C");
2648
[8978]2649    const svm_parameter& param = model->param;
2650
2651    fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2652    fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2653
2654    if(param.kernel_type == POLY)
2655        fprintf(fp,"degree %d\n", param.degree);
2656
2657    if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2658        fprintf(fp,"gamma %g\n", param.gamma);
2659
2660    if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2661        fprintf(fp,"coef0 %g\n", param.coef0);
2662
2663    int nr_class = model->nr_class;
2664    int l = model->l;
2665    fprintf(fp, "nr_class %d\n", nr_class);
2666    fprintf(fp, "total_sv %d\n",l);
2667   
2668    {
2669        fprintf(fp, "rho");
2670        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2671            fprintf(fp," %g",model->rho[i]);
2672        fprintf(fp, "\n");
2673    }
2674   
2675    if(model->label)
2676    {
2677        fprintf(fp, "label");
2678        for(int i=0;i<nr_class;i++)
2679            fprintf(fp," %d",model->label[i]);
2680        fprintf(fp, "\n");
2681    }
2682
2683    if(model->probA) // regression has probA only
2684    {
2685        fprintf(fp, "probA");
2686        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2687            fprintf(fp," %g",model->probA[i]);
2688        fprintf(fp, "\n");
2689    }
2690    if(model->probB)
2691    {
2692        fprintf(fp, "probB");
2693        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2694            fprintf(fp," %g",model->probB[i]);
2695        fprintf(fp, "\n");
2696    }
2697
2698    if(model->nSV)
2699    {
2700        fprintf(fp, "nr_sv");
2701        for(int i=0;i<nr_class;i++)
2702            fprintf(fp," %d",model->nSV[i]);
2703        fprintf(fp, "\n");
2704    }
2705
2706    fprintf(fp, "SV\n");
2707    const double * const *sv_coef = model->sv_coef;
2708    const svm_node * const *SV = model->SV;
2709
2710    for(int i=0;i<l;i++)
2711    {
2712        for(int j=0;j<nr_class-1;j++)
2713            fprintf(fp, "%.16g ",sv_coef[j][i]);
2714
2715        const svm_node *p = SV[i];
2716
2717        if(param.kernel_type == PRECOMPUTED)
2718            fprintf(fp,"0:%d ",(int)(p->value));
2719        else
2720            while(p->index != -1)
2721            {
2722                fprintf(fp,"%d:%.8g ",p->index,p->value);
2723                p++;
2724            }
2725        fprintf(fp, "\n");
2726    }
[11664]2727
2728    setlocale(LC_ALL, old_locale);
2729    free(old_locale);
2730
[8978]2731    if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2732    else return 0;
2733}
2734
2735static char *line = NULL;
2736static int max_line_len;
2737
2738static char* readline(FILE *input)
2739{
2740    int len;
2741
2742    if(fgets(line,max_line_len,input) == NULL)
2743        return NULL;
2744
2745    while(strrchr(line,'\n') == NULL)
2746    {
2747        max_line_len *= 2;
2748        line = (char *) realloc(line,max_line_len);
2749        len = (int) strlen(line);
2750        if(fgets(line+len,max_line_len-len,input) == NULL)
2751            break;
2752    }
2753    return line;
2754}
2755
2756svm_model *svm_load_model(const char *model_file_name)
2757{
2758    FILE *fp = fopen(model_file_name,"rb");
2759    if(fp==NULL) return NULL;
[11664]2760
2761    char *old_locale = strdup(setlocale(LC_ALL, NULL));
2762    setlocale(LC_ALL, "C");
2763
[8978]2764    // read parameters
2765
2766    svm_model *model = Malloc(svm_model,1);
2767    svm_parameter& param = model->param;
2768    model->rho = NULL;
2769    model->probA = NULL;
2770    model->probB = NULL;
[11664]2771    model->sv_indices = NULL;
[8978]2772    model->label = NULL;
2773    model->nSV = NULL;
2774
2775    char cmd[81];
2776    while(1)
2777    {
2778        fscanf(fp,"%80s",cmd);
2779
2780        if(strcmp(cmd,"svm_type")==0)
2781        {
2782            fscanf(fp,"%80s",cmd);
2783            int i;
2784            for(i=0;svm_type_table[i];i++)
2785            {
2786                if(strcmp(svm_type_table[i],cmd)==0)
2787                {
2788                    param.svm_type=i;
2789                    break;
2790                }
2791            }
2792            if(svm_type_table[i] == NULL)
2793            {
2794                fprintf(stderr,"unknown svm type.\n");
[11664]2795               
2796                setlocale(LC_ALL, old_locale);
2797                free(old_locale);
[8978]2798                free(model->rho);
2799                free(model->label);
2800                free(model->nSV);
2801                free(model);
2802                return NULL;
2803            }
2804        }
2805        else if(strcmp(cmd,"kernel_type")==0)
2806        {       
2807            fscanf(fp,"%80s",cmd);
2808            int i;
2809            for(i=0;kernel_type_table[i];i++)
2810            {
2811                if(strcmp(kernel_type_table[i],cmd)==0)
2812                {
2813                    param.kernel_type=i;
2814                    break;
2815                }
2816            }
2817            if(kernel_type_table[i] == NULL)
2818            {
2819                fprintf(stderr,"unknown kernel function.\n");
[11664]2820               
2821                setlocale(LC_ALL, old_locale);
2822                free(old_locale);
[8978]2823                free(model->rho);
2824                free(model->label);
2825                free(model->nSV);
2826                free(model);
2827                return NULL;
2828            }
2829        }
2830        else if(strcmp(cmd,"degree")==0)
2831            fscanf(fp,"%d",&param.degree);
2832        else if(strcmp(cmd,"gamma")==0)
2833            fscanf(fp,"%lf",&param.gamma);
2834        else if(strcmp(cmd,"coef0")==0)
2835            fscanf(fp,"%lf",&param.coef0);
2836        else if(strcmp(cmd,"nr_class")==0)
2837            fscanf(fp,"%d",&model->nr_class);
2838        else if(strcmp(cmd,"total_sv")==0)
2839            fscanf(fp,"%d",&model->l);
2840        else if(strcmp(cmd,"rho")==0)
2841        {
2842            int n = model->nr_class * (model->nr_class-1)/2;
2843            model->rho = Malloc(double,n);
2844            for(int i=0;i<n;i++)
2845                fscanf(fp,"%lf",&model->rho[i]);
2846        }
2847        else if(strcmp(cmd,"label")==0)
2848        {
2849            int n = model->nr_class;
2850            model->label = Malloc(int,n);
2851            for(int i=0;i<n;i++)
2852                fscanf(fp,"%d",&model->label[i]);
2853        }
2854        else if(strcmp(cmd,"probA")==0)
2855        {
2856            int n = model->nr_class * (model->nr_class-1)/2;
2857            model->probA = Malloc(double,n);
2858            for(int i=0;i<n;i++)
2859                fscanf(fp,"%lf",&model->probA[i]);
2860        }
2861        else if(strcmp(cmd,"probB")==0)
2862        {
2863            int n = model->nr_class * (model->nr_class-1)/2;
2864            model->probB = Malloc(double,n);
2865            for(int i=0;i<n;i++)
2866                fscanf(fp,"%lf",&model->probB[i]);
2867        }
2868        else if(strcmp(cmd,"nr_sv")==0)
2869        {
2870            int n = model->nr_class;
2871            model->nSV = Malloc(int,n);
2872            for(int i=0;i<n;i++)
2873                fscanf(fp,"%d",&model->nSV[i]);
2874        }
2875        else if(strcmp(cmd,"SV")==0)
2876        {
2877            while(1)
2878            {
2879                int c = getc(fp);
2880                if(c==EOF || c=='\n') break;   
2881            }
2882            break;
2883        }
2884        else
2885        {
2886            fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
[11664]2887           
2888            setlocale(LC_ALL, old_locale);
2889            free(old_locale);
[8978]2890            free(model->rho);
2891            free(model->label);
2892            free(model->nSV);
2893            free(model);
2894            return NULL;
2895        }
2896    }
2897
2898    // read sv_coef and SV
2899
2900    int elements = 0;
2901    long pos = ftell(fp);
2902
2903    max_line_len = 1024;
2904    line = Malloc(char,max_line_len);
2905    char *p,*endptr,*idx,*val;
2906
2907    while(readline(fp)!=NULL)
2908    {
2909        p = strtok(line,":");
2910        while(1)
2911        {
2912            p = strtok(NULL,":");
2913            if(p == NULL)
2914                break;
2915            ++elements;
2916        }
2917    }
2918    elements += model->l;
2919
2920    fseek(fp,pos,SEEK_SET);
2921
2922    int m = model->nr_class - 1;
2923    int l = model->l;
2924    model->sv_coef = Malloc(double *,m);
2925    int i;
2926    for(i=0;i<m;i++)
2927        model->sv_coef[i] = Malloc(double,l);
2928    model->SV = Malloc(svm_node*,l);
2929    svm_node *x_space = NULL;
2930    if(l>0) x_space = Malloc(svm_node,elements);
2931
2932    int j=0;
2933    for(i=0;i<l;i++)
2934    {
2935        readline(fp);
2936        model->SV[i] = &x_space[j];
2937
2938        p = strtok(line, " \t");
2939        model->sv_coef[0][i] = strtod(p,&endptr);
2940        for(int k=1;k<m;k++)
2941        {
2942            p = strtok(NULL, " \t");
2943            model->sv_coef[k][i] = strtod(p,&endptr);
2944        }
2945
2946        while(1)
2947        {
2948            idx = strtok(NULL, ":");
2949            val = strtok(NULL, " \t");
2950
2951            if(val == NULL)
2952                break;
2953            x_space[j].index = (int) strtol(idx,&endptr,10);
2954            x_space[j].value = strtod(val,&endptr);
2955
2956            ++j;
2957        }
2958        x_space[j++].index = -1;
2959    }
2960    free(line);
2961
[11664]2962    setlocale(LC_ALL, old_locale);
2963    free(old_locale);
2964
[8978]2965    if (ferror(fp) != 0 || fclose(fp) != 0)
2966        return NULL;
2967
2968    model->free_sv = 1; // XXX
2969    return model;
2970}
2971
2972void svm_free_model_content(svm_model* model_ptr)
2973{
2974    if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL)
2975        free((void *)(model_ptr->SV[0]));
2976    if(model_ptr->sv_coef)
2977    {
2978        for(int i=0;i<model_ptr->nr_class-1;i++)
2979            free(model_ptr->sv_coef[i]);
2980    }
2981
2982    free(model_ptr->SV);
2983    model_ptr->SV = NULL;
2984
2985    free(model_ptr->sv_coef);
2986    model_ptr->sv_coef = NULL;
2987
2988    free(model_ptr->rho);
2989    model_ptr->rho = NULL;
2990
2991    free(model_ptr->label);
2992    model_ptr->label= NULL;
2993
2994    free(model_ptr->probA);
2995    model_ptr->probA = NULL;
2996
2997    free(model_ptr->probB);
2998    model_ptr->probB= NULL;
2999
[11664]3000    free(model_ptr->sv_indices);
3001    model_ptr->sv_indices = NULL;
3002
[8978]3003    free(model_ptr->nSV);
3004    model_ptr->nSV = NULL;
3005}
3006
3007void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
3008{
3009    if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL)
3010    {
3011        svm_free_model_content(*model_ptr_ptr);
3012        free(*model_ptr_ptr);
3013        *model_ptr_ptr = NULL;
3014    }
3015}
3016
3017void svm_destroy_param(svm_parameter* param)
3018{
3019    free(param->weight_label);
3020    free(param->weight);
3021}
3022
3023const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
3024{
3025    // svm_type
3026
3027    int svm_type = param->svm_type;
3028    if(svm_type != C_SVC &&
3029       svm_type != NU_SVC &&
3030       svm_type != ONE_CLASS &&
3031       svm_type != EPSILON_SVR &&
3032       svm_type != NU_SVR)
3033        return "unknown svm type";
3034   
3035    // kernel_type, degree
3036   
3037    int kernel_type = param->kernel_type;
3038    if(kernel_type != LINEAR &&
3039       kernel_type != POLY &&
3040       kernel_type != RBF &&
3041       kernel_type != SIGMOID &&
3042       kernel_type != PRECOMPUTED)
3043        return "unknown kernel type";
3044
3045    if(param->gamma < 0)
3046        return "gamma < 0";
3047
3048    if(param->degree < 0)
3049        return "degree of polynomial kernel < 0";
3050
3051    // cache_size,eps,C,nu,p,shrinking
3052
3053    if(param->cache_size <= 0)
3054        return "cache_size <= 0";
3055
3056    if(param->eps <= 0)
3057        return "eps <= 0";
3058
3059    if(svm_type == C_SVC ||
3060       svm_type == EPSILON_SVR ||
3061       svm_type == NU_SVR)
3062        if(param->C <= 0)
3063            return "C <= 0";
3064
3065    if(svm_type == NU_SVC ||
3066       svm_type == ONE_CLASS ||
3067       svm_type == NU_SVR)
3068        if(param->nu <= 0 || param->nu > 1)
3069            return "nu <= 0 or nu > 1";
3070
3071    if(svm_type == EPSILON_SVR)
3072        if(param->p < 0)
3073            return "p < 0";
3074
3075    if(param->shrinking != 0 &&
3076       param->shrinking != 1)
3077        return "shrinking != 0 and shrinking != 1";
3078
3079    if(param->probability != 0 &&
3080       param->probability != 1)
3081        return "probability != 0 and probability != 1";
3082
3083    if(param->probability == 1 &&
3084       svm_type == ONE_CLASS)
3085        return "one-class SVM probability output not supported yet";
3086
3087
3088    // check whether nu-svc is feasible
3089   
3090    if(svm_type == NU_SVC)
3091    {
3092        int l = prob->l;
3093        int max_nr_class = 16;
3094        int nr_class = 0;
3095        int *label = Malloc(int,max_nr_class);
3096        int *count = Malloc(int,max_nr_class);
3097
3098        int i;
3099        for(i=0;i<l;i++)
3100        {
3101            int this_label = (int)prob->y[i];
3102            int j;
3103            for(j=0;j<nr_class;j++)
3104                if(this_label == label[j])
3105                {
3106                    ++count[j];
3107                    break;
3108                }
3109            if(j == nr_class)
3110            {
3111                if(nr_class == max_nr_class)
3112                {
3113                    max_nr_class *= 2;
3114                    label = (int *)realloc(label,max_nr_class*sizeof(int));
3115                    count = (int *)realloc(count,max_nr_class*sizeof(int));
3116                }
3117                label[nr_class] = this_label;
3118                count[nr_class] = 1;
3119                ++nr_class;
3120            }
3121        }
3122   
3123        for(i=0;i<nr_class;i++)
3124        {
3125            int n1 = count[i];
3126            for(int j=i+1;j<nr_class;j++)
3127            {
3128                int n2 = count[j];
3129                if(param->nu*(n1+n2)/2 > min(n1,n2))
3130                {
3131                    free(label);
3132                    free(count);
3133                    return "specified nu is infeasible";
3134                }
3135            }
3136        }
3137        free(label);
3138        free(count);
3139    }
3140
3141    return NULL;
3142}
3143
3144int svm_check_probability_model(const svm_model *model)
3145{
3146    return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3147        model->probA!=NULL && model->probB!=NULL) ||
3148        ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3149         model->probA!=NULL);
3150}
3151
3152void svm_set_print_string_function(void (*print_func)(const char *))
3153{
3154    if(print_func == NULL)
3155        svm_print_string = &print_string_stdout;
3156    else
3157        svm_print_string = print_func;
3158}
Note: See TracBrowser for help on using the repository browser.