Ignore:
Timestamp:
03/19/12 16:43:39 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
84eb1ad3ab7dbc2cf39ac79eaf19384b1317d358
Message:

Added get_binary_classifier method to SVMClassifierWrapper.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/svm/__init__.py

    r10542 r10573  
    282282            for name, val in self.__dict__.items() \ 
    283283            if name not in self.wrapped.__dict__]) 
     284         
     285    def get_binary_classifier(self, c1, c2): 
     286        """Return a binary classifier for classes `c1` and `c2`. 
     287        """ 
     288        import numpy as np 
     289        if self.svm_type not in [SVMLearner.C_SVC, SVMLearner.Nu_SVC]: 
     290            raise TypeError("Wrong svm type.") 
     291         
     292        c1 = int(self.class_var(c1)) 
     293        c2 = int(self.class_var(c2)) 
     294        n_class = len(self.class_var.values) 
     295         
     296        if c1 == c2: 
     297            raise ValueError("Different classes expected.") 
     298          
     299        if c1 > c2: 
     300            c1, c2 = c2, c1 
     301         
     302        # Index of the 1vs1 binary classifier  
     303        classifier_i = n_class * (n_class - 1) / 2 - (n_class - c1 - 1) * (n_class - c1 - 2) / 2 - (n_class - c2) 
     304         
     305        # Indices for classes in the coef structure. 
     306        class_indices = np.cumsum([0] + list(self.n_SV), dtype=int) 
     307        c1_range = range(class_indices[c1], class_indices[c1 + 1]) 
     308        c2_range = range(class_indices[c2], class_indices[c2 + 1]) 
     309         
     310        coef_array = np.array(self.coef) 
     311        coef1 = coef_array[c2 - 1, c1_range] 
     312        coef2 = coef_array[c1, c2_range] 
     313         
     314        # Support vectors for the binary classifier 
     315        sv1 = [self.support_vectors[i] for i in c1_range] 
     316        sv2 = [self.support_vectors[i] for i in c2_range] 
     317         
     318        # Rho for the classifier 
     319        rho = self.rho[classifier_i] 
     320         
     321        # Filter non zero support vectors 
     322        nonzero1 = np.abs(coef1) > 0.0 
     323        nonzero2 = np.abs(coef2) > 0.0 
     324         
     325        coef1 = coef1[nonzero1] 
     326        coef2 = coef2[nonzero2] 
     327         
     328        sv1 = [sv for sv, nz in zip(sv1, nonzero1) if nz] 
     329        sv2 = [sv for sv, nz in zip(sv2, nonzero2) if nz] 
     330         
     331        bin_class_var = Orange.feature.Discrete("%s vs %s" % \ 
     332                        (self.class_var.values[c1], self.class_var.values[c2]), 
     333                        values=["0", "1"]) 
     334         
     335        model = self._binary_libsvm_model(bin_class_var, [coef1, coef2], [rho], sv1 + sv2) 
     336         
     337        all_sv = Orange.data.Table(sv1 + sv2) 
     338        if self.kernel_type == kernels.Custom: 
     339            classifier = SVMClassifier(bin_class_var, self.examples, 
     340                                       all_sv, model, self.kernel_func) 
     341        else: 
     342            classifier = SVMClassifier(bin_class_var, self.examples, 
     343                                       all_sv, model) 
     344             
     345        return SVMClassifierWrapper(classifier) 
     346     
     347    def _binary_libsvm_model(self, class_var, coefs, rho, sv_vectors): 
     348        """Return a libsvm formated model string for binary subclassifier 
     349        """ 
     350        import itertools 
     351         
     352        model = [] 
     353         
     354        # Take the model up to nr_classes 
     355        for line in self.get_model().splitlines(): 
     356            if line.startswith("nr_class"): 
     357                break 
     358            else: 
     359                model.append(line.rstrip()) 
     360         
     361        model.append("nr_class %i" % len(class_var.values)) 
     362        model.append("total_sv %i" % len(sv_vectors)) 
     363        model.append("rho " + " ".join(str(r) for r in rho)) 
     364        model.append("label " + " ".join(str(i) for i in range(len(class_var.values)))) 
     365        # No probA and probB 
     366         
     367        model.append("nr_sv " + " ".join(str(len(c)) for c in coefs)) 
     368        model.append("SV") 
     369         
     370        def instance_to_svm(inst): 
     371            values = [(i, float(inst[v])) \ 
     372                      for i, v in enumerate(inst.domain.attributes) \ 
     373                      if not inst[v].is_special() and float(inst[v]) != 0.0] 
     374            return " ".join("%i:%f" % (i + 1, v) for i, v in values) 
     375         
     376        if self.svm_type == kernels.Custom: 
     377            raise NotImplemented("not implemented for custom kernels.") 
     378        else: 
     379            for c, sv in zip(itertools.chain(*coefs), itertools.chain(sv_vectors)): 
     380                model.append("%f %s" % (c, instance_to_svm(sv))) 
     381                 
     382        model.append("") 
     383        return "\n".join(model) 
     384         
    284385 
    285386SVMClassifierWrapper = Orange.misc.deprecated_members({ 
Note: See TracChangeset for help on using the changeset viewer.