Changeset 10579:5d12fb993765 in orange


Ignore:
Timestamp:
03/20/12 11:23:40 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
5bbaeea1f58986385cf8605acb58b5707acb1985
Message:

Test multiclass prediction from the decision values.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_svm.py

    r10574 r10579  
    66import orange 
    77 
     8import copy 
     9 
    810import pickle 
     11import numpy as np 
    912 
    10                  
     13def multiclass_from1sv1(dec_values, class_var): 
     14    n_class = len(class_var.values) 
     15    votes = [0] * n_class 
     16    p = 0 
     17    for i in range(n_class - 1): 
     18        for j in range(i + 1, n_class): 
     19            val = dec_values[p] 
     20            if val > 0: 
     21                votes[i] += 1 
     22            else: 
     23                votes[j] += 1 
     24            p += 1 
     25    max_i = np.argmax(votes) 
     26    return class_var(int(max_i)) 
     27     
     28     
    1129def svm_test_binary_classifier(self, data): 
    1230    if isinstance(data.domain.class_var, Orange.feature.Discrete): 
     
    2240         
    2341        indices = Orange.data.sample.SubsetIndices2(p0=0.2) 
    24         sample = data.select(indices(data)) 
     42        sample = data.select(indices(data), 0) 
     43         
     44        learner = copy.copy(self.LEARNER) 
     45        learner.probability = False  
     46        classifier_no_prob = learner(data) 
    2547         
    2648        for inst in sample: 
     
    3153                self.assertAlmostEqual(v1, v2, places=3) 
    3254                self.assertAlmostEqual(v1, v3, places=3) 
     55             
     56            prediction_1 = classifier_no_prob(inst) 
     57            d_val = classifier_no_prob.get_decision_values(inst) 
     58            prediciton_2 = multiclass_from1sv1(d_val, classifier_no_prob.class_var) 
     59            self.assertEqual(prediction_1, prediciton_2) 
     60             
    3361 
    3462datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS 
Note: See TracChangeset for help on using the changeset viewer.