Changeset 10584:edc3a37bc3c3 in orange


Ignore:
Timestamp:
03/20/12 15:50:14 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
b27e78cc27e112b769d1edb5e3d27ec79e783ce9
Message:

Fixed get_linear_svm_weights function with respect to the internal libsvm label ordering.

Location:
Orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/svm/__init__.py

    r10581 r10584  
    291291                val_index = n_class * (n_class - 1) / 2 - (n_class - ni - 1) * (n_class - ni - 2) / 2 - (n_class - nj) 
    292292                new_values.append(mult * dec_values[val_index]) 
    293         return new_values 
     293        return Orange.core.FloatList(new_values) 
    294294         
    295295    def get_model(self): 
     
    314314        import numpy as np 
    315315        if self.svm_type not in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
    316             raise TypeError("Wrong svm type.") 
     316            raise TypeError("SVM classification model expected.") 
    317317         
    318318        c1 = int(self.class_var(c1)) 
     
    616616 
    617617    class_var = SVs.domain.class_var 
    618     if classifier.svm_type in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
    619         classes = class_var.values 
    620     else: 
    621         classes = [""] 
    622     if len(classes) > 1: 
    623         sv_ranges = [(0, classifier.nSV[0])] 
    624         for n in classifier.nSV[1:]: 
    625             sv_ranges.append((sv_ranges[-1][1], sv_ranges[-1][1] + n)) 
    626     else: 
    627         sv_ranges = [(0, len(SVs))] 
    628  
     618    if classifier.svm_type not in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
     619        raise TypeError("SVM classification model expected.") 
     620     
     621    classes = classifier.class_var.values 
     622     
    629623    for i in range(len(classes) - 1): 
    630624        for j in range(i + 1, len(classes)): 
     625            # Get the coef and rho values from the binary sub-classifier 
     626            # Easier then using the full coef matrix (due to libsvm internal 
     627            # class  reordering) 
     628            bin_classifier = classifier.get_binary_classifier(i, j) 
     629            n_sv0 = bin_classifier.n_SV[0] 
     630            SVs = bin_classifier.support_vectors 
    631631            w = {} 
    632             coef_ind = j - 1 
    633             for sv_ind in range(*sv_ranges[i]): 
     632             
     633            for SV, alpha in zip(SVs, bin_classifier.coef[0]): 
    634634                attributes = SVs.domain.attributes + \ 
    635                 SVs[sv_ind].getmetas(False, Orange.feature.Descriptor).keys() 
     635                SV.getmetas(False, Orange.feature.Descriptor).keys() 
    636636                for attr in attributes: 
    637637                    if attr.varType == Orange.feature.Type.Continuous: 
    638                         update_weights(w, attr, to_float(SVs[sv_ind][attr]), \ 
    639                                        classifier.coef[coef_ind][sv_ind]) 
    640             coef_ind = i 
    641             for sv_ind in range(*sv_ranges[j]): 
    642                 attributes = SVs.domain.attributes + \ 
    643                 SVs[sv_ind].getmetas(False, Orange.feature.Descriptor).keys() 
    644                 for attr in attributes: 
    645                     if attr.varType == Orange.feature.Type.Continuous: 
    646                         update_weights(w, attr, to_float(SVs[sv_ind][attr]), \ 
    647                                        classifier.coef[coef_ind][sv_ind]) 
     638                        update_weights(w, attr, to_float(SV[attr]), alpha) 
     639                 
    648640            weights.append(w) 
    649  
     641             
    650642    if sum: 
    651643        scores = defaultdict(float) 
     
    656648        for key in scores: 
    657649            scores[key] = math.sqrt(scores[key]) 
    658         return scores 
     650        return dict(scores) 
    659651    else: 
    660652        return weights 
  • Orange/testing/unit/tests/test_svm.py

    r10579 r10584  
    11import Orange 
    2 from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE 
     2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights,\ 
     3                            LinearLearner, RFE, get_linear_svm_weights, \ 
     4                            example_weighted_sum 
     5                             
    36from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper 
    47from Orange.misc import testing 
     
    7073        svm_test_binary_classifier(self, dataset) 
    7174         
     75     
     76    @test_on_datasets(datasets=testing.CLASSIFICATION_DATASETS + ["zoo"]) 
     77    def test_linear_weights_on(self, dataset): 
     78        # Test get_linear_svm_weights 
     79        classifier = self.LEARNER(dataset) 
     80        weights = get_linear_svm_weights(classifier, sum=True) 
     81         
     82        weights = get_linear_svm_weights(classifier, sum=False) 
     83         
     84        n_class = len(classifier.class_var.values) 
     85         
     86        def class_pairs(n_class): 
     87            for i in range(n_class - 1): 
     88                for j in range(i + 1, n_class): 
     89                    yield i, j 
     90                     
     91        l_map = classifier._get_libsvm_labels_map() 
     92        # Would need to map the rho values 
     93        if l_map == sorted(l_map): 
     94            for inst in dataset[:20]: 
     95                dec_values = classifier.get_decision_values(inst) 
     96                 
     97                for dec_v, weight, rho, pair in zip(dec_values, weights, 
     98                                        classifier.rho, class_pairs(n_class)): 
     99                    t_inst = Orange.data.Instance(classifier.domain, inst)                     
     100                    dec_v1 = example_weighted_sum(t_inst, weight) - rho 
     101                    self.assertAlmostEqual(dec_v, dec_v1, 4) 
     102         
     103         
    72104@datasets_driven(datasets=datasets) 
    73105class PolySVMTestCase(testing.LearnerTestCase): 
Note: See TracChangeset for help on using the changeset viewer.