source: orange/source/orange/libsvm/svm.cpp @ 8978:f0ba3673b0a7

Revision 8978:f0ba3673b0a7, 61.0 KB checked in by ales_erjavec <ales.erjavec@…>, 3 years ago (diff)

Refactored the code for LibSVM and LIBLINEAR. The sources are unmodified in liblinear and libsvm directories. Can be excluded from the build and instead linked to a system libraries instead.
The most significant change is in the way custom kernels are handled for SVMLearner - it now uses 'PRECOMPUTED' kernel functionality from LibSVM (as a bonus the results seem to be better then before (was there a bug in the previous code?), also faster)

Line 
1#include <math.h>
2#include <stdio.h>
3#include <stdlib.h>
4#include <ctype.h>
5#include <float.h>
6#include <string.h>
7#include <stdarg.h>
8#include "svm.h"
9int libsvm_version = LIBSVM_VERSION;
10typedef float Qfloat;
11typedef signed char schar;
12#ifndef min
13template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
14#endif
15#ifndef max
16template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
17#endif
18template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
19template <class S, class T> static inline void clone(T*& dst, S* src, int n)
20{
21    dst = new T[n];
22    memcpy((void *)dst,(void *)src,sizeof(T)*n);
23}
24static inline double powi(double base, int times)
25{
26    double tmp = base, ret = 1.0;
27
28    for(int t=times; t>0; t/=2)
29    {
30        if(t%2==1) ret*=tmp;
31        tmp = tmp * tmp;
32    }
33    return ret;
34}
35#define INF HUGE_VAL
36#define TAU 1e-12
37#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
38
39static void print_string_stdout(const char *s)
40{
41    fputs(s,stdout);
42    fflush(stdout);
43}
44static void (*svm_print_string) (const char *) = &print_string_stdout;
45#if 1
46static void info(const char *fmt,...)
47{
48    char buf[BUFSIZ];
49    va_list ap;
50    va_start(ap,fmt);
51    vsprintf(buf,fmt,ap);
52    va_end(ap);
53    (*svm_print_string)(buf);
54}
55#else
56static void info(const char *fmt,...) {}
57#endif
58
59//
60// Kernel Cache
61//
62// l is the number of total data items
63// size is the cache size limit in bytes
64//
65class Cache
66{
67public:
68    Cache(int l,long int size);
69    ~Cache();
70
71    // request data [0,len)
72    // return some position p where [p,len) need to be filled
73    // (p >= len if nothing needs to be filled)
74    int get_data(const int index, Qfloat **data, int len);
75    void swap_index(int i, int j); 
76private:
77    int l;
78    long int size;
79    struct head_t
80    {
81        head_t *prev, *next;    // a circular list
82        Qfloat *data;
83        int len;        // data[0,len) is cached in this entry
84    };
85
86    head_t *head;
87    head_t lru_head;
88    void lru_delete(head_t *h);
89    void lru_insert(head_t *h);
90};
91
92Cache::Cache(int l_,long int size_):l(l_),size(size_)
93{
94    head = (head_t *)calloc(l,sizeof(head_t));  // initialized to 0
95    size /= sizeof(Qfloat);
96    size -= l * sizeof(head_t) / sizeof(Qfloat);
97    size = max(size, 2 * (long int) l); // cache must be large enough for two columns
98    lru_head.next = lru_head.prev = &lru_head;
99}
100
101Cache::~Cache()
102{
103    for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
104        free(h->data);
105    free(head);
106}
107
108void Cache::lru_delete(head_t *h)
109{
110    // delete from current location
111    h->prev->next = h->next;
112    h->next->prev = h->prev;
113}
114
115void Cache::lru_insert(head_t *h)
116{
117    // insert to last position
118    h->next = &lru_head;
119    h->prev = lru_head.prev;
120    h->prev->next = h;
121    h->next->prev = h;
122}
123
124int Cache::get_data(const int index, Qfloat **data, int len)
125{
126    head_t *h = &head[index];
127    if(h->len) lru_delete(h);
128    int more = len - h->len;
129
130    if(more > 0)
131    {
132        // free old space
133        while(size < more)
134        {
135            head_t *old = lru_head.next;
136            lru_delete(old);
137            free(old->data);
138            size += old->len;
139            old->data = 0;
140            old->len = 0;
141        }
142
143        // allocate new space
144        h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
145        size -= more;
146        swap(h->len,len);
147    }
148
149    lru_insert(h);
150    *data = h->data;
151    return len;
152}
153
154void Cache::swap_index(int i, int j)
155{
156    if(i==j) return;
157
158    if(head[i].len) lru_delete(&head[i]);
159    if(head[j].len) lru_delete(&head[j]);
160    swap(head[i].data,head[j].data);
161    swap(head[i].len,head[j].len);
162    if(head[i].len) lru_insert(&head[i]);
163    if(head[j].len) lru_insert(&head[j]);
164
165    if(i>j) swap(i,j);
166    for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
167    {
168        if(h->len > i)
169        {
170            if(h->len > j)
171                swap(h->data[i],h->data[j]);
172            else
173            {
174                // give up
175                lru_delete(h);
176                free(h->data);
177                size += h->len;
178                h->data = 0;
179                h->len = 0;
180            }
181        }
182    }
183}
184
185//
186// Kernel evaluation
187//
188// the static method k_function is for doing single kernel evaluation
189// the constructor of Kernel prepares to calculate the l*l kernel matrix
190// the member function get_Q is for getting one column from the Q Matrix
191//
192class QMatrix {
193public:
194    virtual Qfloat *get_Q(int column, int len) const = 0;
195    virtual double *get_QD() const = 0;
196    virtual void swap_index(int i, int j) const = 0;
197    virtual ~QMatrix() {}
198};
199
200class Kernel: public QMatrix {
201public:
202    Kernel(int l, svm_node * const * x, const svm_parameter& param);
203    virtual ~Kernel();
204
205    static double k_function(const svm_node *x, const svm_node *y,
206                 const svm_parameter& param);
207    virtual Qfloat *get_Q(int column, int len) const = 0;
208    virtual double *get_QD() const = 0;
209    virtual void swap_index(int i, int j) const // no so const...
210    {
211        swap(x[i],x[j]);
212        if(x_square) swap(x_square[i],x_square[j]);
213    }
214protected:
215
216    double (Kernel::*kernel_function)(int i, int j) const;
217
218private:
219    const svm_node **x;
220    double *x_square;
221
222    // svm_parameter
223    const int kernel_type;
224    const int degree;
225    const double gamma;
226    const double coef0;
227
228    static double dot(const svm_node *px, const svm_node *py);
229    double kernel_linear(int i, int j) const
230    {
231        return dot(x[i],x[j]);
232    }
233    double kernel_poly(int i, int j) const
234    {
235        return powi(gamma*dot(x[i],x[j])+coef0,degree);
236    }
237    double kernel_rbf(int i, int j) const
238    {
239        return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
240    }
241    double kernel_sigmoid(int i, int j) const
242    {
243        return tanh(gamma*dot(x[i],x[j])+coef0);
244    }
245    double kernel_precomputed(int i, int j) const
246    {
247        return x[i][(int)(x[j][0].value)].value;
248    }
249};
250
251Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
252:kernel_type(param.kernel_type), degree(param.degree),
253 gamma(param.gamma), coef0(param.coef0)
254{
255    switch(kernel_type)
256    {
257        case LINEAR:
258            kernel_function = &Kernel::kernel_linear;
259            break;
260        case POLY:
261            kernel_function = &Kernel::kernel_poly;
262            break;
263        case RBF:
264            kernel_function = &Kernel::kernel_rbf;
265            break;
266        case SIGMOID:
267            kernel_function = &Kernel::kernel_sigmoid;
268            break;
269        case PRECOMPUTED:
270            kernel_function = &Kernel::kernel_precomputed;
271            break;
272    }
273
274    clone(x,x_,l);
275
276    if(kernel_type == RBF)
277    {
278        x_square = new double[l];
279        for(int i=0;i<l;i++)
280            x_square[i] = dot(x[i],x[i]);
281    }
282    else
283        x_square = 0;
284}
285
286Kernel::~Kernel()
287{
288    delete[] x;
289    delete[] x_square;
290}
291
292double Kernel::dot(const svm_node *px, const svm_node *py)
293{
294    double sum = 0;
295    while(px->index != -1 && py->index != -1)
296    {
297        if(px->index == py->index)
298        {
299            sum += px->value * py->value;
300            ++px;
301            ++py;
302        }
303        else
304        {
305            if(px->index > py->index)
306                ++py;
307            else
308                ++px;
309        }           
310    }
311    return sum;
312}
313
314double Kernel::k_function(const svm_node *x, const svm_node *y,
315              const svm_parameter& param)
316{
317    switch(param.kernel_type)
318    {
319        case LINEAR:
320            return dot(x,y);
321        case POLY:
322            return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
323        case RBF:
324        {
325            double sum = 0;
326            while(x->index != -1 && y->index !=-1)
327            {
328                if(x->index == y->index)
329                {
330                    double d = x->value - y->value;
331                    sum += d*d;
332                    ++x;
333                    ++y;
334                }
335                else
336                {
337                    if(x->index > y->index)
338                    {   
339                        sum += y->value * y->value;
340                        ++y;
341                    }
342                    else
343                    {
344                        sum += x->value * x->value;
345                        ++x;
346                    }
347                }
348            }
349
350            while(x->index != -1)
351            {
352                sum += x->value * x->value;
353                ++x;
354            }
355
356            while(y->index != -1)
357            {
358                sum += y->value * y->value;
359                ++y;
360            }
361           
362            return exp(-param.gamma*sum);
363        }
364        case SIGMOID:
365            return tanh(param.gamma*dot(x,y)+param.coef0);
366        case PRECOMPUTED:  //x: test (validation), y: SV
367            return x[(int)(y->value)].value;
368        default:
369            return 0;  // Unreachable
370    }
371}
372
373// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
374// Solves:
375//
376//  min 0.5(\alpha^T Q \alpha) + p^T \alpha
377//
378//      y^T \alpha = \delta
379//      y_i = +1 or -1
380//      0 <= alpha_i <= Cp for y_i = 1
381//      0 <= alpha_i <= Cn for y_i = -1
382//
383// Given:
384//
385//  Q, p, y, Cp, Cn, and an initial feasible point \alpha
386//  l is the size of vectors and matrices
387//  eps is the stopping tolerance
388//
389// solution will be put in \alpha, objective value will be put in obj
390//
391class Solver {
392public:
393    Solver() {};
394    virtual ~Solver() {};
395
396    struct SolutionInfo {
397        double obj;
398        double rho;
399        double upper_bound_p;
400        double upper_bound_n;
401        double r;   // for Solver_NU
402    };
403
404    void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
405           double *alpha_, double Cp, double Cn, double eps,
406           SolutionInfo* si, int shrinking);
407protected:
408    int active_size;
409    schar *y;
410    double *G;      // gradient of objective function
411    enum { LOWER_BOUND, UPPER_BOUND, FREE };
412    char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
413    double *alpha;
414    const QMatrix *Q;
415    const double *QD;
416    double eps;
417    double Cp,Cn;
418    double *p;
419    int *active_set;
420    double *G_bar;      // gradient, if we treat free variables as 0
421    int l;
422    bool unshrink;  // XXX
423
424    double get_C(int i)
425    {
426        return (y[i] > 0)? Cp : Cn;
427    }
428    void update_alpha_status(int i)
429    {
430        if(alpha[i] >= get_C(i))
431            alpha_status[i] = UPPER_BOUND;
432        else if(alpha[i] <= 0)
433            alpha_status[i] = LOWER_BOUND;
434        else alpha_status[i] = FREE;
435    }
436    bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
437    bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
438    bool is_free(int i) { return alpha_status[i] == FREE; }
439    void swap_index(int i, int j);
440    void reconstruct_gradient();
441    virtual int select_working_set(int &i, int &j);
442    virtual double calculate_rho();
443    virtual void do_shrinking();
444private:
445    bool be_shrunk(int i, double Gmax1, double Gmax2); 
446};
447
448void Solver::swap_index(int i, int j)
449{
450    Q->swap_index(i,j);
451    swap(y[i],y[j]);
452    swap(G[i],G[j]);
453    swap(alpha_status[i],alpha_status[j]);
454    swap(alpha[i],alpha[j]);
455    swap(p[i],p[j]);
456    swap(active_set[i],active_set[j]);
457    swap(G_bar[i],G_bar[j]);
458}
459
460void Solver::reconstruct_gradient()
461{
462    // reconstruct inactive elements of G from G_bar and free variables
463
464    if(active_size == l) return;
465
466    int i,j;
467    int nr_free = 0;
468
469    for(j=active_size;j<l;j++)
470        G[j] = G_bar[j] + p[j];
471
472    for(j=0;j<active_size;j++)
473        if(is_free(j))
474            nr_free++;
475
476    if(2*nr_free < active_size)
477        info("\nWarning: using -h 0 may be faster\n");
478
479    if (nr_free*l > 2*active_size*(l-active_size))
480    {
481        for(i=active_size;i<l;i++)
482        {
483            const Qfloat *Q_i = Q->get_Q(i,active_size);
484            for(j=0;j<active_size;j++)
485                if(is_free(j))
486                    G[i] += alpha[j] * Q_i[j];
487        }
488    }
489    else
490    {
491        for(i=0;i<active_size;i++)
492            if(is_free(i))
493            {
494                const Qfloat *Q_i = Q->get_Q(i,l);
495                double alpha_i = alpha[i];
496                for(j=active_size;j<l;j++)
497                    G[j] += alpha_i * Q_i[j];
498            }
499    }
500}
501
502void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
503           double *alpha_, double Cp, double Cn, double eps,
504           SolutionInfo* si, int shrinking)
505{
506    this->l = l;
507    this->Q = &Q;
508    QD=Q.get_QD();
509    clone(p, p_,l);
510    clone(y, y_,l);
511    clone(alpha,alpha_,l);
512    this->Cp = Cp;
513    this->Cn = Cn;
514    this->eps = eps;
515    unshrink = false;
516
517    // initialize alpha_status
518    {
519        alpha_status = new char[l];
520        for(int i=0;i<l;i++)
521            update_alpha_status(i);
522    }
523
524    // initialize active set (for shrinking)
525    {
526        active_set = new int[l];
527        for(int i=0;i<l;i++)
528            active_set[i] = i;
529        active_size = l;
530    }
531
532    // initialize gradient
533    {
534        G = new double[l];
535        G_bar = new double[l];
536        int i;
537        for(i=0;i<l;i++)
538        {
539            G[i] = p[i];
540            G_bar[i] = 0;
541        }
542        for(i=0;i<l;i++)
543            if(!is_lower_bound(i))
544            {
545                const Qfloat *Q_i = Q.get_Q(i,l);
546                double alpha_i = alpha[i];
547                int j;
548                for(j=0;j<l;j++)
549                    G[j] += alpha_i*Q_i[j];
550                if(is_upper_bound(i))
551                    for(j=0;j<l;j++)
552                        G_bar[j] += get_C(i) * Q_i[j];
553            }
554    }
555
556    // optimization step
557
558    int iter = 0;
559    int counter = min(l,1000)+1;
560
561    while(1)
562    {
563        // show progress and do shrinking
564
565        if(--counter == 0)
566        {
567            counter = min(l,1000);
568            if(shrinking) do_shrinking();
569            info(".");
570        }
571
572        int i,j;
573        if(select_working_set(i,j)!=0)
574        {
575            // reconstruct the whole gradient
576            reconstruct_gradient();
577            // reset active set size and check
578            active_size = l;
579            info("*");
580            if(select_working_set(i,j)!=0)
581                break;
582            else
583                counter = 1;    // do shrinking next iteration
584        }
585       
586        ++iter;
587
588        // update alpha[i] and alpha[j], handle bounds carefully
589       
590        const Qfloat *Q_i = Q.get_Q(i,active_size);
591        const Qfloat *Q_j = Q.get_Q(j,active_size);
592
593        double C_i = get_C(i);
594        double C_j = get_C(j);
595
596        double old_alpha_i = alpha[i];
597        double old_alpha_j = alpha[j];
598
599        if(y[i]!=y[j])
600        {
601            double quad_coef = QD[i]+QD[j]+2*Q_i[j];
602            if (quad_coef <= 0)
603                quad_coef = TAU;
604            double delta = (-G[i]-G[j])/quad_coef;
605            double diff = alpha[i] - alpha[j];
606            alpha[i] += delta;
607            alpha[j] += delta;
608           
609            if(diff > 0)
610            {
611                if(alpha[j] < 0)
612                {
613                    alpha[j] = 0;
614                    alpha[i] = diff;
615                }
616            }
617            else
618            {
619                if(alpha[i] < 0)
620                {
621                    alpha[i] = 0;
622                    alpha[j] = -diff;
623                }
624            }
625            if(diff > C_i - C_j)
626            {
627                if(alpha[i] > C_i)
628                {
629                    alpha[i] = C_i;
630                    alpha[j] = C_i - diff;
631                }
632            }
633            else
634            {
635                if(alpha[j] > C_j)
636                {
637                    alpha[j] = C_j;
638                    alpha[i] = C_j + diff;
639                }
640            }
641        }
642        else
643        {
644            double quad_coef = QD[i]+QD[j]-2*Q_i[j];
645            if (quad_coef <= 0)
646                quad_coef = TAU;
647            double delta = (G[i]-G[j])/quad_coef;
648            double sum = alpha[i] + alpha[j];
649            alpha[i] -= delta;
650            alpha[j] += delta;
651
652            if(sum > C_i)
653            {
654                if(alpha[i] > C_i)
655                {
656                    alpha[i] = C_i;
657                    alpha[j] = sum - C_i;
658                }
659            }
660            else
661            {
662                if(alpha[j] < 0)
663                {
664                    alpha[j] = 0;
665                    alpha[i] = sum;
666                }
667            }
668            if(sum > C_j)
669            {
670                if(alpha[j] > C_j)
671                {
672                    alpha[j] = C_j;
673                    alpha[i] = sum - C_j;
674                }
675            }
676            else
677            {
678                if(alpha[i] < 0)
679                {
680                    alpha[i] = 0;
681                    alpha[j] = sum;
682                }
683            }
684        }
685
686        // update G
687
688        double delta_alpha_i = alpha[i] - old_alpha_i;
689        double delta_alpha_j = alpha[j] - old_alpha_j;
690       
691        for(int k=0;k<active_size;k++)
692        {
693            G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
694        }
695
696        // update alpha_status and G_bar
697
698        {
699            bool ui = is_upper_bound(i);
700            bool uj = is_upper_bound(j);
701            update_alpha_status(i);
702            update_alpha_status(j);
703            int k;
704            if(ui != is_upper_bound(i))
705            {
706                Q_i = Q.get_Q(i,l);
707                if(ui)
708                    for(k=0;k<l;k++)
709                        G_bar[k] -= C_i * Q_i[k];
710                else
711                    for(k=0;k<l;k++)
712                        G_bar[k] += C_i * Q_i[k];
713            }
714
715            if(uj != is_upper_bound(j))
716            {
717                Q_j = Q.get_Q(j,l);
718                if(uj)
719                    for(k=0;k<l;k++)
720                        G_bar[k] -= C_j * Q_j[k];
721                else
722                    for(k=0;k<l;k++)
723                        G_bar[k] += C_j * Q_j[k];
724            }
725        }
726    }
727
728    // calculate rho
729
730    si->rho = calculate_rho();
731
732    // calculate objective value
733    {
734        double v = 0;
735        int i;
736        for(i=0;i<l;i++)
737            v += alpha[i] * (G[i] + p[i]);
738
739        si->obj = v/2;
740    }
741
742    // put back the solution
743    {
744        for(int i=0;i<l;i++)
745            alpha_[active_set[i]] = alpha[i];
746    }
747
748    // juggle everything back
749    /*{
750        for(int i=0;i<l;i++)
751            while(active_set[i] != i)
752                swap_index(i,active_set[i]);
753                // or Q.swap_index(i,active_set[i]);
754    }*/
755
756    si->upper_bound_p = Cp;
757    si->upper_bound_n = Cn;
758
759    info("\noptimization finished, #iter = %d\n",iter);
760
761    delete[] p;
762    delete[] y;
763    delete[] alpha;
764    delete[] alpha_status;
765    delete[] active_set;
766    delete[] G;
767    delete[] G_bar;
768}
769
770// return 1 if already optimal, return 0 otherwise
771int Solver::select_working_set(int &out_i, int &out_j)
772{
773    // return i,j such that
774    // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
775    // j: minimizes the decrease of obj value
776    //    (if quadratic coefficeint <= 0, replace it with tau)
777    //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
778   
779    double Gmax = -INF;
780    double Gmax2 = -INF;
781    int Gmax_idx = -1;
782    int Gmin_idx = -1;
783    double obj_diff_min = INF;
784
785    for(int t=0;t<active_size;t++)
786        if(y[t]==+1)   
787        {
788            if(!is_upper_bound(t))
789                if(-G[t] >= Gmax)
790                {
791                    Gmax = -G[t];
792                    Gmax_idx = t;
793                }
794        }
795        else
796        {
797            if(!is_lower_bound(t))
798                if(G[t] >= Gmax)
799                {
800                    Gmax = G[t];
801                    Gmax_idx = t;
802                }
803        }
804
805    int i = Gmax_idx;
806    const Qfloat *Q_i = NULL;
807    if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
808        Q_i = Q->get_Q(i,active_size);
809
810    for(int j=0;j<active_size;j++)
811    {
812        if(y[j]==+1)
813        {
814            if (!is_lower_bound(j))
815            {
816                double grad_diff=Gmax+G[j];
817                if (G[j] >= Gmax2)
818                    Gmax2 = G[j];
819                if (grad_diff > 0)
820                {
821                    double obj_diff; 
822                    double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
823                    if (quad_coef > 0)
824                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
825                    else
826                        obj_diff = -(grad_diff*grad_diff)/TAU;
827
828                    if (obj_diff <= obj_diff_min)
829                    {
830                        Gmin_idx=j;
831                        obj_diff_min = obj_diff;
832                    }
833                }
834            }
835        }
836        else
837        {
838            if (!is_upper_bound(j))
839            {
840                double grad_diff= Gmax-G[j];
841                if (-G[j] >= Gmax2)
842                    Gmax2 = -G[j];
843                if (grad_diff > 0)
844                {
845                    double obj_diff; 
846                    double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
847                    if (quad_coef > 0)
848                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
849                    else
850                        obj_diff = -(grad_diff*grad_diff)/TAU;
851
852                    if (obj_diff <= obj_diff_min)
853                    {
854                        Gmin_idx=j;
855                        obj_diff_min = obj_diff;
856                    }
857                }
858            }
859        }
860    }
861
862    if(Gmax+Gmax2 < eps)
863        return 1;
864
865    out_i = Gmax_idx;
866    out_j = Gmin_idx;
867    return 0;
868}
869
870bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
871{
872    if(is_upper_bound(i))
873    {
874        if(y[i]==+1)
875            return(-G[i] > Gmax1);
876        else
877            return(-G[i] > Gmax2);
878    }
879    else if(is_lower_bound(i))
880    {
881        if(y[i]==+1)
882            return(G[i] > Gmax2);
883        else   
884            return(G[i] > Gmax1);
885    }
886    else
887        return(false);
888}
889
890void Solver::do_shrinking()
891{
892    int i;
893    double Gmax1 = -INF;        // max { -y_i * grad(f)_i | i in I_up(\alpha) }
894    double Gmax2 = -INF;        // max { y_i * grad(f)_i | i in I_low(\alpha) }
895
896    // find maximal violating pair first
897    for(i=0;i<active_size;i++)
898    {
899        if(y[i]==+1)   
900        {
901            if(!is_upper_bound(i)) 
902            {
903                if(-G[i] >= Gmax1)
904                    Gmax1 = -G[i];
905            }
906            if(!is_lower_bound(i)) 
907            {
908                if(G[i] >= Gmax2)
909                    Gmax2 = G[i];
910            }
911        }
912        else   
913        {
914            if(!is_upper_bound(i)) 
915            {
916                if(-G[i] >= Gmax2)
917                    Gmax2 = -G[i];
918            }
919            if(!is_lower_bound(i)) 
920            {
921                if(G[i] >= Gmax1)
922                    Gmax1 = G[i];
923            }
924        }
925    }
926
927    if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
928    {
929        unshrink = true;
930        reconstruct_gradient();
931        active_size = l;
932        info("*");
933    }
934
935    for(i=0;i<active_size;i++)
936        if (be_shrunk(i, Gmax1, Gmax2))
937        {
938            active_size--;
939            while (active_size > i)
940            {
941                if (!be_shrunk(active_size, Gmax1, Gmax2))
942                {
943                    swap_index(i,active_size);
944                    break;
945                }
946                active_size--;
947            }
948        }
949}
950
951double Solver::calculate_rho()
952{
953    double r;
954    int nr_free = 0;
955    double ub = INF, lb = -INF, sum_free = 0;
956    for(int i=0;i<active_size;i++)
957    {
958        double yG = y[i]*G[i];
959
960        if(is_upper_bound(i))
961        {
962            if(y[i]==-1)
963                ub = min(ub,yG);
964            else
965                lb = max(lb,yG);
966        }
967        else if(is_lower_bound(i))
968        {
969            if(y[i]==+1)
970                ub = min(ub,yG);
971            else
972                lb = max(lb,yG);
973        }
974        else
975        {
976            ++nr_free;
977            sum_free += yG;
978        }
979    }
980
981    if(nr_free>0)
982        r = sum_free/nr_free;
983    else
984        r = (ub+lb)/2;
985
986    return r;
987}
988
989//
990// Solver for nu-svm classification and regression
991//
992// additional constraint: e^T \alpha = constant
993//
994class Solver_NU : public Solver
995{
996public:
997    Solver_NU() {}
998    void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
999           double *alpha, double Cp, double Cn, double eps,
1000           SolutionInfo* si, int shrinking)
1001    {
1002        this->si = si;
1003        Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1004    }
1005private:
1006    SolutionInfo *si;
1007    int select_working_set(int &i, int &j);
1008    double calculate_rho();
1009    bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1010    void do_shrinking();
1011};
1012
1013// return 1 if already optimal, return 0 otherwise
1014int Solver_NU::select_working_set(int &out_i, int &out_j)
1015{
1016    // return i,j such that y_i = y_j and
1017    // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1018    // j: minimizes the decrease of obj value
1019    //    (if quadratic coefficeint <= 0, replace it with tau)
1020    //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1021
1022    double Gmaxp = -INF;
1023    double Gmaxp2 = -INF;
1024    int Gmaxp_idx = -1;
1025
1026    double Gmaxn = -INF;
1027    double Gmaxn2 = -INF;
1028    int Gmaxn_idx = -1;
1029
1030    int Gmin_idx = -1;
1031    double obj_diff_min = INF;
1032
1033    for(int t=0;t<active_size;t++)
1034        if(y[t]==+1)
1035        {
1036            if(!is_upper_bound(t))
1037                if(-G[t] >= Gmaxp)
1038                {
1039                    Gmaxp = -G[t];
1040                    Gmaxp_idx = t;
1041                }
1042        }
1043        else
1044        {
1045            if(!is_lower_bound(t))
1046                if(G[t] >= Gmaxn)
1047                {
1048                    Gmaxn = G[t];
1049                    Gmaxn_idx = t;
1050                }
1051        }
1052
1053    int ip = Gmaxp_idx;
1054    int in = Gmaxn_idx;
1055    const Qfloat *Q_ip = NULL;
1056    const Qfloat *Q_in = NULL;
1057    if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1058        Q_ip = Q->get_Q(ip,active_size);
1059    if(in != -1)
1060        Q_in = Q->get_Q(in,active_size);
1061
1062    for(int j=0;j<active_size;j++)
1063    {
1064        if(y[j]==+1)
1065        {
1066            if (!is_lower_bound(j)) 
1067            {
1068                double grad_diff=Gmaxp+G[j];
1069                if (G[j] >= Gmaxp2)
1070                    Gmaxp2 = G[j];
1071                if (grad_diff > 0)
1072                {
1073                    double obj_diff; 
1074                    double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
1075                    if (quad_coef > 0)
1076                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
1077                    else
1078                        obj_diff = -(grad_diff*grad_diff)/TAU;
1079
1080                    if (obj_diff <= obj_diff_min)
1081                    {
1082                        Gmin_idx=j;
1083                        obj_diff_min = obj_diff;
1084                    }
1085                }
1086            }
1087        }
1088        else
1089        {
1090            if (!is_upper_bound(j))
1091            {
1092                double grad_diff=Gmaxn-G[j];
1093                if (-G[j] >= Gmaxn2)
1094                    Gmaxn2 = -G[j];
1095                if (grad_diff > 0)
1096                {
1097                    double obj_diff; 
1098                    double quad_coef = QD[in]+QD[j]-2*Q_in[j];
1099                    if (quad_coef > 0)
1100                        obj_diff = -(grad_diff*grad_diff)/quad_coef;
1101                    else
1102                        obj_diff = -(grad_diff*grad_diff)/TAU;
1103
1104                    if (obj_diff <= obj_diff_min)
1105                    {
1106                        Gmin_idx=j;
1107                        obj_diff_min = obj_diff;
1108                    }
1109                }
1110            }
1111        }
1112    }
1113
1114    if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1115        return 1;
1116
1117    if (y[Gmin_idx] == +1)
1118        out_i = Gmaxp_idx;
1119    else
1120        out_i = Gmaxn_idx;
1121    out_j = Gmin_idx;
1122
1123    return 0;
1124}
1125
1126bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1127{
1128    if(is_upper_bound(i))
1129    {
1130        if(y[i]==+1)
1131            return(-G[i] > Gmax1);
1132        else   
1133            return(-G[i] > Gmax4);
1134    }
1135    else if(is_lower_bound(i))
1136    {
1137        if(y[i]==+1)
1138            return(G[i] > Gmax2);
1139        else   
1140            return(G[i] > Gmax3);
1141    }
1142    else
1143        return(false);
1144}
1145
1146void Solver_NU::do_shrinking()
1147{
1148    double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1149    double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1150    double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1151    double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1152
1153    // find maximal violating pair first
1154    int i;
1155    for(i=0;i<active_size;i++)
1156    {
1157        if(!is_upper_bound(i))
1158        {
1159            if(y[i]==+1)
1160            {
1161                if(-G[i] > Gmax1) Gmax1 = -G[i];
1162            }
1163            else    if(-G[i] > Gmax4) Gmax4 = -G[i];
1164        }
1165        if(!is_lower_bound(i))
1166        {
1167            if(y[i]==+1)
1168            {   
1169                if(G[i] > Gmax2) Gmax2 = G[i];
1170            }
1171            else    if(G[i] > Gmax3) Gmax3 = G[i];
1172        }
1173    }
1174
1175    if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
1176    {
1177        unshrink = true;
1178        reconstruct_gradient();
1179        active_size = l;
1180    }
1181
1182    for(i=0;i<active_size;i++)
1183        if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1184        {
1185            active_size--;
1186            while (active_size > i)
1187            {
1188                if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1189                {
1190                    swap_index(i,active_size);
1191                    break;
1192                }
1193                active_size--;
1194            }
1195        }
1196}
1197
1198double Solver_NU::calculate_rho()
1199{
1200    int nr_free1 = 0,nr_free2 = 0;
1201    double ub1 = INF, ub2 = INF;
1202    double lb1 = -INF, lb2 = -INF;
1203    double sum_free1 = 0, sum_free2 = 0;
1204
1205    for(int i=0;i<active_size;i++)
1206    {
1207        if(y[i]==+1)
1208        {
1209            if(is_upper_bound(i))
1210                lb1 = max(lb1,G[i]);
1211            else if(is_lower_bound(i))
1212                ub1 = min(ub1,G[i]);
1213            else
1214            {
1215                ++nr_free1;
1216                sum_free1 += G[i];
1217            }
1218        }
1219        else
1220        {
1221            if(is_upper_bound(i))
1222                lb2 = max(lb2,G[i]);
1223            else if(is_lower_bound(i))
1224                ub2 = min(ub2,G[i]);
1225            else
1226            {
1227                ++nr_free2;
1228                sum_free2 += G[i];
1229            }
1230        }
1231    }
1232
1233    double r1,r2;
1234    if(nr_free1 > 0)
1235        r1 = sum_free1/nr_free1;
1236    else
1237        r1 = (ub1+lb1)/2;
1238   
1239    if(nr_free2 > 0)
1240        r2 = sum_free2/nr_free2;
1241    else
1242        r2 = (ub2+lb2)/2;
1243   
1244    si->r = (r1+r2)/2;
1245    return (r1-r2)/2;
1246}
1247
1248//
1249// Q matrices for various formulations
1250//
1251class SVC_Q: public Kernel
1252{ 
1253public:
1254    SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1255    :Kernel(prob.l, prob.x, param)
1256    {
1257        clone(y,y_,prob.l);
1258        cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1259        QD = new double[prob.l];
1260        for(int i=0;i<prob.l;i++)
1261            QD[i] = (this->*kernel_function)(i,i);
1262    }
1263   
1264    Qfloat *get_Q(int i, int len) const
1265    {
1266        Qfloat *data;
1267        int start, j;
1268        if((start = cache->get_data(i,&data,len)) < len)
1269        {
1270            for(j=start;j<len;j++)
1271                data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1272        }
1273        return data;
1274    }
1275
1276    double *get_QD() const
1277    {
1278        return QD;
1279    }
1280
1281    void swap_index(int i, int j) const
1282    {
1283        cache->swap_index(i,j);
1284        Kernel::swap_index(i,j);
1285        swap(y[i],y[j]);
1286        swap(QD[i],QD[j]);
1287    }
1288
1289    ~SVC_Q()
1290    {
1291        delete[] y;
1292        delete cache;
1293        delete[] QD;
1294    }
1295private:
1296    schar *y;
1297    Cache *cache;
1298    double *QD;
1299};
1300
1301class ONE_CLASS_Q: public Kernel
1302{
1303public:
1304    ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1305    :Kernel(prob.l, prob.x, param)
1306    {
1307        cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1308        QD = new double[prob.l];
1309        for(int i=0;i<prob.l;i++)
1310            QD[i] = (this->*kernel_function)(i,i);
1311    }
1312   
1313    Qfloat *get_Q(int i, int len) const
1314    {
1315        Qfloat *data;
1316        int start, j;
1317        if((start = cache->get_data(i,&data,len)) < len)
1318        {
1319            for(j=start;j<len;j++)
1320                data[j] = (Qfloat)(this->*kernel_function)(i,j);
1321        }
1322        return data;
1323    }
1324
1325    double *get_QD() const
1326    {
1327        return QD;
1328    }
1329
1330    void swap_index(int i, int j) const
1331    {
1332        cache->swap_index(i,j);
1333        Kernel::swap_index(i,j);
1334        swap(QD[i],QD[j]);
1335    }
1336
1337    ~ONE_CLASS_Q()
1338    {
1339        delete cache;
1340        delete[] QD;
1341    }
1342private:
1343    Cache *cache;
1344    double *QD;
1345};
1346
1347class SVR_Q: public Kernel
1348{ 
1349public:
1350    SVR_Q(const svm_problem& prob, const svm_parameter& param)
1351    :Kernel(prob.l, prob.x, param)
1352    {
1353        l = prob.l;
1354        cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1355        QD = new double[2*l];
1356        sign = new schar[2*l];
1357        index = new int[2*l];
1358        for(int k=0;k<l;k++)
1359        {
1360            sign[k] = 1;
1361            sign[k+l] = -1;
1362            index[k] = k;
1363            index[k+l] = k;
1364            QD[k] = (this->*kernel_function)(k,k);
1365            QD[k+l] = QD[k];
1366        }
1367        buffer[0] = new Qfloat[2*l];
1368        buffer[1] = new Qfloat[2*l];
1369        next_buffer = 0;
1370    }
1371
1372    void swap_index(int i, int j) const
1373    {
1374        swap(sign[i],sign[j]);
1375        swap(index[i],index[j]);
1376        swap(QD[i],QD[j]);
1377    }
1378   
1379    Qfloat *get_Q(int i, int len) const
1380    {
1381        Qfloat *data;
1382        int j, real_i = index[i];
1383        if(cache->get_data(real_i,&data,l) < l)
1384        {
1385            for(j=0;j<l;j++)
1386                data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1387        }
1388
1389        // reorder and copy
1390        Qfloat *buf = buffer[next_buffer];
1391        next_buffer = 1 - next_buffer;
1392        schar si = sign[i];
1393        for(j=0;j<len;j++)
1394            buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1395        return buf;
1396    }
1397
1398    double *get_QD() const
1399    {
1400        return QD;
1401    }
1402
1403    ~SVR_Q()
1404    {
1405        delete cache;
1406        delete[] sign;
1407        delete[] index;
1408        delete[] buffer[0];
1409        delete[] buffer[1];
1410        delete[] QD;
1411    }
1412private:
1413    int l;
1414    Cache *cache;
1415    schar *sign;
1416    int *index;
1417    mutable int next_buffer;
1418    Qfloat *buffer[2];
1419    double *QD;
1420};
1421
1422//
1423// construct and solve various formulations
1424//
1425static void solve_c_svc(
1426    const svm_problem *prob, const svm_parameter* param,
1427    double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1428{
1429    int l = prob->l;
1430    double *minus_ones = new double[l];
1431    schar *y = new schar[l];
1432
1433    int i;
1434
1435    for(i=0;i<l;i++)
1436    {
1437        alpha[i] = 0;
1438        minus_ones[i] = -1;
1439        if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
1440    }
1441
1442    Solver s;
1443    s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1444        alpha, Cp, Cn, param->eps, si, param->shrinking);
1445
1446    double sum_alpha=0;
1447    for(i=0;i<l;i++)
1448        sum_alpha += alpha[i];
1449
1450    if (Cp==Cn)
1451        info("nu = %f\n", sum_alpha/(Cp*prob->l));
1452
1453    for(i=0;i<l;i++)
1454        alpha[i] *= y[i];
1455
1456    delete[] minus_ones;
1457    delete[] y;
1458}
1459
1460static void solve_nu_svc(
1461    const svm_problem *prob, const svm_parameter *param,
1462    double *alpha, Solver::SolutionInfo* si)
1463{
1464    int i;
1465    int l = prob->l;
1466    double nu = param->nu;
1467
1468    schar *y = new schar[l];
1469
1470    for(i=0;i<l;i++)
1471        if(prob->y[i]>0)
1472            y[i] = +1;
1473        else
1474            y[i] = -1;
1475
1476    double sum_pos = nu*l/2;
1477    double sum_neg = nu*l/2;
1478
1479    for(i=0;i<l;i++)
1480        if(y[i] == +1)
1481        {
1482            alpha[i] = min(1.0,sum_pos);
1483            sum_pos -= alpha[i];
1484        }
1485        else
1486        {
1487            alpha[i] = min(1.0,sum_neg);
1488            sum_neg -= alpha[i];
1489        }
1490
1491    double *zeros = new double[l];
1492
1493    for(i=0;i<l;i++)
1494        zeros[i] = 0;
1495
1496    Solver_NU s;
1497    s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1498        alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
1499    double r = si->r;
1500
1501    info("C = %f\n",1/r);
1502
1503    for(i=0;i<l;i++)
1504        alpha[i] *= y[i]/r;
1505
1506    si->rho /= r;
1507    si->obj /= (r*r);
1508    si->upper_bound_p = 1/r;
1509    si->upper_bound_n = 1/r;
1510
1511    delete[] y;
1512    delete[] zeros;
1513}
1514
1515static void solve_one_class(
1516    const svm_problem *prob, const svm_parameter *param,
1517    double *alpha, Solver::SolutionInfo* si)
1518{
1519    int l = prob->l;
1520    double *zeros = new double[l];
1521    schar *ones = new schar[l];
1522    int i;
1523
1524    int n = (int)(param->nu*prob->l);   // # of alpha's at upper bound
1525
1526    for(i=0;i<n;i++)
1527        alpha[i] = 1;
1528    if(n<prob->l)
1529        alpha[n] = param->nu * prob->l - n;
1530    for(i=n+1;i<l;i++)
1531        alpha[i] = 0;
1532
1533    for(i=0;i<l;i++)
1534    {
1535        zeros[i] = 0;
1536        ones[i] = 1;
1537    }
1538
1539    Solver s;
1540    s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1541        alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1542
1543    delete[] zeros;
1544    delete[] ones;
1545}
1546
1547static void solve_epsilon_svr(
1548    const svm_problem *prob, const svm_parameter *param,
1549    double *alpha, Solver::SolutionInfo* si)
1550{
1551    int l = prob->l;
1552    double *alpha2 = new double[2*l];
1553    double *linear_term = new double[2*l];
1554    schar *y = new schar[2*l];
1555    int i;
1556
1557    for(i=0;i<l;i++)
1558    {
1559        alpha2[i] = 0;
1560        linear_term[i] = param->p - prob->y[i];
1561        y[i] = 1;
1562
1563        alpha2[i+l] = 0;
1564        linear_term[i+l] = param->p + prob->y[i];
1565        y[i+l] = -1;
1566    }
1567
1568    Solver s;
1569    s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1570        alpha2, param->C, param->C, param->eps, si, param->shrinking);
1571
1572    double sum_alpha = 0;
1573    for(i=0;i<l;i++)
1574    {
1575        alpha[i] = alpha2[i] - alpha2[i+l];
1576        sum_alpha += fabs(alpha[i]);
1577    }
1578    info("nu = %f\n",sum_alpha/(param->C*l));
1579
1580    delete[] alpha2;
1581    delete[] linear_term;
1582    delete[] y;
1583}
1584
1585static void solve_nu_svr(
1586    const svm_problem *prob, const svm_parameter *param,
1587    double *alpha, Solver::SolutionInfo* si)
1588{
1589    int l = prob->l;
1590    double C = param->C;
1591    double *alpha2 = new double[2*l];
1592    double *linear_term = new double[2*l];
1593    schar *y = new schar[2*l];
1594    int i;
1595
1596    double sum = C * param->nu * l / 2;
1597    for(i=0;i<l;i++)
1598    {
1599        alpha2[i] = alpha2[i+l] = min(sum,C);
1600        sum -= alpha2[i];
1601
1602        linear_term[i] = - prob->y[i];
1603        y[i] = 1;
1604
1605        linear_term[i+l] = prob->y[i];
1606        y[i+l] = -1;
1607    }
1608
1609    Solver_NU s;
1610    s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1611        alpha2, C, C, param->eps, si, param->shrinking);
1612
1613    info("epsilon = %f\n",-si->r);
1614
1615    for(i=0;i<l;i++)
1616        alpha[i] = alpha2[i] - alpha2[i+l];
1617
1618    delete[] alpha2;
1619    delete[] linear_term;
1620    delete[] y;
1621}
1622
1623//
1624// decision_function
1625//
1626struct decision_function
1627{
1628    double *alpha;
1629    double rho; 
1630};
1631
1632static decision_function svm_train_one(
1633    const svm_problem *prob, const svm_parameter *param,
1634    double Cp, double Cn)
1635{
1636    double *alpha = Malloc(double,prob->l);
1637    Solver::SolutionInfo si;
1638    switch(param->svm_type)
1639    {
1640        case C_SVC:
1641            solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1642            break;
1643        case NU_SVC:
1644            solve_nu_svc(prob,param,alpha,&si);
1645            break;
1646        case ONE_CLASS:
1647            solve_one_class(prob,param,alpha,&si);
1648            break;
1649        case EPSILON_SVR:
1650            solve_epsilon_svr(prob,param,alpha,&si);
1651            break;
1652        case NU_SVR:
1653            solve_nu_svr(prob,param,alpha,&si);
1654            break;
1655    }
1656
1657    info("obj = %f, rho = %f\n",si.obj,si.rho);
1658
1659    // output SVs
1660
1661    int nSV = 0;
1662    int nBSV = 0;
1663    for(int i=0;i<prob->l;i++)
1664    {
1665        if(fabs(alpha[i]) > 0)
1666        {
1667            ++nSV;
1668            if(prob->y[i] > 0)
1669            {
1670                if(fabs(alpha[i]) >= si.upper_bound_p)
1671                    ++nBSV;
1672            }
1673            else
1674            {
1675                if(fabs(alpha[i]) >= si.upper_bound_n)
1676                    ++nBSV;
1677            }
1678        }
1679    }
1680
1681    info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1682
1683    decision_function f;
1684    f.alpha = alpha;
1685    f.rho = si.rho;
1686    return f;
1687}
1688
1689// Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1690static void sigmoid_train(
1691    int l, const double *dec_values, const double *labels, 
1692    double& A, double& B)
1693{
1694    double prior1=0, prior0 = 0;
1695    int i;
1696
1697    for (i=0;i<l;i++)
1698        if (labels[i] > 0) prior1+=1;
1699        else prior0+=1;
1700   
1701    int max_iter=100;   // Maximal number of iterations
1702    double min_step=1e-10;  // Minimal step taken in line search
1703    double sigma=1e-12; // For numerically strict PD of Hessian
1704    double eps=1e-5;
1705    double hiTarget=(prior1+1.0)/(prior1+2.0);
1706    double loTarget=1/(prior0+2.0);
1707    double *t=Malloc(double,l);
1708    double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1709    double newA,newB,newf,d1,d2;
1710    int iter; 
1711   
1712    // Initial Point and Initial Fun Value
1713    A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1714    double fval = 0.0;
1715
1716    for (i=0;i<l;i++)
1717    {
1718        if (labels[i]>0) t[i]=hiTarget;
1719        else t[i]=loTarget;
1720        fApB = dec_values[i]*A+B;
1721        if (fApB>=0)
1722            fval += t[i]*fApB + log(1+exp(-fApB));
1723        else
1724            fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1725    }
1726    for (iter=0;iter<max_iter;iter++)
1727    {
1728        // Update Gradient and Hessian (use H' = H + sigma I)
1729        h11=sigma; // numerically ensures strict PD
1730        h22=sigma;
1731        h21=0.0;g1=0.0;g2=0.0;
1732        for (i=0;i<l;i++)
1733        {
1734            fApB = dec_values[i]*A+B;
1735            if (fApB >= 0)
1736            {
1737                p=exp(-fApB)/(1.0+exp(-fApB));
1738                q=1.0/(1.0+exp(-fApB));
1739            }
1740            else
1741            {
1742                p=1.0/(1.0+exp(fApB));
1743                q=exp(fApB)/(1.0+exp(fApB));
1744            }
1745            d2=p*q;
1746            h11+=dec_values[i]*dec_values[i]*d2;
1747            h22+=d2;
1748            h21+=dec_values[i]*d2;
1749            d1=t[i]-p;
1750            g1+=dec_values[i]*d1;
1751            g2+=d1;
1752        }
1753
1754        // Stopping Criteria
1755        if (fabs(g1)<eps && fabs(g2)<eps)
1756            break;
1757
1758        // Finding Newton direction: -inv(H') * g
1759        det=h11*h22-h21*h21;
1760        dA=-(h22*g1 - h21 * g2) / det;
1761        dB=-(-h21*g1+ h11 * g2) / det;
1762        gd=g1*dA+g2*dB;
1763
1764
1765        stepsize = 1;       // Line Search
1766        while (stepsize >= min_step)
1767        {
1768            newA = A + stepsize * dA;
1769            newB = B + stepsize * dB;
1770
1771            // New function value
1772            newf = 0.0;
1773            for (i=0;i<l;i++)
1774            {
1775                fApB = dec_values[i]*newA+newB;
1776                if (fApB >= 0)
1777                    newf += t[i]*fApB + log(1+exp(-fApB));
1778                else
1779                    newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1780            }
1781            // Check sufficient decrease
1782            if (newf<fval+0.0001*stepsize*gd)
1783            {
1784                A=newA;B=newB;fval=newf;
1785                break;
1786            }
1787            else
1788                stepsize = stepsize / 2.0;
1789        }
1790
1791        if (stepsize < min_step)
1792        {
1793            info("Line search fails in two-class probability estimates\n");
1794            break;
1795        }
1796    }
1797
1798    if (iter>=max_iter)
1799        info("Reaching maximal iterations in two-class probability estimates\n");
1800    free(t);
1801}
1802
1803static double sigmoid_predict(double decision_value, double A, double B)
1804{
1805    double fApB = decision_value*A+B;
1806    // 1-p used later; avoid catastrophic cancellation
1807    if (fApB >= 0)
1808        return exp(-fApB)/(1.0+exp(-fApB));
1809    else
1810        return 1.0/(1+exp(fApB)) ;
1811}
1812
1813// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1814static void multiclass_probability(int k, double **r, double *p)
1815{
1816    int t,j;
1817    int iter = 0, max_iter=max(100,k);
1818    double **Q=Malloc(double *,k);
1819    double *Qp=Malloc(double,k);
1820    double pQp, eps=0.005/k;
1821   
1822    for (t=0;t<k;t++)
1823    {
1824        p[t]=1.0/k;  // Valid if k = 1
1825        Q[t]=Malloc(double,k);
1826        Q[t][t]=0;
1827        for (j=0;j<t;j++)
1828        {
1829            Q[t][t]+=r[j][t]*r[j][t];
1830            Q[t][j]=Q[j][t];
1831        }
1832        for (j=t+1;j<k;j++)
1833        {
1834            Q[t][t]+=r[j][t]*r[j][t];
1835            Q[t][j]=-r[j][t]*r[t][j];
1836        }
1837    }
1838    for (iter=0;iter<max_iter;iter++)
1839    {
1840        // stopping condition, recalculate QP,pQP for numerical accuracy
1841        pQp=0;
1842        for (t=0;t<k;t++)
1843        {
1844            Qp[t]=0;
1845            for (j=0;j<k;j++)
1846                Qp[t]+=Q[t][j]*p[j];
1847            pQp+=p[t]*Qp[t];
1848        }
1849        double max_error=0;
1850        for (t=0;t<k;t++)
1851        {
1852            double error=fabs(Qp[t]-pQp);
1853            if (error>max_error)
1854                max_error=error;
1855        }
1856        if (max_error<eps) break;
1857       
1858        for (t=0;t<k;t++)
1859        {
1860            double diff=(-Qp[t]+pQp)/Q[t][t];
1861            p[t]+=diff;
1862            pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1863            for (j=0;j<k;j++)
1864            {
1865                Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1866                p[j]/=(1+diff);
1867            }
1868        }
1869    }
1870    if (iter>=max_iter)
1871        info("Exceeds max_iter in multiclass_prob\n");
1872    for(t=0;t<k;t++) free(Q[t]);
1873    free(Q);
1874    free(Qp);
1875}
1876
1877// Cross-validation decision values for probability estimates
1878static void svm_binary_svc_probability(
1879    const svm_problem *prob, const svm_parameter *param,
1880    double Cp, double Cn, double& probA, double& probB)
1881{
1882    int i;
1883    int nr_fold = 5;
1884    int *perm = Malloc(int,prob->l);
1885    double *dec_values = Malloc(double,prob->l);
1886
1887    // random shuffle
1888    for(i=0;i<prob->l;i++) perm[i]=i;
1889    for(i=0;i<prob->l;i++)
1890    {
1891        int j = i+rand()%(prob->l-i);
1892        swap(perm[i],perm[j]);
1893    }
1894    for(i=0;i<nr_fold;i++)
1895    {
1896        int begin = i*prob->l/nr_fold;
1897        int end = (i+1)*prob->l/nr_fold;
1898        int j,k;
1899        struct svm_problem subprob;
1900
1901        subprob.l = prob->l-(end-begin);
1902        subprob.x = Malloc(struct svm_node*,subprob.l);
1903        subprob.y = Malloc(double,subprob.l);
1904           
1905        k=0;
1906        for(j=0;j<begin;j++)
1907        {
1908            subprob.x[k] = prob->x[perm[j]];
1909            subprob.y[k] = prob->y[perm[j]];
1910            ++k;
1911        }
1912        for(j=end;j<prob->l;j++)
1913        {
1914            subprob.x[k] = prob->x[perm[j]];
1915            subprob.y[k] = prob->y[perm[j]];
1916            ++k;
1917        }
1918        int p_count=0,n_count=0;
1919        for(j=0;j<k;j++)
1920            if(subprob.y[j]>0)
1921                p_count++;
1922            else
1923                n_count++;
1924
1925        if(p_count==0 && n_count==0)
1926            for(j=begin;j<end;j++)
1927                dec_values[perm[j]] = 0;
1928        else if(p_count > 0 && n_count == 0)
1929            for(j=begin;j<end;j++)
1930                dec_values[perm[j]] = 1;
1931        else if(p_count == 0 && n_count > 0)
1932            for(j=begin;j<end;j++)
1933                dec_values[perm[j]] = -1;
1934        else
1935        {
1936            svm_parameter subparam = *param;
1937            subparam.probability=0;
1938            subparam.C=1.0;
1939            subparam.nr_weight=2;
1940            subparam.weight_label = Malloc(int,2);
1941            subparam.weight = Malloc(double,2);
1942            subparam.weight_label[0]=+1;
1943            subparam.weight_label[1]=-1;
1944            subparam.weight[0]=Cp;
1945            subparam.weight[1]=Cn;
1946            struct svm_model *submodel = svm_train(&subprob,&subparam);
1947            for(j=begin;j<end;j++)
1948            {
1949                svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 
1950                // ensure +1 -1 order; reason not using CV subroutine
1951                dec_values[perm[j]] *= submodel->label[0];
1952            }       
1953            svm_free_and_destroy_model(&submodel);
1954            svm_destroy_param(&subparam);
1955        }
1956        free(subprob.x);
1957        free(subprob.y);
1958    }       
1959    sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1960    free(dec_values);
1961    free(perm);
1962}
1963
1964// Return parameter of a Laplace distribution
1965static double svm_svr_probability(
1966    const svm_problem *prob, const svm_parameter *param)
1967{
1968    int i;
1969    int nr_fold = 5;
1970    double *ymv = Malloc(double,prob->l);
1971    double mae = 0;
1972
1973    svm_parameter newparam = *param;
1974    newparam.probability = 0;
1975    svm_cross_validation(prob,&newparam,nr_fold,ymv);
1976    for(i=0;i<prob->l;i++)
1977    {
1978        ymv[i]=prob->y[i]-ymv[i];
1979        mae += fabs(ymv[i]);
1980    }       
1981    mae /= prob->l;
1982    double std=sqrt(2*mae*mae);
1983    int count=0;
1984    mae=0;
1985    for(i=0;i<prob->l;i++)
1986        if (fabs(ymv[i]) > 5*std) 
1987            count=count+1;
1988        else 
1989            mae+=fabs(ymv[i]);
1990    mae /= (prob->l-count);
1991    info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
1992    free(ymv);
1993    return mae;
1994}
1995
1996
1997// label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
1998// perm, length l, must be allocated before calling this subroutine
1999static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
2000{
2001    int l = prob->l;
2002    int max_nr_class = 16;
2003    int nr_class = 0;
2004    int *label = Malloc(int,max_nr_class);
2005    int *count = Malloc(int,max_nr_class);
2006    int *data_label = Malloc(int,l);   
2007    int i;
2008
2009    for(i=0;i<l;i++)
2010    {
2011        int this_label = (int)prob->y[i];
2012        int j;
2013        for(j=0;j<nr_class;j++)
2014        {
2015            if(this_label == label[j])
2016            {
2017                ++count[j];
2018                break;
2019            }
2020        }
2021        data_label[i] = j;
2022        if(j == nr_class)
2023        {
2024            if(nr_class == max_nr_class)
2025            {
2026                max_nr_class *= 2;
2027                label = (int *)realloc(label,max_nr_class*sizeof(int));
2028                count = (int *)realloc(count,max_nr_class*sizeof(int));
2029            }
2030            label[nr_class] = this_label;
2031            count[nr_class] = 1;
2032            ++nr_class;
2033        }
2034    }
2035
2036    int *start = Malloc(int,nr_class);
2037    start[0] = 0;
2038    for(i=1;i<nr_class;i++)
2039        start[i] = start[i-1]+count[i-1];
2040    for(i=0;i<l;i++)
2041    {
2042        perm[start[data_label[i]]] = i;
2043        ++start[data_label[i]];
2044    }
2045    start[0] = 0;
2046    for(i=1;i<nr_class;i++)
2047        start[i] = start[i-1]+count[i-1];
2048
2049    *nr_class_ret = nr_class;
2050    *label_ret = label;
2051    *start_ret = start;
2052    *count_ret = count;
2053    free(data_label);
2054}
2055
2056//
2057// Interface functions
2058//
2059svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2060{
2061    svm_model *model = Malloc(svm_model,1);
2062    model->param = *param;
2063    model->free_sv = 0; // XXX
2064
2065    if(param->svm_type == ONE_CLASS ||
2066       param->svm_type == EPSILON_SVR ||
2067       param->svm_type == NU_SVR)
2068    {
2069        // regression or one-class-svm
2070        model->nr_class = 2;
2071        model->label = NULL;
2072        model->nSV = NULL;
2073        model->probA = NULL; model->probB = NULL;
2074        model->sv_coef = Malloc(double *,1);
2075
2076        if(param->probability && 
2077           (param->svm_type == EPSILON_SVR ||
2078            param->svm_type == NU_SVR))
2079        {
2080            model->probA = Malloc(double,1);
2081            model->probA[0] = svm_svr_probability(prob,param);
2082        }
2083
2084        decision_function f = svm_train_one(prob,param,0,0);
2085        model->rho = Malloc(double,1);
2086        model->rho[0] = f.rho;
2087
2088        int nSV = 0;
2089        int i;
2090        for(i=0;i<prob->l;i++)
2091            if(fabs(f.alpha[i]) > 0) ++nSV;
2092        model->l = nSV;
2093        model->SV = Malloc(svm_node *,nSV);
2094        model->sv_coef[0] = Malloc(double,nSV);
2095        int j = 0;
2096        for(i=0;i<prob->l;i++)
2097            if(fabs(f.alpha[i]) > 0)
2098            {
2099                model->SV[j] = prob->x[i];
2100                model->sv_coef[0][j] = f.alpha[i];
2101                ++j;
2102            }       
2103
2104        free(f.alpha);
2105    }
2106    else
2107    {
2108        // classification
2109        int l = prob->l;
2110        int nr_class;
2111        int *label = NULL;
2112        int *start = NULL;
2113        int *count = NULL;
2114        int *perm = Malloc(int,l);
2115
2116        // group training data of the same class
2117        svm_group_classes(prob,&nr_class,&label,&start,&count,perm);       
2118        svm_node **x = Malloc(svm_node *,l);
2119        int i;
2120        for(i=0;i<l;i++)
2121            x[i] = prob->x[perm[i]];
2122
2123        // calculate weighted C
2124
2125        double *weighted_C = Malloc(double, nr_class);
2126        for(i=0;i<nr_class;i++)
2127            weighted_C[i] = param->C;
2128        for(i=0;i<param->nr_weight;i++)
2129        {   
2130            int j;
2131            for(j=0;j<nr_class;j++)
2132                if(param->weight_label[i] == label[j])
2133                    break;
2134            if(j == nr_class)
2135                fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2136            else
2137                weighted_C[j] *= param->weight[i];
2138        }
2139
2140        // train k*(k-1)/2 models
2141       
2142        bool *nonzero = Malloc(bool,l);
2143        for(i=0;i<l;i++)
2144            nonzero[i] = false;
2145        decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2146
2147        double *probA=NULL,*probB=NULL;
2148        if (param->probability)
2149        {
2150            probA=Malloc(double,nr_class*(nr_class-1)/2);
2151            probB=Malloc(double,nr_class*(nr_class-1)/2);
2152        }
2153
2154        int p = 0;
2155        for(i=0;i<nr_class;i++)
2156            for(int j=i+1;j<nr_class;j++)
2157            {
2158                svm_problem sub_prob;
2159                int si = start[i], sj = start[j];
2160                int ci = count[i], cj = count[j];
2161                sub_prob.l = ci+cj;
2162                sub_prob.x = Malloc(svm_node *,sub_prob.l);
2163                sub_prob.y = Malloc(double,sub_prob.l);
2164                int k;
2165                for(k=0;k<ci;k++)
2166                {
2167                    sub_prob.x[k] = x[si+k];
2168                    sub_prob.y[k] = +1;
2169                }
2170                for(k=0;k<cj;k++)
2171                {
2172                    sub_prob.x[ci+k] = x[sj+k];
2173                    sub_prob.y[ci+k] = -1;
2174                }
2175
2176                if(param->probability)
2177                    svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2178
2179                f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2180                for(k=0;k<ci;k++)
2181                    if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2182                        nonzero[si+k] = true;
2183                for(k=0;k<cj;k++)
2184                    if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2185                        nonzero[sj+k] = true;
2186                free(sub_prob.x);
2187                free(sub_prob.y);
2188                ++p;
2189            }
2190
2191        // build output
2192
2193        model->nr_class = nr_class;
2194       
2195        model->label = Malloc(int,nr_class);
2196        for(i=0;i<nr_class;i++)
2197            model->label[i] = label[i];
2198       
2199        model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2200        for(i=0;i<nr_class*(nr_class-1)/2;i++)
2201            model->rho[i] = f[i].rho;
2202
2203        if(param->probability)
2204        {
2205            model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2206            model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2207            for(i=0;i<nr_class*(nr_class-1)/2;i++)
2208            {
2209                model->probA[i] = probA[i];
2210                model->probB[i] = probB[i];
2211            }
2212        }
2213        else
2214        {
2215            model->probA=NULL;
2216            model->probB=NULL;
2217        }
2218
2219        int total_sv = 0;
2220        int *nz_count = Malloc(int,nr_class);
2221        model->nSV = Malloc(int,nr_class);
2222        for(i=0;i<nr_class;i++)
2223        {
2224            int nSV = 0;
2225            for(int j=0;j<count[i];j++)
2226                if(nonzero[start[i]+j])
2227                {   
2228                    ++nSV;
2229                    ++total_sv;
2230                }
2231            model->nSV[i] = nSV;
2232            nz_count[i] = nSV;
2233        }
2234       
2235        info("Total nSV = %d\n",total_sv);
2236
2237        model->l = total_sv;
2238        model->SV = Malloc(svm_node *,total_sv);
2239        p = 0;
2240        for(i=0;i<l;i++)
2241            if(nonzero[i]) model->SV[p++] = x[i];
2242
2243        int *nz_start = Malloc(int,nr_class);
2244        nz_start[0] = 0;
2245        for(i=1;i<nr_class;i++)
2246            nz_start[i] = nz_start[i-1]+nz_count[i-1];
2247
2248        model->sv_coef = Malloc(double *,nr_class-1);
2249        for(i=0;i<nr_class-1;i++)
2250            model->sv_coef[i] = Malloc(double,total_sv);
2251
2252        p = 0;
2253        for(i=0;i<nr_class;i++)
2254            for(int j=i+1;j<nr_class;j++)
2255            {
2256                // classifier (i,j): coefficients with
2257                // i are in sv_coef[j-1][nz_start[i]...],
2258                // j are in sv_coef[i][nz_start[j]...]
2259
2260                int si = start[i];
2261                int sj = start[j];
2262                int ci = count[i];
2263                int cj = count[j];
2264               
2265                int q = nz_start[i];
2266                int k;
2267                for(k=0;k<ci;k++)
2268                    if(nonzero[si+k])
2269                        model->sv_coef[j-1][q++] = f[p].alpha[k];
2270                q = nz_start[j];
2271                for(k=0;k<cj;k++)
2272                    if(nonzero[sj+k])
2273                        model->sv_coef[i][q++] = f[p].alpha[ci+k];
2274                ++p;
2275            }
2276       
2277        free(label);
2278        free(probA);
2279        free(probB);
2280        free(count);
2281        free(perm);
2282        free(start);
2283        free(x);
2284        free(weighted_C);
2285        free(nonzero);
2286        for(i=0;i<nr_class*(nr_class-1)/2;i++)
2287            free(f[i].alpha);
2288        free(f);
2289        free(nz_count);
2290        free(nz_start);
2291    }
2292    return model;
2293}
2294
2295// Stratified cross validation
2296void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2297{
2298    int i;
2299    int *fold_start = Malloc(int,nr_fold+1);
2300    int l = prob->l;
2301    int *perm = Malloc(int,l);
2302    int nr_class;
2303
2304    // stratified cv may not give leave-one-out rate
2305    // Each class to l folds -> some folds may have zero elements
2306    if((param->svm_type == C_SVC ||
2307        param->svm_type == NU_SVC) && nr_fold < l)
2308    {
2309        int *start = NULL;
2310        int *label = NULL;
2311        int *count = NULL;
2312        svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2313
2314        // random shuffle and then data grouped by fold using the array perm
2315        int *fold_count = Malloc(int,nr_fold);
2316        int c;
2317        int *index = Malloc(int,l);
2318        for(i=0;i<l;i++)
2319            index[i]=perm[i];
2320        for (c=0; c<nr_class; c++) 
2321            for(i=0;i<count[c];i++)
2322            {
2323                int j = i+rand()%(count[c]-i);
2324                swap(index[start[c]+j],index[start[c]+i]);
2325            }
2326        for(i=0;i<nr_fold;i++)
2327        {
2328            fold_count[i] = 0;
2329            for (c=0; c<nr_class;c++)
2330                fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2331        }
2332        fold_start[0]=0;
2333        for (i=1;i<=nr_fold;i++)
2334            fold_start[i] = fold_start[i-1]+fold_count[i-1];
2335        for (c=0; c<nr_class;c++)
2336            for(i=0;i<nr_fold;i++)
2337            {
2338                int begin = start[c]+i*count[c]/nr_fold;
2339                int end = start[c]+(i+1)*count[c]/nr_fold;
2340                for(int j=begin;j<end;j++)
2341                {
2342                    perm[fold_start[i]] = index[j];
2343                    fold_start[i]++;
2344                }
2345            }
2346        fold_start[0]=0;
2347        for (i=1;i<=nr_fold;i++)
2348            fold_start[i] = fold_start[i-1]+fold_count[i-1];
2349        free(start);   
2350        free(label);
2351        free(count);   
2352        free(index);
2353        free(fold_count);
2354    }
2355    else
2356    {
2357        for(i=0;i<l;i++) perm[i]=i;
2358        for(i=0;i<l;i++)
2359        {
2360            int j = i+rand()%(l-i);
2361            swap(perm[i],perm[j]);
2362        }
2363        for(i=0;i<=nr_fold;i++)
2364            fold_start[i]=i*l/nr_fold;
2365    }
2366
2367    for(i=0;i<nr_fold;i++)
2368    {
2369        int begin = fold_start[i];
2370        int end = fold_start[i+1];
2371        int j,k;
2372        struct svm_problem subprob;
2373
2374        subprob.l = l-(end-begin);
2375        subprob.x = Malloc(struct svm_node*,subprob.l);
2376        subprob.y = Malloc(double,subprob.l);
2377           
2378        k=0;
2379        for(j=0;j<begin;j++)
2380        {
2381            subprob.x[k] = prob->x[perm[j]];
2382            subprob.y[k] = prob->y[perm[j]];
2383            ++k;
2384        }
2385        for(j=end;j<l;j++)
2386        {
2387            subprob.x[k] = prob->x[perm[j]];
2388            subprob.y[k] = prob->y[perm[j]];
2389            ++k;
2390        }
2391        struct svm_model *submodel = svm_train(&subprob,param);
2392        if(param->probability && 
2393           (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2394        {
2395            double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2396            for(j=begin;j<end;j++)
2397                target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2398            free(prob_estimates);           
2399        }
2400        else
2401            for(j=begin;j<end;j++)
2402                target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2403        svm_free_and_destroy_model(&submodel);
2404        free(subprob.x);
2405        free(subprob.y);
2406    }       
2407    free(fold_start);
2408    free(perm); 
2409}
2410
2411
2412int svm_get_svm_type(const svm_model *model)
2413{
2414    return model->param.svm_type;
2415}
2416
2417int svm_get_nr_class(const svm_model *model)
2418{
2419    return model->nr_class;
2420}
2421
2422void svm_get_labels(const svm_model *model, int* label)
2423{
2424    if (model->label != NULL)
2425        for(int i=0;i<model->nr_class;i++)
2426            label[i] = model->label[i];
2427}
2428
2429double svm_get_svr_probability(const svm_model *model)
2430{
2431    if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2432        model->probA!=NULL)
2433        return model->probA[0];
2434    else
2435    {
2436        fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2437        return 0;
2438    }
2439}
2440
2441double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2442{
2443    if(model->param.svm_type == ONE_CLASS ||
2444       model->param.svm_type == EPSILON_SVR ||
2445       model->param.svm_type == NU_SVR)
2446    {
2447        double *sv_coef = model->sv_coef[0];
2448        double sum = 0;
2449        for(int i=0;i<model->l;i++)
2450            sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2451        sum -= model->rho[0];
2452        *dec_values = sum;
2453
2454        if(model->param.svm_type == ONE_CLASS)
2455            return (sum>0)?1:-1;
2456        else
2457            return sum;
2458    }
2459    else
2460    {
2461        int i;
2462        int nr_class = model->nr_class;
2463        int l = model->l;
2464       
2465        double *kvalue = Malloc(double,l);
2466        for(i=0;i<l;i++)
2467            kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2468
2469        int *start = Malloc(int,nr_class);
2470        start[0] = 0;
2471        for(i=1;i<nr_class;i++)
2472            start[i] = start[i-1]+model->nSV[i-1];
2473
2474        int *vote = Malloc(int,nr_class);
2475        for(i=0;i<nr_class;i++)
2476            vote[i] = 0;
2477
2478        int p=0;
2479        for(i=0;i<nr_class;i++)
2480            for(int j=i+1;j<nr_class;j++)
2481            {
2482                double sum = 0;
2483                int si = start[i];
2484                int sj = start[j];
2485                int ci = model->nSV[i];
2486                int cj = model->nSV[j];
2487               
2488                int k;
2489                double *coef1 = model->sv_coef[j-1];
2490                double *coef2 = model->sv_coef[i];
2491                for(k=0;k<ci;k++)
2492                    sum += coef1[si+k] * kvalue[si+k];
2493                for(k=0;k<cj;k++)
2494                    sum += coef2[sj+k] * kvalue[sj+k];
2495                sum -= model->rho[p];
2496                dec_values[p] = sum;
2497
2498                if(dec_values[p] > 0)
2499                    ++vote[i];
2500                else
2501                    ++vote[j];
2502                p++;
2503            }
2504
2505        int vote_max_idx = 0;
2506        for(i=1;i<nr_class;i++)
2507            if(vote[i] > vote[vote_max_idx])
2508                vote_max_idx = i;
2509
2510        free(kvalue);
2511        free(start);
2512        free(vote);
2513        return model->label[vote_max_idx];
2514    }
2515}
2516
2517double svm_predict(const svm_model *model, const svm_node *x)
2518{
2519    int nr_class = model->nr_class;
2520    double *dec_values;
2521    if(model->param.svm_type == ONE_CLASS ||
2522       model->param.svm_type == EPSILON_SVR ||
2523       model->param.svm_type == NU_SVR)
2524        dec_values = Malloc(double, 1);
2525    else 
2526        dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2527    double pred_result = svm_predict_values(model, x, dec_values);
2528    free(dec_values);
2529    return pred_result;
2530}
2531
2532double svm_predict_probability(
2533    const svm_model *model, const svm_node *x, double *prob_estimates)
2534{
2535    if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2536        model->probA!=NULL && model->probB!=NULL)
2537    {
2538        int i;
2539        int nr_class = model->nr_class;
2540        double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2541        svm_predict_values(model, x, dec_values);
2542
2543        double min_prob=1e-7;
2544        double **pairwise_prob=Malloc(double *,nr_class);
2545        for(i=0;i<nr_class;i++)
2546            pairwise_prob[i]=Malloc(double,nr_class);
2547        int k=0;
2548        for(i=0;i<nr_class;i++)
2549            for(int j=i+1;j<nr_class;j++)
2550            {
2551                pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2552                pairwise_prob[j][i]=1-pairwise_prob[i][j];
2553                k++;
2554            }
2555        multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2556
2557        int prob_max_idx = 0;
2558        for(i=1;i<nr_class;i++)
2559            if(prob_estimates[i] > prob_estimates[prob_max_idx])
2560                prob_max_idx = i;
2561        for(i=0;i<nr_class;i++)
2562            free(pairwise_prob[i]);
2563        free(dec_values);
2564        free(pairwise_prob);         
2565        return model->label[prob_max_idx];
2566    }
2567    else 
2568        return svm_predict(model, x);
2569}
2570
2571static const char *svm_type_table[] =
2572{
2573    "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2574};
2575
2576static const char *kernel_type_table[]=
2577{
2578    "linear","polynomial","rbf","sigmoid","precomputed",NULL
2579};
2580
2581int svm_save_model(const char *model_file_name, const svm_model *model)
2582{
2583    FILE *fp = fopen(model_file_name,"w");
2584    if(fp==NULL) return -1;
2585
2586    const svm_parameter& param = model->param;
2587
2588    fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2589    fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2590
2591    if(param.kernel_type == POLY)
2592        fprintf(fp,"degree %d\n", param.degree);
2593
2594    if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2595        fprintf(fp,"gamma %g\n", param.gamma);
2596
2597    if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2598        fprintf(fp,"coef0 %g\n", param.coef0);
2599
2600    int nr_class = model->nr_class;
2601    int l = model->l;
2602    fprintf(fp, "nr_class %d\n", nr_class);
2603    fprintf(fp, "total_sv %d\n",l);
2604   
2605    {
2606        fprintf(fp, "rho");
2607        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2608            fprintf(fp," %g",model->rho[i]);
2609        fprintf(fp, "\n");
2610    }
2611   
2612    if(model->label)
2613    {
2614        fprintf(fp, "label");
2615        for(int i=0;i<nr_class;i++)
2616            fprintf(fp," %d",model->label[i]);
2617        fprintf(fp, "\n");
2618    }
2619
2620    if(model->probA) // regression has probA only
2621    {
2622        fprintf(fp, "probA");
2623        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2624            fprintf(fp," %g",model->probA[i]);
2625        fprintf(fp, "\n");
2626    }
2627    if(model->probB)
2628    {
2629        fprintf(fp, "probB");
2630        for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2631            fprintf(fp," %g",model->probB[i]);
2632        fprintf(fp, "\n");
2633    }
2634
2635    if(model->nSV)
2636    {
2637        fprintf(fp, "nr_sv");
2638        for(int i=0;i<nr_class;i++)
2639            fprintf(fp," %d",model->nSV[i]);
2640        fprintf(fp, "\n");
2641    }
2642
2643    fprintf(fp, "SV\n");
2644    const double * const *sv_coef = model->sv_coef;
2645    const svm_node * const *SV = model->SV;
2646
2647    for(int i=0;i<l;i++)
2648    {
2649        for(int j=0;j<nr_class-1;j++)
2650            fprintf(fp, "%.16g ",sv_coef[j][i]);
2651
2652        const svm_node *p = SV[i];
2653
2654        if(param.kernel_type == PRECOMPUTED)
2655            fprintf(fp,"0:%d ",(int)(p->value));
2656        else
2657            while(p->index != -1)
2658            {
2659                fprintf(fp,"%d:%.8g ",p->index,p->value);
2660                p++;
2661            }
2662        fprintf(fp, "\n");
2663    }
2664    if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2665    else return 0;
2666}
2667
2668static char *line = NULL;
2669static int max_line_len;
2670
2671static char* readline(FILE *input)
2672{
2673    int len;
2674
2675    if(fgets(line,max_line_len,input) == NULL)
2676        return NULL;
2677
2678    while(strrchr(line,'\n') == NULL)
2679    {
2680        max_line_len *= 2;
2681        line = (char *) realloc(line,max_line_len);
2682        len = (int) strlen(line);
2683        if(fgets(line+len,max_line_len-len,input) == NULL)
2684            break;
2685    }
2686    return line;
2687}
2688
2689svm_model *svm_load_model(const char *model_file_name)
2690{
2691    FILE *fp = fopen(model_file_name,"rb");
2692    if(fp==NULL) return NULL;
2693   
2694    // read parameters
2695
2696    svm_model *model = Malloc(svm_model,1);
2697    svm_parameter& param = model->param;
2698    model->rho = NULL;
2699    model->probA = NULL;
2700    model->probB = NULL;
2701    model->label = NULL;
2702    model->nSV = NULL;
2703
2704    char cmd[81];
2705    while(1)
2706    {
2707        fscanf(fp,"%80s",cmd);
2708
2709        if(strcmp(cmd,"svm_type")==0)
2710        {
2711            fscanf(fp,"%80s",cmd);
2712            int i;
2713            for(i=0;svm_type_table[i];i++)
2714            {
2715                if(strcmp(svm_type_table[i],cmd)==0)
2716                {
2717                    param.svm_type=i;
2718                    break;
2719                }
2720            }
2721            if(svm_type_table[i] == NULL)
2722            {
2723                fprintf(stderr,"unknown svm type.\n");
2724                free(model->rho);
2725                free(model->label);
2726                free(model->nSV);
2727                free(model);
2728                return NULL;
2729            }
2730        }
2731        else if(strcmp(cmd,"kernel_type")==0)
2732        {       
2733            fscanf(fp,"%80s",cmd);
2734            int i;
2735            for(i=0;kernel_type_table[i];i++)
2736            {
2737                if(strcmp(kernel_type_table[i],cmd)==0)
2738                {
2739                    param.kernel_type=i;
2740                    break;
2741                }
2742            }
2743            if(kernel_type_table[i] == NULL)
2744            {
2745                fprintf(stderr,"unknown kernel function.\n");
2746                free(model->rho);
2747                free(model->label);
2748                free(model->nSV);
2749                free(model);
2750                return NULL;
2751            }
2752        }
2753        else if(strcmp(cmd,"degree")==0)
2754            fscanf(fp,"%d",&param.degree);
2755        else if(strcmp(cmd,"gamma")==0)
2756            fscanf(fp,"%lf",&param.gamma);
2757        else if(strcmp(cmd,"coef0")==0)
2758            fscanf(fp,"%lf",&param.coef0);
2759        else if(strcmp(cmd,"nr_class")==0)
2760            fscanf(fp,"%d",&model->nr_class);
2761        else if(strcmp(cmd,"total_sv")==0)
2762            fscanf(fp,"%d",&model->l);
2763        else if(strcmp(cmd,"rho")==0)
2764        {
2765            int n = model->nr_class * (model->nr_class-1)/2;
2766            model->rho = Malloc(double,n);
2767            for(int i=0;i<n;i++)
2768                fscanf(fp,"%lf",&model->rho[i]);
2769        }
2770        else if(strcmp(cmd,"label")==0)
2771        {
2772            int n = model->nr_class;
2773            model->label = Malloc(int,n);
2774            for(int i=0;i<n;i++)
2775                fscanf(fp,"%d",&model->label[i]);
2776        }
2777        else if(strcmp(cmd,"probA")==0)
2778        {
2779            int n = model->nr_class * (model->nr_class-1)/2;
2780            model->probA = Malloc(double,n);
2781            for(int i=0;i<n;i++)
2782                fscanf(fp,"%lf",&model->probA[i]);
2783        }
2784        else if(strcmp(cmd,"probB")==0)
2785        {
2786            int n = model->nr_class * (model->nr_class-1)/2;
2787            model->probB = Malloc(double,n);
2788            for(int i=0;i<n;i++)
2789                fscanf(fp,"%lf",&model->probB[i]);
2790        }
2791        else if(strcmp(cmd,"nr_sv")==0)
2792        {
2793            int n = model->nr_class;
2794            model->nSV = Malloc(int,n);
2795            for(int i=0;i<n;i++)
2796                fscanf(fp,"%d",&model->nSV[i]);
2797        }
2798        else if(strcmp(cmd,"SV")==0)
2799        {
2800            while(1)
2801            {
2802                int c = getc(fp);
2803                if(c==EOF || c=='\n') break;   
2804            }
2805            break;
2806        }
2807        else
2808        {
2809            fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2810            free(model->rho);
2811            free(model->label);
2812            free(model->nSV);
2813            free(model);
2814            return NULL;
2815        }
2816    }
2817
2818    // read sv_coef and SV
2819
2820    int elements = 0;
2821    long pos = ftell(fp);
2822
2823    max_line_len = 1024;
2824    line = Malloc(char,max_line_len);
2825    char *p,*endptr,*idx,*val;
2826
2827    while(readline(fp)!=NULL)
2828    {
2829        p = strtok(line,":");
2830        while(1)
2831        {
2832            p = strtok(NULL,":");
2833            if(p == NULL)
2834                break;
2835            ++elements;
2836        }
2837    }
2838    elements += model->l;
2839
2840    fseek(fp,pos,SEEK_SET);
2841
2842    int m = model->nr_class - 1;
2843    int l = model->l;
2844    model->sv_coef = Malloc(double *,m);
2845    int i;
2846    for(i=0;i<m;i++)
2847        model->sv_coef[i] = Malloc(double,l);
2848    model->SV = Malloc(svm_node*,l);
2849    svm_node *x_space = NULL;
2850    if(l>0) x_space = Malloc(svm_node,elements);
2851
2852    int j=0;
2853    for(i=0;i<l;i++)
2854    {
2855        readline(fp);
2856        model->SV[i] = &x_space[j];
2857
2858        p = strtok(line, " \t");
2859        model->sv_coef[0][i] = strtod(p,&endptr);
2860        for(int k=1;k<m;k++)
2861        {
2862            p = strtok(NULL, " \t");
2863            model->sv_coef[k][i] = strtod(p,&endptr);
2864        }
2865
2866        while(1)
2867        {
2868            idx = strtok(NULL, ":");
2869            val = strtok(NULL, " \t");
2870
2871            if(val == NULL)
2872                break;
2873            x_space[j].index = (int) strtol(idx,&endptr,10);
2874            x_space[j].value = strtod(val,&endptr);
2875
2876            ++j;
2877        }
2878        x_space[j++].index = -1;
2879    }
2880    free(line);
2881
2882    if (ferror(fp) != 0 || fclose(fp) != 0)
2883        return NULL;
2884
2885    model->free_sv = 1; // XXX
2886    return model;
2887}
2888
2889void svm_free_model_content(svm_model* model_ptr)
2890{
2891    if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL)
2892        free((void *)(model_ptr->SV[0]));
2893    if(model_ptr->sv_coef)
2894    {
2895        for(int i=0;i<model_ptr->nr_class-1;i++)
2896            free(model_ptr->sv_coef[i]);
2897    }
2898
2899    free(model_ptr->SV);
2900    model_ptr->SV = NULL;
2901
2902    free(model_ptr->sv_coef);
2903    model_ptr->sv_coef = NULL;
2904
2905    free(model_ptr->rho);
2906    model_ptr->rho = NULL;
2907
2908    free(model_ptr->label);
2909    model_ptr->label= NULL;
2910
2911    free(model_ptr->probA);
2912    model_ptr->probA = NULL;
2913
2914    free(model_ptr->probB);
2915    model_ptr->probB= NULL;
2916
2917    free(model_ptr->nSV);
2918    model_ptr->nSV = NULL;
2919}
2920
2921void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
2922{
2923    if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL)
2924    {
2925        svm_free_model_content(*model_ptr_ptr);
2926        free(*model_ptr_ptr);
2927        *model_ptr_ptr = NULL;
2928    }
2929}
2930
2931void svm_destroy_param(svm_parameter* param)
2932{
2933    free(param->weight_label);
2934    free(param->weight);
2935}
2936
2937const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
2938{
2939    // svm_type
2940
2941    int svm_type = param->svm_type;
2942    if(svm_type != C_SVC &&
2943       svm_type != NU_SVC &&
2944       svm_type != ONE_CLASS &&
2945       svm_type != EPSILON_SVR &&
2946       svm_type != NU_SVR)
2947        return "unknown svm type";
2948   
2949    // kernel_type, degree
2950   
2951    int kernel_type = param->kernel_type;
2952    if(kernel_type != LINEAR &&
2953       kernel_type != POLY &&
2954       kernel_type != RBF &&
2955       kernel_type != SIGMOID &&
2956       kernel_type != PRECOMPUTED)
2957        return "unknown kernel type";
2958
2959    if(param->gamma < 0)
2960        return "gamma < 0";
2961
2962    if(param->degree < 0)
2963        return "degree of polynomial kernel < 0";
2964
2965    // cache_size,eps,C,nu,p,shrinking
2966
2967    if(param->cache_size <= 0)
2968        return "cache_size <= 0";
2969
2970    if(param->eps <= 0)
2971        return "eps <= 0";
2972
2973    if(svm_type == C_SVC ||
2974       svm_type == EPSILON_SVR ||
2975       svm_type == NU_SVR)
2976        if(param->C <= 0)
2977            return "C <= 0";
2978
2979    if(svm_type == NU_SVC ||
2980       svm_type == ONE_CLASS ||
2981       svm_type == NU_SVR)
2982        if(param->nu <= 0 || param->nu > 1)
2983            return "nu <= 0 or nu > 1";
2984
2985    if(svm_type == EPSILON_SVR)
2986        if(param->p < 0)
2987            return "p < 0";
2988
2989    if(param->shrinking != 0 &&
2990       param->shrinking != 1)
2991        return "shrinking != 0 and shrinking != 1";
2992
2993    if(param->probability != 0 &&
2994       param->probability != 1)
2995        return "probability != 0 and probability != 1";
2996
2997    if(param->probability == 1 &&
2998       svm_type == ONE_CLASS)
2999        return "one-class SVM probability output not supported yet";
3000
3001
3002    // check whether nu-svc is feasible
3003   
3004    if(svm_type == NU_SVC)
3005    {
3006        int l = prob->l;
3007        int max_nr_class = 16;
3008        int nr_class = 0;
3009        int *label = Malloc(int,max_nr_class);
3010        int *count = Malloc(int,max_nr_class);
3011
3012        int i;
3013        for(i=0;i<l;i++)
3014        {
3015            int this_label = (int)prob->y[i];
3016            int j;
3017            for(j=0;j<nr_class;j++)
3018                if(this_label == label[j])
3019                {
3020                    ++count[j];
3021                    break;
3022                }
3023            if(j == nr_class)
3024            {
3025                if(nr_class == max_nr_class)
3026                {
3027                    max_nr_class *= 2;
3028                    label = (int *)realloc(label,max_nr_class*sizeof(int));
3029                    count = (int *)realloc(count,max_nr_class*sizeof(int));
3030                }
3031                label[nr_class] = this_label;
3032                count[nr_class] = 1;
3033                ++nr_class;
3034            }
3035        }
3036   
3037        for(i=0;i<nr_class;i++)
3038        {
3039            int n1 = count[i];
3040            for(int j=i+1;j<nr_class;j++)
3041            {
3042                int n2 = count[j];
3043                if(param->nu*(n1+n2)/2 > min(n1,n2))
3044                {
3045                    free(label);
3046                    free(count);
3047                    return "specified nu is infeasible";
3048                }
3049            }
3050        }
3051        free(label);
3052        free(count);
3053    }
3054
3055    return NULL;
3056}
3057
3058int svm_check_probability_model(const svm_model *model)
3059{
3060    return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3061        model->probA!=NULL && model->probB!=NULL) ||
3062        ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3063         model->probA!=NULL);
3064}
3065
3066void svm_set_print_string_function(void (*print_func)(const char *))
3067{
3068    if(print_func == NULL)
3069        svm_print_string = &print_string_stdout;
3070    else
3071        svm_print_string = print_func;
3072}
Note: See TracBrowser for help on using the repository browser.