Changeset 3591:b183ad69732c in orange


Ignore:
Timestamp:
04/26/07 10:43:18 (7 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
fe3417cba0fef37d746ded76cc13f4f2b642a407
Message:

svm regression problem after the last update (svm coef, nSV..)

Location:
source/orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • source/orange/svm.cpp

    r3535 r3591  
    4646 *-added code to compute custom kernel in Kernel::k_function 
    4747 *-handle CUSTOM in svm_check_parameter 
     48 *-moved svm_model definition into svm.hpp 
     49 *-added custom to kernel_type_table 
    4850/*########################################## 
    4951##########################################*/ 
     
    427429        case CUSTOM: 
    428430        { 
    429             while(y->index!=-1) // the value in the last svm_node is the index of the training example 
     431            /*while(y->index!=-1)   // the value in the last svm_node is the index of the training example 
    430432                ++y; 
    431433            while(x->index!=-1) 
    432                 ++x; 
     434                ++x;*/ 
    433435            if(!param.classifier) 
    434436                return param.learner->kernelFunc->operator()(param.learner->tempExamples->at((int)x->value), 
     
    18081810// svm_model 
    18091811// 
    1810 struct svm_model 
    1811 { 
    1812     svm_parameter param;    // parameter 
    1813     int nr_class;       // number of classes, = 2 in regression/one class svm 
    1814     int l;          // total #SV 
    1815     svm_node **SV;      // SVs (SV[l]) 
    1816     double **sv_coef;   // coefficients for SVs in decision functions (sv_coef[n-1][l]) 
    1817     double *rho;        // constants in decision functions (rho[n*(n-1)/2]) 
    1818     double *probA;          // pariwise probability information 
    1819     double *probB; 
    1820  
    1821     // for classification only 
    1822  
    1823     int *label;     // label of each class (label[n]) 
    1824     int *nSV;       // number of SVs for each class (nSV[n]) 
    1825                 // nSV[0] + nSV[1] + ... + nSV[n-1] = l 
    1826     // XXX 
    1827     int free_sv;        // 1 if svm_model is created by svm_load_model 
    1828                 // 0 if svm_model is created by svm_train 
    1829 }; 
     1812 
    18301813 
    18311814// Platt's binary SVM Probablistic Output: an improvement from Lin et al. 
     
    27252708const char *kernel_type_table[]= 
    27262709{ 
    2727     "linear","polynomial","rbf","sigmoid","precomputed",NULL 
     2710    "linear","polynomial","rbf","sigmoid","custom", "precomputed",NULL 
    27282711}; 
    27292712 
     
    31593142//#_i_nclu_sde "svm.ppp" 
    31603143 
     3144/* 
     3145Save load functions for use with orange pickling. 
     3146They are a copy of the standard libSVM save-load functions 
     3147except that they read/write from/to a temp file that is deleted 
     3148at the end with fclose 
     3149*/ 
     3150#include "slist.hpp" 
     3151int svm_save_model_alt(TCharBuffer& buffer, const svm_model *model){ 
     3152    FILE *fp = tmpfile(); 
     3153    if(fp==NULL) return -1; 
     3154 
     3155    const svm_parameter& param = model->param; 
     3156 
     3157    fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); 
     3158    fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); 
     3159 
     3160    if(param.kernel_type == POLY) 
     3161        fprintf(fp,"degree %d\n", param.degree); 
     3162 
     3163    if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID) 
     3164        fprintf(fp,"gamma %g\n", param.gamma); 
     3165 
     3166    if(param.kernel_type == POLY || param.kernel_type == SIGMOID) 
     3167        fprintf(fp,"coef0 %g\n", param.coef0); 
     3168 
     3169    int nr_class = model->nr_class; 
     3170    int l = model->l; 
     3171    fprintf(fp, "nr_class %d\n", nr_class); 
     3172    fprintf(fp, "total_sv %d\n",l); 
     3173     
     3174    { 
     3175        fprintf(fp, "rho"); 
     3176        for(int i=0;i<nr_class*(nr_class-1)/2;i++) 
     3177            fprintf(fp," %g",model->rho[i]); 
     3178        fprintf(fp, "\n"); 
     3179    } 
     3180     
     3181    if(model->label) 
     3182    { 
     3183        fprintf(fp, "label"); 
     3184        for(int i=0;i<nr_class;i++) 
     3185            fprintf(fp," %d",model->label[i]); 
     3186        fprintf(fp, "\n"); 
     3187    } 
     3188 
     3189    if(model->probA) // regression has probA only 
     3190    { 
     3191        fprintf(fp, "probA"); 
     3192        for(int i=0;i<nr_class*(nr_class-1)/2;i++) 
     3193            fprintf(fp," %g",model->probA[i]); 
     3194        fprintf(fp, "\n"); 
     3195    } 
     3196    if(model->probB) 
     3197    { 
     3198        fprintf(fp, "probB"); 
     3199        for(int i=0;i<nr_class*(nr_class-1)/2;i++) 
     3200            fprintf(fp," %g",model->probB[i]); 
     3201        fprintf(fp, "\n"); 
     3202    } 
     3203 
     3204    if(model->nSV) 
     3205    { 
     3206        fprintf(fp, "nr_sv"); 
     3207        for(int i=0;i<nr_class;i++) 
     3208            fprintf(fp," %d",model->nSV[i]); 
     3209        fprintf(fp, "\n"); 
     3210    } 
     3211 
     3212    fprintf(fp, "SV\n"); 
     3213    const double * const *sv_coef = model->sv_coef; 
     3214    const svm_node * const *SV = model->SV; 
     3215 
     3216    for(int i=0;i<l;i++) 
     3217    { 
     3218        for(int j=0;j<nr_class-1;j++) 
     3219            fprintf(fp, "%.16g ",sv_coef[j][i]); 
     3220 
     3221        const svm_node *p = SV[i]; 
     3222 
     3223        if(param.kernel_type == PRECOMPUTED) 
     3224            fprintf(fp,"0:%d ",(int)(p->value)); 
     3225        else 
     3226            while(p->index != -1) 
     3227            { 
     3228                fprintf(fp,"%d:%.8g ",p->index,p->value); 
     3229                p++; 
     3230            } 
     3231        fprintf(fp, "\n"); 
     3232    } 
     3233    if (ferror(fp) != 0 || fclose(fp) != 0) return -1; 
     3234 
     3235    fseek(fp, SEEK_SET, 0); 
     3236    string tmpbuf; 
     3237    char str[512]; 
     3238    while(fgets(str, 512, fp)){ 
     3239        tmpbuf+=str; 
     3240    } 
     3241    //if(!feof(fp)) 
     3242    //  printf("Error saving svm_model"); 
     3243    buffer.writeInt(tmpbuf.size()+1); 
     3244    printf(tmpbuf.c_str()); 
     3245    buffer.writeBuf((void*)tmpbuf.c_str(), tmpbuf.size()+1); 
     3246    fclose(fp); 
     3247    return 0; 
     3248} 
     3249 
     3250svm_model *svm_load_model_alt(TCharBuffer& buffer) 
     3251{ 
     3252    FILE *fp = tmpfile(); 
     3253    if(fp==NULL) return NULL; 
     3254    int bufflen=buffer.readInt(); 
     3255    char *tmpstr=(char*)malloc(sizeof(char)*bufflen); 
     3256    buffer.readBuf(tmpstr, bufflen); 
     3257    fwrite(tmpstr, sizeof(char), bufflen, fp); 
     3258    fseek(fp, SEEK_SET, 0); 
     3259    free(tmpstr); 
     3260     
     3261    // read parameters 
     3262 
     3263    svm_model *model = Malloc(svm_model,1); 
     3264    svm_parameter& param = model->param; 
     3265    model->rho = NULL; 
     3266    model->probA = NULL; 
     3267    model->probB = NULL; 
     3268    model->label = NULL; 
     3269    model->nSV = NULL; 
     3270 
     3271    char cmd[81]; 
     3272    while(1) 
     3273    { 
     3274        fscanf(fp,"%80s",cmd); 
     3275 
     3276        if(strcmp(cmd,"svm_type")==0) 
     3277        { 
     3278            fscanf(fp,"%80s",cmd); 
     3279            int i; 
     3280            for(i=0;svm_type_table[i];i++) 
     3281            { 
     3282                if(strcmp(svm_type_table[i],cmd)==0) 
     3283                { 
     3284                    param.svm_type=i; 
     3285                    break; 
     3286                } 
     3287            } 
     3288            if(svm_type_table[i] == NULL) 
     3289            { 
     3290                fprintf(stderr,"unknown svm type.\n"); 
     3291                free(model->rho); 
     3292                free(model->label); 
     3293                free(model->nSV); 
     3294                free(model); 
     3295                fclose(fp); 
     3296                return NULL; 
     3297            } 
     3298        } 
     3299        else if(strcmp(cmd,"kernel_type")==0) 
     3300        {        
     3301            fscanf(fp,"%80s",cmd); 
     3302            int i; 
     3303            for(i=0;kernel_type_table[i];i++) 
     3304            { 
     3305                if(strcmp(kernel_type_table[i],cmd)==0) 
     3306                { 
     3307                    param.kernel_type=i; 
     3308                    break; 
     3309                } 
     3310            } 
     3311            if(kernel_type_table[i] == NULL) 
     3312            { 
     3313                fprintf(stderr,"unknown kernel function.\n"); 
     3314                free(model->rho); 
     3315                free(model->label); 
     3316                free(model->nSV); 
     3317                free(model); 
     3318                fclose(fp); 
     3319                return NULL; 
     3320            } 
     3321        } 
     3322        else if(strcmp(cmd,"degree")==0) 
     3323            fscanf(fp,"%d",&param.degree); 
     3324        else if(strcmp(cmd,"gamma")==0) 
     3325            fscanf(fp,"%lf",&param.gamma); 
     3326        else if(strcmp(cmd,"coef0")==0) 
     3327            fscanf(fp,"%lf",&param.coef0); 
     3328        else if(strcmp(cmd,"nr_class")==0) 
     3329            fscanf(fp,"%d",&model->nr_class); 
     3330        else if(strcmp(cmd,"total_sv")==0) 
     3331            fscanf(fp,"%d",&model->l); 
     3332        else if(strcmp(cmd,"rho")==0) 
     3333        { 
     3334            int n = model->nr_class * (model->nr_class-1)/2; 
     3335            model->rho = Malloc(double,n); 
     3336            for(int i=0;i<n;i++) 
     3337                fscanf(fp,"%lf",&model->rho[i]); 
     3338        } 
     3339        else if(strcmp(cmd,"label")==0) 
     3340        { 
     3341            int n = model->nr_class; 
     3342            model->label = Malloc(int,n); 
     3343            for(int i=0;i<n;i++) 
     3344                fscanf(fp,"%d",&model->label[i]); 
     3345        } 
     3346        else if(strcmp(cmd,"probA")==0) 
     3347        { 
     3348            int n = model->nr_class * (model->nr_class-1)/2; 
     3349            model->probA = Malloc(double,n); 
     3350            for(int i=0;i<n;i++) 
     3351                fscanf(fp,"%lf",&model->probA[i]); 
     3352        } 
     3353        else if(strcmp(cmd,"probB")==0) 
     3354        { 
     3355            int n = model->nr_class * (model->nr_class-1)/2; 
     3356            model->probB = Malloc(double,n); 
     3357            for(int i=0;i<n;i++) 
     3358                fscanf(fp,"%lf",&model->probB[i]); 
     3359        } 
     3360        else if(strcmp(cmd,"nr_sv")==0) 
     3361        { 
     3362            int n = model->nr_class; 
     3363            model->nSV = Malloc(int,n); 
     3364            for(int i=0;i<n;i++) 
     3365                fscanf(fp,"%d",&model->nSV[i]); 
     3366        } 
     3367        else if(strcmp(cmd,"SV")==0) 
     3368        { 
     3369            while(1) 
     3370            { 
     3371                int c = getc(fp); 
     3372                if(c==(char)EOF || c=='\n') break;   
     3373            } 
     3374            break; 
     3375        } 
     3376        else 
     3377        { 
     3378            fprintf(stderr,"unknown text in model file: [%s]\n",cmd); 
     3379            free(model->rho); 
     3380            free(model->label); 
     3381            free(model->nSV); 
     3382            free(model); 
     3383            fclose(fp); 
     3384            return NULL; 
     3385        } 
     3386    } 
     3387 
     3388    // read sv_coef and SV 
     3389 
     3390    int elements = 0; 
     3391    long pos = ftell(fp); 
     3392 
     3393    while(1) 
     3394    { 
     3395        int c = fgetc(fp); 
     3396        switch(c) 
     3397        { 
     3398            case '\n': 
     3399                // count the '-1' element 
     3400            case ':': 
     3401                ++elements; 
     3402                break; 
     3403            case (char)EOF: 
     3404                goto out; 
     3405            default: 
     3406                ; 
     3407        } 
     3408    } 
     3409out: 
     3410    fseek(fp,pos,SEEK_SET); 
     3411 
     3412    int m = model->nr_class - 1; 
     3413    int l = model->l; 
     3414    model->sv_coef = Malloc(double *,m); 
     3415    int i; 
     3416    for(i=0;i<m;i++) 
     3417        model->sv_coef[i] = Malloc(double,l); 
     3418    model->SV = Malloc(svm_node*,l); 
     3419    svm_node *x_space=NULL; 
     3420    if(l>0) x_space = Malloc(svm_node,elements); 
     3421 
     3422    int j=0; 
     3423    for(i=0;i<l;i++) 
     3424    { 
     3425        model->SV[i] = &x_space[j]; 
     3426        for(int k=0;k<m;k++) 
     3427            fscanf(fp,"%lf",&model->sv_coef[k][i]); 
     3428        while(1) 
     3429        { 
     3430            int c; 
     3431            do { 
     3432                c = getc(fp); 
     3433                if(c=='\n') goto out2; 
     3434            } while(isspace(c)); 
     3435            ungetc(c,fp); 
     3436            fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value)); 
     3437            ++j; 
     3438        }    
     3439out2: 
     3440        x_space[j++].index = -1; 
     3441    } 
     3442    if (ferror(fp) != 0 || fclose(fp) != 0) return NULL; 
     3443 
     3444    model->free_sv = 1; // XXX 
     3445 
     3446    printf("%i\n",model->param.kernel_type); 
     3447    return model; 
     3448} 
     3449     
    31613450svm_node* example_to_svm(const TExample &ex, svm_node* node, float last=0.0, int type=0){ 
    3162     //if(type==0) 
     3451    if(type==0) 
    31633452        for(int i=0;i<ex.domain->attributes->size();i++){ 
    31643453            if(ex[i].isRegular()){ 
     
    31743463            } 
    31753464        } 
    3176     //if(type==0 || type==1){ 
    3177         node->index=-1; 
    3178         node->value=last; 
    3179         node++; 
    3180     /*} 
    3181     else{ 
    3182         node->index=-1; 
    3183         const TExample *e=&ex; 
    3184         memcpy((void *)&node->value, (void *)&e, sizeof(TExample*)); 
    3185         node++; 
    3186     }*/ 
     3465    if(type == 1){ /*one dummy attr so we can pickle the classifier and keep the SV index in the training table*/ 
     3466        node->index=1; 
     3467        node->value=last; 
     3468        node++; 
     3469    } 
     3470    node->index=-1; 
     3471    node->value=last; 
     3472    node++; 
    31873473    return node; 
    31883474} 
     
    32693555    PEITERATE(iter, examples){ 
    32703556        prob.x[k]=node; 
    3271         node=example_to_svm(*iter, node, k);//, (param.kernel_type==CUSTOM)? 1:0); 
     3557        node=example_to_svm(*iter, node, k, (param.kernel_type==CUSTOM)? 1:0); 
    32723558        switch(classVarType){ 
    32733559            case TValue::FLOATVAR:{ 
     
    33093595    examples=_examples; 
    33103596    model->param.classifier=this; 
    3311     kernelFunc=model->param.learner->kernelFunc; 
     3597    if (model->param.learner) // if the model constructed at unpickling  
     3598        kernelFunc=model->param.learner->kernelFunc; 
     3599    model->param.learner=NULL; 
    33123600    currentExample=NULL; 
    33133601    computesProbabilities = model && model->param.svm_type!=EPSILON_SVR && 
     
    33243612        supportVectors->addExample(mlnew TExample(examples->at(int(node->value)))); 
    33253613    } 
    3326     nSV=mlnew TFloatList(nr_class); // num of SVs for each class (sum = model->l) 
    3327     for(i=0;i<nr_class;i++) 
    3328         nSV->at(i)=model->nSV[i]; 
     3614    int svm_type=model->param.svm_type; 
     3615    if (svm_type==C_SVC || svm_type==NU_SVC){ 
     3616        nSV=mlnew TFloatList(nr_class); // num of SVs for each class (sum = model->l) 
     3617        for(i=0;i<nr_class;i++) 
     3618            nSV->at(i)=model->nSV[i]; 
     3619    } 
    33293620 
    33303621    coef=mlnew TFloatListList(nr_class-1); 
     
    33843675    int nr_class=svm_get_nr_class(model); 
    33853676    svm_node *x=Malloc(svm_node, exlen+1); 
    3386     example_to_svm(example, x, -1.0);//, (model->param.kernel_type==CUSTOM)? 2:0); 
     3677    example_to_svm(example, x, -1.0, (model->param.kernel_type==CUSTOM)? 1:0); 
    33873678    double v; 
    33883679    if(model->param.probability){ 
  • source/orange/svm.hpp

    r3435 r3591  
    9090}; 
    9191 
     92struct svm_model 
     93{ 
     94    svm_parameter param;    // parameter 
     95    int nr_class;       // number of classes, = 2 in regression/one class svm 
     96    int l;          // total #SV 
     97    svm_node **SV;      // SVs (SV[l]) 
     98    double **sv_coef;   // coefficients for SVs in decision functions (sv_coef[n-1][l]) 
     99    double *rho;        // constants in decision functions (rho[n*(n-1)/2]) 
     100    double *probA;          // pariwise probability information 
     101    double *probB; 
     102 
     103    // for classification only 
     104 
     105    int *label;     // label of each class (label[n]) 
     106    int *nSV;       // number of SVs for each class (nSV[n]) 
     107                // nSV[0] + nSV[1] + ... + nSV[n-1] = l 
     108    // XXX 
     109    int free_sv;        // 1 if svm_model is created by svm_load_model 
     110                // 0 if svm_model is created by svm_train 
     111}; 
     112 
    92113struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param); 
    93114void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target); 
     
    129150#include "examples.hpp" 
    130151#include "distance.hpp" 
     152#include "slist.hpp" 
     153 
     154svm_model *svm_load_model_alt(TCharBuffer& buffer); 
     155int svm_save_model_alt(TCharBuffer& buffer, const svm_model *model); 
    131156 
    132157WRAPPER(ExampleGenerator) 
     
    196221    const TExample *currentExample; 
    197222 
     223    svm_model* getModel(){return model;}; 
     224 
    198225private: 
    199226    svm_model *model; 
Note: See TracChangeset for help on using the changeset viewer.