Changeset 7521:5723c90704af in orange


Ignore:
Timestamp:
02/04/11 20:05:47 (3 years ago)
Author:
miha <miha.stajdohar@…>
Branch:
default
Convert:
3f57f0e675f9466ac5e751d4abc2a17ed13138a9
Message:

moved tuning and thresholding classes to Orange.optimization

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/orngWrap.py

    r6929 r7521  
    1 import orange 
    2  
    3 # The class needs to be given 
    4 #   object     - the learning algorithm to be fitted 
    5 #   evaluate   - statistics to evaluate (default: orngStat.CA) 
    6 #   folds      - the number of folds for internal cross validation 
    7 #   compare    - function to compare (default: cmp - the bigger the better) 
    8 #   returnWhat - tells whether to return values of parameters, a fitted 
    9 #                learner, the best classifier or None. "object" is left 
    10 #                with optimal parameters in any case 
    11 class TuneParameters(orange.Learner): 
    12     returnNone=0 
    13     returnParameters=1 
    14     returnLearner=2 
    15     returnClassifier=3 
    16      
    17     def __new__(cls, examples = None, weightID = 0, **argkw): 
    18         self = orange.Learner.__new__(cls, **argkw) 
    19         self.__dict__.update(argkw) 
    20         if examples: 
    21             return self.__call__(examples, weightID) 
    22         else: 
    23             return self 
    24  
    25     def findobj(self, name): 
    26         import string 
    27         names=string.split(name, ".") 
    28         lastobj=self.object 
    29         for i in names[:-1]: 
    30             lastobj=getattr(lastobj, i) 
    31         return lastobj, names[-1] 
    32          
    33 # Same arguments as TuneParameters, plus: 
    34 #   parameter  - a string or a list of strings with parameter(s) to fit 
    35 #   values     - possible values of the parameter 
    36 #                (eg <object>.<parameter> = <value>[i]) 
    37 class Tune1Parameter(TuneParameters): 
    38     def __call__(self, table, weight=None, verbose=0): 
    39         import orngTest, orngStat, orngMisc 
    40  
    41         verbose = verbose or getattr(self, "verbose", 0) 
    42         evaluate = getattr(self, "evaluate", orngStat.CA) 
    43         folds = getattr(self, "folds", 5) 
    44         compare = getattr(self, "compare", cmp) 
    45         returnWhat = getattr(self, "returnWhat", Tune1Parameter.returnClassifier) 
    46  
    47         if (type(self.parameter)==list) or (type(self.parameter)==tuple): 
    48             to_set = [self.findobj(ld) for ld in self.parameter] 
    49         else: 
    50             to_set = [self.findobj(self.parameter)] 
    51  
    52         cvind = orange.MakeRandomIndicesCV(table, folds) 
    53         findBest = orngMisc.BestOnTheFly(seed = table.checksum(), callCompareOn1st = True) 
    54         tableAndWeight = weight and (table, weight) or table 
    55         for par in self.values: 
    56             for i in to_set: 
    57                 setattr(i[0], i[1], par) 
    58             res = evaluate(orngTest.testWithIndices([self.object], tableAndWeight, cvind)) 
    59             findBest.candidate((res, par)) 
    60             if verbose==2: 
    61                 print '*** orngWrap  %s: %s:' % (par, res) 
    62  
    63         bestpar = findBest.winner()[1] 
    64         for i in to_set: 
    65             setattr(i[0], i[1], bestpar) 
    66  
    67         if verbose: 
    68             print "*** Optimal parameter: %s = %s" % (self.parameter, bestpar) 
    69  
    70         if returnWhat==Tune1Parameter.returnNone: 
    71             return None 
    72         elif returnWhat==Tune1Parameter.returnParameters: 
    73             return bestpar 
    74         elif returnWhat==Tune1Parameter.returnLearner: 
    75             return self.object 
    76         else: 
    77             classifier = self.object(table) 
    78             classifier.setattr("fittedParameter", bestpar) 
    79             return classifier 
    80  
    81  
    82 # Same arguments as TuneParameters, plus 
    83 #   parameters - a list of tuples with parameters to be fitted and the 
    84 #                corresponding possible values, [(parameter(s), values), ...] 
    85 #                (eg <object>.<parameter[j]> = <value[j]>[i]) 
    86 class TuneMParameters(TuneParameters): 
    87     def __call__(self, table, weight=None, verbose=0): 
    88         import orngTest, orngStat, orngMisc 
    89  
    90         evaluate = getattr(self, "evaluate", orngStat.CA) 
    91         folds = getattr(self, "folds", 5) 
    92         compare = getattr(self, "compare", cmp) 
    93         verbose = verbose or getattr(self, "verbose", 0) 
    94         returnWhat = getattr(self, "returnWhat", Tune1Parameter.returnClassifier) 
    95         progressCallback = getattr(self, "progressCallback", lambda i: None) 
    96          
    97         to_set = [] 
    98         parnames = [] 
    99         for par in self.parameters: 
    100             if (type(par[0])==list) or (type(par[0])==tuple): 
    101                 to_set.append([self.findobj(ld) for ld in par[0]]) 
    102                 parnames.append(par[0]) 
    103             else: 
    104                 to_set.append([self.findobj(par[0])]) 
    105                 parnames.append([par[0]]) 
    106  
    107  
    108         cvind = orange.MakeRandomIndicesCV(table, folds) 
    109         findBest = orngMisc.BestOnTheFly(seed = table.checksum(), callCompareOn1st = True) 
    110         tableAndWeight = weight and (table, weight) or table 
    111         numOfTests = sum([len(x[1]) for x in self.parameters]) 
    112         milestones = set(range(0, numOfTests, max(numOfTests / 100, 1))) 
    113         for itercount, valueindices in enumerate(orngMisc.LimitedCounter([len(x[1]) for x in self.parameters])): 
    114             values = [self.parameters[i][1][x] for i,x in enumerate(valueindices)] 
    115             for pi, value in enumerate(values): 
    116                 for i, par in enumerate(to_set[pi]): 
    117                     setattr(par[0], par[1], value) 
    118                     if verbose==2: 
    119                         print "%s: %s" % (parnames[pi][i], value) 
    120                          
    121             res = evaluate(orngTest.testWithIndices([self.object], tableAndWeight, cvind)) 
    122             if itercount in milestones: 
    123                 progressCallback(100.0 * itercount / numOfTests) 
    124              
    125             findBest.candidate((res, values)) 
    126             if verbose==2: 
    127                 print "===> Result: %s\n" % res 
    128  
    129         bestpar = findBest.winner()[1] 
    130         if verbose: 
    131             print "*** Optimal set of parameters: ", 
    132         for pi, value in enumerate(bestpar): 
    133             for i, par in enumerate(to_set[pi]): 
    134                 setattr(par[0], par[1], value) 
    135                 if verbose: 
    136                     print "%s: %s" % (parnames[pi][i], value), 
    137         if verbose: 
    138             print 
    139  
    140         if returnWhat==Tune1Parameter.returnNone: 
    141             return None 
    142         elif returnWhat==Tune1Parameter.returnParameters: 
    143             return bestpar 
    144         elif returnWhat==Tune1Parameter.returnLearner: 
    145             return self.object 
    146         else: 
    147             classifier = self.object(table) 
    148             classifier.fittedParameters = bestpar 
    149             return classifier 
    150  
    151  
    152  
    153  
    154 class ThresholdLearner(orange.Learner): 
    155     def __new__(cls, examples = None, weightID = 0, **kwds): 
    156         self = orange.Learner.__new__(cls, **kwds) 
    157         self.__dict__.update(kwds) 
    158         if examples: 
    159             return self.__call__(examples, weightID) 
    160         else: 
    161             return self 
    162  
    163     def __call__(self, examples, weightID = 0): 
    164         if not hasattr(self, "learner"): 
    165             raise "learner not set" 
    166          
    167         classifier = self.learner(examples, weightID) 
    168         threshold, optCA, curve = orange.ThresholdCA(classifier, examples, weightID) 
    169         if getattr(self, "storeCurve", 0): 
    170             return ThresholdClassifier(classifier, threshold, curve = curve) 
    171         else: 
    172             return ThresholdClassifier(classifier, threshold) 
    173  
    174 class ThresholdClassifier(orange.Classifier): 
    175     def __init__(self, classifier, threshold, **kwds): 
    176         self.classifier = classifier 
    177         self.threshold = threshold 
    178         self.__dict__.update(kwds) 
    179  
    180     def __call__(self, example, what = orange.Classifier.GetValue): 
    181         probs = self.classifier(example, self.GetProbabilities) 
    182         if what == self.GetProbabilities: 
    183             return probs 
    184         value = orange.Value(self.classifier.classVar, probs[1]>self.threshold) 
    185         if what == orange.Classifier.GetValue: 
    186             return value 
    187         else: 
    188             return (value, probs) 
    189  
    190 def ThresholdLearner_fixed(learner, threshold, examples = None, weightId = 0, **kwds): 
    191     lr = apply(ThresholdLearner_fixed_Class, (learner, threshold), kwds) 
    192     if examples: 
    193         return lr(examples, weightId) 
    194     else: 
    195         return lr 
    196      
    197 class ThresholdLearner_fixed(orange.Learner): 
    198     def __new__(cls, examples = None, weightID = 0, **kwds): 
    199         self = orange.Learner.__new__(cls, **kwds) 
    200         self.__dict__.update(kwds) 
    201         if examples: 
    202             return self.__call__(examples, weightID) 
    203         else: 
    204             return self 
    205  
    206     def __call__(self, examples, weightID = 0): 
    207         if not hasattr(self, "learner"): 
    208             raise "learner not set" 
    209         if not hasattr(self, "threshold"): 
    210             raise "threshold not set" 
    211         if len(examples.domain.classVar.values)!=2: 
    212             raise "ThresholdLearner handles binary classes only" 
    213          
    214         return ThresholdClassifier(self.learner(examples, weightID), self.threshold) 
    215  
    216 class PreprocessedLearner(object): 
    217     def __new__(cls, preprocessor = None, learner = None): 
    218         self = object.__new__(cls) 
    219         if learner is not None: 
    220             self.__init__(preprocessor) 
    221             return self.wrapLearner(learner) 
    222         else: 
    223             return self 
    224          
    225     def __init__(self, preprocessor = None, learner = None): 
    226         if isinstance(preprocessor, list): 
    227             self.preprocessors = preprocessor 
    228         elif preprocessor is not None: 
    229             self.preprocessors = [preprocessor] 
    230         else: 
    231             self.preprocessors = [] 
    232         #self.preprocessors = [orange.Preprocessor_addClassNoise(proportion=0.8)] 
    233         if learner: 
    234             self.wrapLearner(learner) 
    235          
    236     def processData(self, data, weightId = None): 
    237         import orange 
    238         hadWeight = hasWeight = weightId is not None 
    239         for preprocessor in self.preprocessors: 
    240             t = preprocessor(data, weightId) if hasWeight else preprocessor(data) 
    241             if isinstance(t, tuple): 
    242                 data, weightId = t 
    243                 hasWeight = True 
    244             else: 
    245                 data = t 
    246         if hadWeight: 
    247             return data, weightId 
    248         else: 
    249             return data 
    250  
    251     def wrapLearner(self, learner): 
    252         class WrappedLearner(learner.__class__): 
    253             preprocessor = self 
    254             wrappedLearner = learner 
    255             name = getattr(learner, "name", "") 
    256             def __call__(self, data, weightId=0, getData = False): 
    257                 t = self.preprocessor.processData(data, weightId or 0) 
    258                 processed, procW = t if isinstance(t, tuple) else (t, 0) 
    259                 classifier = self.wrappedLearner(processed, procW) 
    260                 if getData: 
    261                     return classifier, processed 
    262                 else: 
    263                     return classifier # super(WrappedLearner, self).__call__(processed, procW) 
    264                  
    265             def __reduce__(self): 
    266                 return PreprocessedLearner, (self.preprocessor.preprocessors, self.wrappedLearner) 
    267              
    268             def __getattr__(self, name): 
    269                 return getattr(learner, name) 
    270              
    271         return WrappedLearner() 
     1from Orange.optimization import * 
Note: See TracChangeset for help on using the changeset viewer.