Changeset 10638:38fd2d78200c in orange


Ignore:
Timestamp:
03/26/12 12:47:44 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added support for regression models in get_linear_svm_weights.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/svm/__init__.py

    r10637 r10638  
    806806 
    807807    SVs = classifier.support_vectors 
    808     weights = [] 
    809  
    810808    class_var = SVs.domain.class_var 
    811     if classifier.svm_type not in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
    812         raise TypeError("SVM classification model expected.") 
    813      
    814     classes = classifier.class_var.values 
    815      
    816     for i in range(len(classes) - 1): 
    817         for j in range(i + 1, len(classes)): 
    818             # Get the coef and rho values from the binary sub-classifier 
    819             # Easier then using the full coef matrix (due to libsvm internal 
    820             # class  reordering) 
    821             bin_classifier = classifier.get_binary_classifier(i, j) 
    822             n_sv0 = bin_classifier.n_SV[0] 
    823             SVs = bin_classifier.support_vectors 
    824             w = {} 
    825              
    826             for alpha, sv_ind in bin_classifier.coef[0]: 
    827                 SV = SVs[sv_ind] 
    828                 attributes = SVs.domain.attributes + \ 
    829                 SV.getmetas(False, Orange.feature.Descriptor).keys() 
    830                 for attr in attributes: 
    831                     if attr.varType == Orange.feature.Type.Continuous: 
    832                         update_weights(w, attr, to_float(SV[attr]), alpha) 
     809     
     810    if classifier.svm_type in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
     811        weights = []     
     812        classes = classifier.class_var.values 
     813        for i in range(len(classes) - 1): 
     814            for j in range(i + 1, len(classes)): 
     815                # Get the coef and rho values from the binary sub-classifier 
     816                # Easier then using the full coef matrix (due to libsvm internal 
     817                # class  reordering) 
     818                bin_classifier = classifier.get_binary_classifier(i, j) 
     819                n_sv0 = bin_classifier.n_SV[0] 
     820                SVs = bin_classifier.support_vectors 
     821                w = {} 
    833822                 
    834             weights.append(w) 
    835              
    836     if sum: 
    837         scores = defaultdict(float) 
    838  
    839         for w in weights: 
    840             for attr, w_attr in w.items(): 
    841                 scores[attr] += w_attr ** 2 
    842         for key in scores: 
    843             scores[key] = math.sqrt(scores[key]) 
    844         return dict(scores) 
     823                for alpha, sv_ind in bin_classifier.coef[0]: 
     824                    SV = SVs[sv_ind] 
     825                    attributes = SVs.domain.attributes + \ 
     826                    SV.getmetas(False, Orange.feature.Descriptor).keys() 
     827                    for attr in attributes: 
     828                        if attr.varType == Orange.feature.Type.Continuous: 
     829                            update_weights(w, attr, to_float(SV[attr]), coef) 
     830                     
     831                weights.append(w) 
     832        if sum: 
     833            scores = defaultdict(float) 
     834            for w in weights: 
     835                for attr, w_attr in w.items(): 
     836                    scores[attr] += w_attr ** 2 
     837            for key in scores: 
     838                scores[key] = math.sqrt(scores[key]) 
     839            weights = dict(scores) 
    845840    else: 
    846         return weights 
    847  
     841#        raise TypeError("SVM classification model expected.") 
     842        weights = {} 
     843        for coef, sv_ind in classifier.coef[0]: 
     844            SV = SVs[sv_ind] 
     845            attributes = SVs.domain.attributes + \ 
     846            SV.getmetas(False, Orange.feature.Descriptor).keys() 
     847            for attr in attributes: 
     848                if attr.varType == Orange.feature.Type.Continuous: 
     849                    update_weights(weights, attr, to_float(SV[attr]), coef) 
     850            
     851    return weights  
     852     
    848853getLinearSVMWeights = get_linear_svm_weights 
    849854 
Note: See TracChangeset for help on using the changeset viewer.