Ignore:
Timestamp:
03/30/12 15:23:22 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Using ScoreSVMWeights in RFE.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/svm/__init__.py

    r10694 r10695  
    928928    Example: 
    929929     
     930        >>> table = Orange.data.Table("vehicle.tab") 
    930931        >>> score = Orange.classification.svm.ScoreSVMWeights() 
    931932        >>> svm_scores = [(score(f, table), f) for f in table.domain.features]  
     
    981982        """ 
    982983        self.learner = learner 
    983         self._cached_examples = None 
     984        self._cached_data = None 
     985        self._cached_data_crc = None 
     986        self._cached_weights = None 
     987        self._cached_classifier = None 
    984988 
    985989    def __call__(self, attr, data, weight_id=None): 
     
    10011005                                type(data.domain.class_var)) 
    10021006 
    1003         if data is self._cached_examples: 
     1007        crc = data.checksum() 
     1008        if data is self._cached_data and crc == self._cached_data_crc: 
    10041009            weights = self._cached_weights 
    10051010        else: 
    10061011            classifier = learner(data, weight_id) 
    1007             self._cached_examples = data 
     1012            self._cached_data = data 
     1013            self._cached_data_crc = data.checksum() 
     1014            self._cached_classifier = classifier 
    10081015            weights = self._extract_weights(classifier, data.domain.attributes) 
    10091016            self._cached_weights = weights 
     
    10331040        source_weights = dict.fromkeys(original_features, 0.0) 
    10341041        for f in original_features: 
    1035             if f not in weights and f in sources: 
     1042            if f in weights: 
     1043                source_weights[f] = weights[f] 
     1044            elif f not in weights and f in sources: 
    10361045                dummys = sources[f] 
    10371046                # Use averege weight   
     
    10701079    Example:: 
    10711080     
    1072         import Orange 
    1073         table = Orange.data.Table("vehicle.tab") 
    1074         l = Orange.classification.svm.SVMLearner( 
    1075             kernel_type=Orange.classification.svm.kernels.Linear,  
    1076             normalization=False) # normalization=False will not change the domain 
    1077         rfe = Orange.classification.svm.RFE(l) 
    1078         data_subset_of_features = rfe(table, 5) 
     1081        >>> table = Orange.data.Table("promoters.tab") 
     1082        >>> svm_l = Orange.classification.svm.SVMLearner( 
     1083        ...     kernel_type=Orange.classification.svm.kernels.Linear)  
     1084        ...  
     1085        >>> rfe = Orange.classification.svm.RFE(learner=svm_l) 
     1086        >>> data_with_subset_of_features = rfe(table, 10) 
     1087        >>> data_with_subset_of_features.domain 
     1088        [p-45, p-36, p-35, p-34, p-33, p-31, p-18, p-12, p-10, p-04, y] 
    10791089         
    10801090    """ 
    10811091 
    10821092    def __init__(self, learner=None): 
    1083         self.learner = learner or SVMLearner(kernel_type= 
    1084                             kernels.Linear, normalization=False) 
     1093        """ 
     1094        :param learner: A linear svm learner for use with  
     1095            :class:`ScoreSVMWeights`. 
     1096         
     1097        """ 
     1098        self.learner = learner 
    10851099 
    10861100    @Orange.utils.deprecated_keywords({"progressCallback": "progress_callback", "stopAt": "stop_at" }) 
     
    10931107        iter = 1 
    10941108        attrs = data.domain.attributes 
    1095         attrScores = {} 
     1109        attr_scores = {} 
     1110        scorer = ScoreSVMWeights(learner=self.learner) 
    10961111 
    10971112        while len(attrs) > stop_at: 
    1098             weights = get_linear_svm_weights(self.learner(data), sum=False) 
     1113            scores = [(scorer(attr, data), attr) for attr in attrs] 
    10991114            if progress_callback: 
    11001115                progress_callback(100. * iter / (len(attrs) - stop_at)) 
    1101             score = dict.fromkeys(attrs, 0) 
    1102             for w in weights: 
    1103                 for attr, wAttr in w.items(): 
    1104                     score[attr] += wAttr ** 2 
    1105             score = score.items() 
    1106             score.sort(lambda a, b:cmp(a[1], b[1])) 
    1107             numToRemove = max(int(len(attrs) * 1.0 / (iter + 1)), 1) 
    1108             for attr, s in  score[:numToRemove]: 
    1109                 attrScores[attr] = len(attrScores) 
    1110             attrs = [attr for attr, s in score[numToRemove:]] 
     1116            scores = sorted(scores) 
     1117            num_to_remove = max(int(len(attrs) * 1.0 / (iter + 1)), 1) 
     1118            for s, attr in  scores[:num_to_remove]: 
     1119                attr_scores[attr] = len(attr_scores) 
     1120            attrs = [attr for s, attr in scores[num_to_remove:]] 
    11111121            if attrs: 
    1112                 data = data.select(attrs + [data.domain.classVar]) 
     1122                data = data.select(attrs + [data.domain.class_var]) 
    11131123            iter += 1 
    1114         return attrScores 
     1124        return attr_scores 
    11151125 
    11161126    @Orange.utils.deprecated_keywords({"numSelected": "num_selected", "progressCallback": "progress_callback"}) 
     
    11201130        :param data: Data 
    11211131        :type data: Orange.data.Table 
     1132         
    11221133        :param num_selected: number of features to preserve 
    11231134        :type num_selected: int 
     
    11721183 
    11731184tableToSVMFormat = table_to_svm_format 
    1174  
    1175  
    1176 def _doctest_args(): 
    1177     """For unittest framework to test the docstrings. 
    1178     """ 
    1179     import Orange 
    1180     table = Orange.data.Table("vehicle.tab") 
    1181     extraglobs = locals() 
    1182     return {"extraglobs": extraglobs} 
Note: See TracChangeset for help on using the changeset viewer.