Changeset 7518:04dc4b55eb90 in orange


Ignore:
Timestamp:
02/04/11 20:04:39 (3 years ago)
Author:
miha <miha.stajdohar@…>
Branch:
default
Convert:
7833a295bb770d8a36ed22ff4b61704b469be043
Message:
 
File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/optmization/__init__.py

    r7515 r7518  
    1  
     1"""  
     2.. index:: optimization 
     3 
     4Wrappers for Tuning Parameters and Thresholds 
     5 
     6Classes for two very useful purposes: tuning learning algorithm's parameters 
     7using internal validation and tuning the threshold for classification into 
     8positive class. 
     9 
     10================= 
     11Tuning parameters 
     12================= 
     13 
     14.. autoclass:: Orange.optimization.TuneParameters 
     15   :members: 
     16 
     17.. autoclass:: Orange.optimization.Tune1Parameter 
     18   :members: 
     19   
     20.. autoclass:: Orange.optimization.TuneMParameters 
     21   :members:  
     22    
     23========================== 
     24Setting Optimal Thresholds 
     25========================== 
     26   
     27.. autoclass:: Orange.optimization.ThresholdLearner 
     28   :members:  
     29      
     30.. autoclass:: Orange.optimization.ThresholdClassifier 
     31   :members:  
     32 
     33.. autoclass:: Orange.optimization.PreprocessedLearner 
     34   :members:  
     35 
     36""" 
     37 
     38import Orange.core 
     39 
     40# The class needs to be given 
     41#   object     - the learning algorithm to be fitted 
     42#   evaluate   - statistics to evaluate (default: orngStat.CA) 
     43#   folds      - the number of folds for internal cross validation 
     44#   compare    - function to compare (default: cmp - the bigger the better) 
     45#   returnWhat - tells whether to return values of parameters, a fitted 
     46#                learner, the best classifier or None. "object" is left 
     47#                with optimal parameters in any case 
     48class TuneParameters(Orange.core.Learner): 
     49    """Tune 
     50     
     51    """ 
     52     
     53    returnNone=0 
     54    returnParameters=1 
     55    returnLearner=2 
     56    returnClassifier=3 
     57     
     58    def __new__(cls, examples = None, weightID = 0, **argkw): 
     59        self = Orange.core.Learner.__new__(cls, **argkw) 
     60        self.__dict__.update(argkw) 
     61        if examples: 
     62            return self.__call__(examples, weightID) 
     63        else: 
     64            return self 
     65 
     66    def findobj(self, name): 
     67        import string 
     68        names=string.split(name, ".") 
     69        lastobj=self.object 
     70        for i in names[:-1]: 
     71            lastobj=getattr(lastobj, i) 
     72        return lastobj, names[-1] 
     73         
     74# Same arguments as TuneParameters, plus: 
     75#   parameter  - a string or a list of strings with parameter(s) to fit 
     76#   values     - possible values of the parameter 
     77#                (eg <object>.<parameter> = <value>[i]) 
     78class Tune1Parameter(TuneParameters): 
     79    def __call__(self, table, weight=None, verbose=0): 
     80        import orngTest, orngStat, orngMisc 
     81 
     82        verbose = verbose or getattr(self, "verbose", 0) 
     83        evaluate = getattr(self, "evaluate", orngStat.CA) 
     84        folds = getattr(self, "folds", 5) 
     85        compare = getattr(self, "compare", cmp) 
     86        returnWhat = getattr(self, "returnWhat", Tune1Parameter.returnClassifier) 
     87 
     88        if (type(self.parameter)==list) or (type(self.parameter)==tuple): 
     89            to_set = [self.findobj(ld) for ld in self.parameter] 
     90        else: 
     91            to_set = [self.findobj(self.parameter)] 
     92 
     93        cvind = Orange.core.MakeRandomIndicesCV(table, folds) 
     94        findBest = orngMisc.BestOnTheFly(seed = table.checksum(), callCompareOn1st = True) 
     95        tableAndWeight = weight and (table, weight) or table 
     96        for par in self.values: 
     97            for i in to_set: 
     98                setattr(i[0], i[1], par) 
     99            res = evaluate(orngTest.testWithIndices([self.object], tableAndWeight, cvind)) 
     100            findBest.candidate((res, par)) 
     101            if verbose==2: 
     102                print '*** orngWrap  %s: %s:' % (par, res) 
     103 
     104        bestpar = findBest.winner()[1] 
     105        for i in to_set: 
     106            setattr(i[0], i[1], bestpar) 
     107 
     108        if verbose: 
     109            print "*** Optimal parameter: %s = %s" % (self.parameter, bestpar) 
     110 
     111        if returnWhat==Tune1Parameter.returnNone: 
     112            return None 
     113        elif returnWhat==Tune1Parameter.returnParameters: 
     114            return bestpar 
     115        elif returnWhat==Tune1Parameter.returnLearner: 
     116            return self.object 
     117        else: 
     118            classifier = self.object(table) 
     119            classifier.setattr("fittedParameter", bestpar) 
     120            return classifier 
     121 
     122 
     123# Same arguments as TuneParameters, plus 
     124#   parameters - a list of tuples with parameters to be fitted and the 
     125#                corresponding possible values, [(parameter(s), values), ...] 
     126#                (eg <object>.<parameter[j]> = <value[j]>[i]) 
     127class TuneMParameters(TuneParameters): 
     128    def __call__(self, table, weight=None, verbose=0): 
     129        import orngTest, orngStat, orngMisc 
     130 
     131        evaluate = getattr(self, "evaluate", orngStat.CA) 
     132        folds = getattr(self, "folds", 5) 
     133        compare = getattr(self, "compare", cmp) 
     134        verbose = verbose or getattr(self, "verbose", 0) 
     135        returnWhat = getattr(self, "returnWhat", Tune1Parameter.returnClassifier) 
     136        progressCallback = getattr(self, "progressCallback", lambda i: None) 
     137         
     138        to_set = [] 
     139        parnames = [] 
     140        for par in self.parameters: 
     141            if (type(par[0])==list) or (type(par[0])==tuple): 
     142                to_set.append([self.findobj(ld) for ld in par[0]]) 
     143                parnames.append(par[0]) 
     144            else: 
     145                to_set.append([self.findobj(par[0])]) 
     146                parnames.append([par[0]]) 
     147 
     148 
     149        cvind = Orange.core.MakeRandomIndicesCV(table, folds) 
     150        findBest = orngMisc.BestOnTheFly(seed = table.checksum(), callCompareOn1st = True) 
     151        tableAndWeight = weight and (table, weight) or table 
     152        numOfTests = sum([len(x[1]) for x in self.parameters]) 
     153        milestones = set(range(0, numOfTests, max(numOfTests / 100, 1))) 
     154        for itercount, valueindices in enumerate(orngMisc.LimitedCounter([len(x[1]) for x in self.parameters])): 
     155            values = [self.parameters[i][1][x] for i,x in enumerate(valueindices)] 
     156            for pi, value in enumerate(values): 
     157                for i, par in enumerate(to_set[pi]): 
     158                    setattr(par[0], par[1], value) 
     159                    if verbose==2: 
     160                        print "%s: %s" % (parnames[pi][i], value) 
     161                         
     162            res = evaluate(orngTest.testWithIndices([self.object], tableAndWeight, cvind)) 
     163            if itercount in milestones: 
     164                progressCallback(100.0 * itercount / numOfTests) 
     165             
     166            findBest.candidate((res, values)) 
     167            if verbose==2: 
     168                print "===> Result: %s\n" % res 
     169 
     170        bestpar = findBest.winner()[1] 
     171        if verbose: 
     172            print "*** Optimal set of parameters: ", 
     173        for pi, value in enumerate(bestpar): 
     174            for i, par in enumerate(to_set[pi]): 
     175                setattr(par[0], par[1], value) 
     176                if verbose: 
     177                    print "%s: %s" % (parnames[pi][i], value), 
     178        if verbose: 
     179            print 
     180 
     181        if returnWhat==Tune1Parameter.returnNone: 
     182            return None 
     183        elif returnWhat==Tune1Parameter.returnParameters: 
     184            return bestpar 
     185        elif returnWhat==Tune1Parameter.returnLearner: 
     186            return self.object 
     187        else: 
     188            classifier = self.object(table) 
     189            classifier.fittedParameters = bestpar 
     190            return classifier 
     191 
     192 
     193 
     194 
     195class ThresholdLearner(Orange.core.Learner): 
     196    def __new__(cls, examples = None, weightID = 0, **kwds): 
     197        self = Orange.core.Learner.__new__(cls, **kwds) 
     198        self.__dict__.update(kwds) 
     199        if examples: 
     200            return self.__call__(examples, weightID) 
     201        else: 
     202            return self 
     203 
     204    def __call__(self, examples, weightID = 0): 
     205        if not hasattr(self, "learner"): 
     206            raise "learner not set" 
     207         
     208        classifier = self.learner(examples, weightID) 
     209        threshold, optCA, curve = Orange.core.ThresholdCA(classifier, examples, weightID) 
     210        if getattr(self, "storeCurve", 0): 
     211            return ThresholdClassifier(classifier, threshold, curve = curve) 
     212        else: 
     213            return ThresholdClassifier(classifier, threshold) 
     214 
     215class ThresholdClassifier(Orange.core.Classifier): 
     216    def __init__(self, classifier, threshold, **kwds): 
     217        self.classifier = classifier 
     218        self.threshold = threshold 
     219        self.__dict__.update(kwds) 
     220 
     221    def __call__(self, example, what = Orange.core.Classifier.GetValue): 
     222        probs = self.classifier(example, self.GetProbabilities) 
     223        if what == self.GetProbabilities: 
     224            return probs 
     225        value = Orange.core.Value(self.classifier.classVar, probs[1]>self.threshold) 
     226        if what == Orange.core.Classifier.GetValue: 
     227            return value 
     228        else: 
     229            return (value, probs) 
     230 
     231def ThresholdLearner_fixed(learner, threshold, examples = None, weightId = 0, **kwds): 
     232    lr = apply(ThresholdLearner_fixed_Class, (learner, threshold), kwds) 
     233    if examples: 
     234        return lr(examples, weightId) 
     235    else: 
     236        return lr 
     237     
     238class ThresholdLearner_fixed(Orange.core.Learner): 
     239    def __new__(cls, examples = None, weightID = 0, **kwds): 
     240        self = Orange.core.Learner.__new__(cls, **kwds) 
     241        self.__dict__.update(kwds) 
     242        if examples: 
     243            return self.__call__(examples, weightID) 
     244        else: 
     245            return self 
     246 
     247    def __call__(self, examples, weightID = 0): 
     248        if not hasattr(self, "learner"): 
     249            raise "learner not set" 
     250        if not hasattr(self, "threshold"): 
     251            raise "threshold not set" 
     252        if len(examples.domain.classVar.values)!=2: 
     253            raise "ThresholdLearner handles binary classes only" 
     254         
     255        return ThresholdClassifier(self.learner(examples, weightID), self.threshold) 
     256 
     257class PreprocessedLearner(object): 
     258    def __new__(cls, preprocessor = None, learner = None): 
     259        self = object.__new__(cls) 
     260        if learner is not None: 
     261            self.__init__(preprocessor) 
     262            return self.wrapLearner(learner) 
     263        else: 
     264            return self 
     265         
     266    def __init__(self, preprocessor = None, learner = None): 
     267        if isinstance(preprocessor, list): 
     268            self.preprocessors = preprocessor 
     269        elif preprocessor is not None: 
     270            self.preprocessors = [preprocessor] 
     271        else: 
     272            self.preprocessors = [] 
     273        #self.preprocessors = [Orange.core.Preprocessor_addClassNoise(proportion=0.8)] 
     274        if learner: 
     275            self.wrapLearner(learner) 
     276         
     277    def processData(self, data, weightId = None): 
     278        hadWeight = hasWeight = weightId is not None 
     279        for preprocessor in self.preprocessors: 
     280            t = preprocessor(data, weightId) if hasWeight else preprocessor(data) 
     281            if isinstance(t, tuple): 
     282                data, weightId = t 
     283                hasWeight = True 
     284            else: 
     285                data = t 
     286        if hadWeight: 
     287            return data, weightId 
     288        else: 
     289            return data 
     290 
     291    def wrapLearner(self, learner): 
     292        class WrappedLearner(learner.__class__): 
     293            preprocessor = self 
     294            wrappedLearner = learner 
     295            name = getattr(learner, "name", "") 
     296            def __call__(self, data, weightId=0, getData = False): 
     297                t = self.preprocessor.processData(data, weightId or 0) 
     298                processed, procW = t if isinstance(t, tuple) else (t, 0) 
     299                classifier = self.wrappedLearner(processed, procW) 
     300                if getData: 
     301                    return classifier, processed 
     302                else: 
     303                    return classifier # super(WrappedLearner, self).__call__(processed, procW) 
     304                 
     305            def __reduce__(self): 
     306                return PreprocessedLearner, (self.preprocessor.preprocessors, self.wrappedLearner) 
     307             
     308            def __getattr__(self, name): 
     309                return getattr(learner, name) 
     310             
     311        return WrappedLearner() 
Note: See TracChangeset for help on using the changeset viewer.