Ignore:
Timestamp:
01/11/12 12:31:25 (2 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
9bac6f61bf567a3f5a2e21c1b2f38c45c0a6ada9
Message:

No longer stores signal ids in the learner instances (the ids makes the output unpicklable and can break the widget if the same learner is used in multiple TestLearner widgets).
Fixed recomputeCM ('self.learners.values()' and 'self.result.learners' order mismatch).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Evaluate/OWTestLearners.py

    r9510 r9528  
    2121class Learner: 
    2222    def __init__(self, learner, id): 
    23         learner.id = id 
    2423        self.learner = learner 
    2524        self.name = learner.name 
     
    268267            for id in ids: 
    269268                self.learners[id].results = None 
     269                self.learners[id].scores = [] 
    270270            return 
    271271        # test which learners can accept the given data set 
    272272        # e.g., regressions can't deal with classification data 
    273273        learners = [] 
     274        used_ids = [] 
    274275        n = len(self.data.domain.attributes)*2 
    275276        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible) 
     
    286287                if predictor(new[0]).varType == new.domain.classVar.varType: 
    287288                    learners.append(learner) 
     289                    used_ids.append(l.id) 
    288290                else: 
    289291                    l.scores = [] 
     292                    l.results = None 
     293                 
    290294            except Exception, ex: 
    291295                learner_exceptions.append((l, ex)) 
    292296                l.scores = [] 
    293  
     297                l.results = None 
     298         
    294299        if learner_exceptions: 
    295300            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \ 
     
    345350            if l.learner in learners: 
    346351                l.results = res 
     352            else: 
     353                l.results = None 
    347354 
    348355        self.error(range(len(self.stat))) 
    349356        scores = [] 
    350          
    351          
    352357         
    353358        for i, s in enumerate(self.stat): 
    354359            if s.cmBased: 
    355360                try: 
    356 #                    scores.append(eval("orngStat." + s.f)) 
    357361                    scores.append(dispatch(s, res, cm)) 
    358362                except Exception, ex: 
     
    364368                for res_one in orngStat.split_by_classifiers(res): 
    365369                    try: 
    366 #                        scores_one.append(eval("orngStat." + s.f)[0]) 
    367370                        scores_one.extend(dispatch(s, res_one, cm)) 
    368371                    except Exception, ex: 
     
    372375                scores.append(scores_one) 
    373376                 
    374         for i, l in enumerate(learners): 
    375             self.learners[l.id].scores = [s[i] if s else None for s in scores] 
     377        for i, (id, l) in enumerate(zip(used_ids, learners)): 
     378            self.learners[id].scores = [s[i] if s else None for s in scores] 
    376379             
    377380        self.sendResults() 
     
    385388        for (indx, score) in scores: 
    386389            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]): 
    387                 l.scores[indx] = score[i] 
     390                learner_indx = self.results.learners.index(l.learner) 
     391                l.scores[indx] = score[learner_indx] 
     392                 
    388393        self.paintscores() 
    389394         
    390  
     395    def clearScores(self, ids=None): 
     396        if ids is None: 
     397            ids = self.learners.keys() 
     398             
     399        for id in ids: 
     400            self.learners[id].scores = [] 
     401            self.learners[id].results = None 
     402         
    391403    # handle input signals 
    392  
    393404    def setData(self, data): 
    394405        """handle input train data set""" 
     
    398409        if not self.data: 
    399410            # data was removed, remove the scores 
    400             for l in self.learners.values(): 
    401                 l.scores = [] 
    402                 l.results = None 
     411            self.clearScores() 
    403412            self.send("Evaluation Results", None) 
    404413        else: 
    405414            # new data has arrived 
     415            self.clearScores() 
     416             
    406417            self.data = orange.Filter_hasClassValue(self.data) 
    407418            self.statLayout.setCurrentWidget(self.cbox if self.isclassification() else self.rbox) 
     
    467478                res = self.learners[id].results 
    468479                if res and res.numberOfLearners > 1: 
    469                     indx = [l.id for l in res.learners].index(id) 
     480                    old_learner = self.learners[id].learner 
     481                    indx = res.learners.index(old_learner) 
    470482                    res.remove(indx) 
    471483                    del res.learners[indx] 
     
    487499        # and remember the corresponding index 
    488500 
    489         valid = [(l.results, [x.id for x in l.results.learners].index(l.id)) 
     501        valid = [(l.results, l.results.learners.index(l.learner)) 
    490502                 for l in self.learners.values() if l.scores and l.results] 
    491503             
     
    496508        # find the result set for a largest number of learners 
    497509        # and remove this set from the list of result sets 
    498         rlist = dict([(l.results,1) for l in self.learners.values() if l.scores]).keys() 
     510        rlist = dict([(l.results,1) for l in self.learners.values() if l.scores and l.results]).keys() 
    499511        rlen = [r.numberOfLearners for r in rlist] 
    500512        results = rlist.pop(rlen.index(max(rlen))) 
    501513         
    502514        for (i, l) in enumerate(results.learners): 
    503             if not l.id in self.learners: 
     515            if not l in [l.learner for l in self.learners.values()]: 
    504516                results.remove(i) 
    505517                del results.learners[i] 
     518             
    506519        for r in rlist: 
    507520            for (i, l) in enumerate(r.learners): 
     521                learner_id = [l1.id for l1 in self.learners.values() if l1.learner is l][0] 
    508522                if (r, i) in valid: 
    509523                    results.add(r, i) 
    510524                    results.learners.append(r.learners[i]) 
    511                     self.learners[r.learners[i].id].results = results 
     525                    self.learners[learner_id].results = results 
    512526        self.send("Evaluation Results", results) 
    513527        self.results = results 
Note: See TracChangeset for help on using the changeset viewer.