Changeset 10574:f7272ba4865d in orange


Ignore:
Timestamp:
03/19/12 18:01:40 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
c778ff6dea4cf273365795e0c003e3cbbaff72c6
Message:

Added support for custom kernels in get_binary_classifier.

Location:
Orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/svm/__init__.py

    r10573 r10574  
    329329        sv2 = [sv for sv, nz in zip(sv2, nonzero2) if nz] 
    330330         
     331        sv_indices1 = [i for i, nz in zip(c1_range, nonzero1) if nz] 
     332        sv_indices2 = [i for i, nz in zip(c2_range, nonzero2) if nz] 
     333         
    331334        bin_class_var = Orange.feature.Discrete("%s vs %s" % \ 
    332335                        (self.class_var.values[c1], self.class_var.values[c2]), 
    333336                        values=["0", "1"]) 
    334337         
    335         model = self._binary_libsvm_model(bin_class_var, [coef1, coef2], [rho], sv1 + sv2) 
     338        model = self._binary_libsvm_model(bin_class_var, [coef1, coef2], 
     339                                          [rho], sv_indices1 + sv_indices2) 
    336340         
    337341        all_sv = Orange.data.Table(sv1 + sv2) 
     
    345349        return SVMClassifierWrapper(classifier) 
    346350     
    347     def _binary_libsvm_model(self, class_var, coefs, rho, sv_vectors): 
     351    def _binary_libsvm_model(self, class_var, coefs, rho, sv_indices): 
    348352        """Return a libsvm formated model string for binary subclassifier 
    349353        """ 
     
    360364         
    361365        model.append("nr_class %i" % len(class_var.values)) 
    362         model.append("total_sv %i" % len(sv_vectors)) 
     366        model.append("total_sv %i" % len(sv_indices)) 
    363367        model.append("rho " + " ".join(str(r) for r in rho)) 
    364368        model.append("label " + " ".join(str(i) for i in range(len(class_var.values)))) 
     
    374378            return " ".join("%i:%f" % (i + 1, v) for i, v in values) 
    375379         
    376         if self.svm_type == kernels.Custom: 
    377             raise NotImplemented("not implemented for custom kernels.") 
     380        if self.kernel_type == kernels.Custom: 
     381            SV = self.get_model().split("SV\n", 1)[1] 
     382            # Get the sv indices (the last entry in the SV entrys) 
     383            indices = [int(s.split(":")[-1]) for s in SV.splitlines() if s.strip()] 
     384            for c, sv_i in zip(itertools.chain(*coefs), itertools.chain(sv_indices)): 
     385                model.append("%f 0:%i" % (c, indices[sv_i])) 
    378386        else: 
    379             for c, sv in zip(itertools.chain(*coefs), itertools.chain(sv_vectors)): 
    380                 model.append("%f %s" % (c, instance_to_svm(sv))) 
     387            for c, sv_i in zip(itertools.chain(*coefs), itertools.chain(sv_indices)): 
     388                model.append("%f %s" % (c, instance_to_svm(self.support_vectors[sv_i]))) 
    381389                 
    382390        model.append("") 
  • Orange/testing/unit/tests/test_svm.py

    r10573 r10574  
    9696        """ Test custom kernel wrapper 
    9797        """ 
    98         # Need the data for ExamplesDistanceConstructor_Euclidean   
     98        if data.domain.has_continuous_attributes(): 
     99            dist = orange.ExamplesDistanceConstructor_Euclidean(data) 
     100        else: 
     101            dist = orange.ExamplesDistanceConstructor_Hamming(data) 
    99102        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom, 
    100                                     kernel_func=RBFKernelWrapper(orange.ExamplesDistanceConstructor_Euclidean(data), gamma=0.5)) 
     103                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5)) 
    101104 
    102105        testing.LearnerTestCase.test_learner_on(self, data) 
     106        svm_test_binary_classifier(self, data) 
    103107 
    104108 
Note: See TracChangeset for help on using the changeset viewer.