Changeset 8740:b61c00c29d53 in orange


Ignore:
Timestamp:
08/23/11 14:47:35 (3 years ago)
Author:
martin <martin@…>
Branch:
default
Convert:
c766c4d8d3d4d4458e892e3bad65dfc1df74a6ff
Message:
 
Location:
orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • orange/orngABCN2.py

    r8246 r8740  
    2121from Orange.classification.rules import CN2UnorderedClassifier 
    2222from orngABML import * 
     23 
     24print createDichotomousClass 
  • orange/orngABML.py

    r8246 r8740  
    66import warnings 
    77 
    8 import Orange.classification.rules 
    98import numpy 
    109import math 
     
    275274        else: 
    276275            return Orange.core.Value(self.newClassAtt, "not " + self.classValue) 
    277      
     276 
    278277def createDichotomousClass(domain, att, value, negate, removeAtt = None): 
    279278    # create new variable 
     
    315314 
    316315def addErrors(test_data, classifier): 
    317     """ Main task of this function is to add probabilistic errors to examples. Function 
    318         also computes classification accuracy as a by product.""" 
    319     correct        = 0.0 
    320     clDistribution = orange.Distribution(test_data.domain.classVar) 
     316    """ Main task of this function is to add probabilistic errors to examples.""" 
    321317    for ex_i, ex in enumerate(test_data): 
    322         (cl,prob) = classifier(ex,orange.GetBoth) 
    323         # add prob difference to ProbError 
    324         if prob[ex.getclass()] < 0.01: 
    325             print prob 
    326         ex.setmeta("ProbError", ex.getmeta("ProbError") + 1.-prob[ex.getclass()])  
    327         ex.setmeta("ProbErrorSD", ex.getmeta("ProbErrorSD") + math.pow(1.-prob[ex.getclass()],2)) 
    328         if test_data.domain.hasmeta("CoveredRules"): 
    329             for r_i,r in enumerate(classifier.rules): 
    330                 if r(ex) and (not hasattr(cl, "ruleBetas") or cl.ruleBetas[r_i]): 
    331                     ex.setmeta("CoveredRules", ex.getmeta("CoveredRules")+orngCN2.ruleToString(r)+";") 
    332         if cl == ex.getclass(): # compute classification accuracy (just for fun) 
    333             correct += 1. 
    334             clDistribution[ex.getclass()] += 1. 
    335     apriori = orange.Distribution(test_data.domain.classVar,test_data) 
    336     correct /= len(test_data) 
    337     for cl_v in test_data.domain.classVar: 
    338         clDistribution[cl_v] /= max(apriori[cl_v],1) 
    339     # add space between different test set in "covered rules" 
    340     for ex in test_data: 
    341         ex.setmeta("CoveredRules", ex.getmeta("CoveredRules")+";      ;") 
    342     return (correct,clDistribution) 
    343  
    344  
    345 def nCrossValidation(data,learner,weightID=0,folds=5,n=4,gen=0): 
     318        (cl,prob) = classifier(ex,Orange.core.GetBoth) 
     319        ex.setmeta("ProbError", float(ex.getmeta("ProbError")) + 1.-prob[ex.getclass()])  
     320 
     321def nCrossValidation(data,learner,weightID=0,folds=5,n=4,gen=0,argument_id="Arguments"): 
    346322    """ Function performs n x fold crossvalidation. For each classifier 
    347323        test set is updated by calling function addErrors. """ 
    348324    acc = 0.0 
    349     dist = orange.Distribution(data.domain.classVar)     
    350     pick = orange.MakeRandomIndicesCV(folds=folds, randseed=gen, stratified = orange.MakeRandomIndices.StratifiedIfPossible)     
     325    rules = {} 
    351326    for d in data: 
    352         d.setmeta("ProbError",0.) 
    353         d.setmeta("ProbErrorSD",0.) 
    354 ##        d.setmeta("Rules","") 
     327        rules[float(d["SerialNumberPE"])] = [] 
     328    pick = Orange.core.MakeRandomIndicesCV(folds=folds, randseed=gen, stratified = Orange.core.MakeRandomIndices.StratifiedIfPossible)     
    355329    for n_i in range(n): 
    356330        pick.randseed = gen+10*n_i 
     
    359333            for data_i,e in enumerate(data): 
    360334                try: 
    361                     if e["Arguments"]: # examples with arguments do not need to be tested 
     335                    if e[argument_id]: # examples with arguments do not need to be tested 
    362336                        selection[data_i]=folds_i+1 
    363337                except: 
     
    366340            test_data = data.selectref(selection, folds_i,negate=0) 
    367341            classifier = learner(train_data,weightID) 
    368             classifier.setattr("data", train_data) 
    369             acc1,dist1 = addErrors(test_data, classifier) 
    370  
    371             print "N=%d, Folds=%d: %s %s" % (n_i+1, folds_i, acc1, dist1) 
    372             acc += acc1 
    373             for cl_i in range(len(data.domain.classVar.values)): 
    374                 dist[cl_i] += dist1[cl_i] 
    375  
    376     for e in data: 
    377         avgProb = e.getmeta("ProbError")/n 
    378         sumSq = e.getmeta("ProbErrorSD") 
    379         sumProb = e.getmeta("ProbError") 
    380         if n>1 and (sumSq-2*avgProb*sumProb+n*math.pow(avgProb,2))/(n-1) > 0.: 
    381             e.setmeta("ProbErrorSD", math.sqrt((sumSq-2*avgProb*sumProb+n*math.pow(avgProb,2))/(n-1))) 
    382         else: 
    383             e.setmeta("ProbErrorSD", 0.) 
    384         e.setmeta("ProbError", avgProb) 
    385     acc = acc/n/folds 
    386     for cl_v in test_data.domain.classVar: 
    387         dist[cl_v] /= n*folds 
    388     return (acc,dist) 
    389  
    390  
    391 def findProb(learner,examples,weightID=0,folds=5,n=4,gen=0): 
     342            addErrors(test_data, classifier) 
     343            # add rules 
     344            for d in test_data: 
     345                for r in classifier.rules: 
     346                    if r(d): 
     347                        rules[float(d["SerialNumberPE"])].append(r) 
     348    # normalize prob errors 
     349    for d in data: 
     350        d["ProbError"]=d["ProbError"]/n 
     351    return rules 
     352 
     353def findProb(learner,examples,weightID=0,folds=5,n=4,gen=0,thr=0.5,argument_id="Arguments"): 
    392354    """ General method for calling to find problematic example. 
    393         It returns all examples along with average probabilistic errors. 
     355        It returns all critial examples along with average probabilistic errors that ought to be higher then thr. 
    394356        Taking the one with highest error is the same as taking the most 
    395357        problematic example. """ 
    396     newDomain = orange.Domain(examples.domain.attributes, examples.domain.classVar) 
     358 
     359    newDomain = Orange.core.Domain(examples.domain.attributes, examples.domain.classVar) 
    397360    newDomain.addmetas(examples.domain.getmetas()) 
    398     newExamples = orange.ExampleTable(newDomain, examples) 
     361    newExamples = Orange.core.ExampleTable(newDomain, examples) 
    399362    if not newExamples.domain.hasmeta("ProbError"): 
    400         newId = orange.newmetaid() 
    401         newDomain.addmeta(newId, orange.FloatVariable("ProbError")) 
    402         newExamples = orange.ExampleTable(newDomain, examples) 
    403     if not newExamples.domain.hasmeta("ProbErrorSD"): 
    404         newId = orange.newmetaid() 
    405         newDomain.addmeta(newId, orange.FloatVariable("ProbErrorSD")) 
    406         newExamples = orange.ExampleTable(newDomain, examples) 
    407     if not newExamples.domain.hasmeta("CoveredRules"): 
    408         newId = orange.newmetaid() 
    409         newDomain.addmeta(newId, orange.StringVariable("CoveredRules")) 
    410         newExamples = orange.ExampleTable(newDomain, examples)         
     363        newId = Orange.core.newmetaid() 
     364        newDomain.addmeta(newId, Orange.core.FloatVariable("ProbError")) 
     365        newExamples = Orange.core.ExampleTable(newDomain, examples) 
    411366    if not newExamples.domain.hasmeta("SerialNumberPE"): 
    412         newId = orange.newmetaid() 
    413         newDomain.addmeta(newId, orange.FloatVariable("SerialNumberPE")) 
    414         newExamples = orange.ExampleTable(newDomain, examples) 
    415 ##        newExamples.domain.addmeta(newId, orange.FloatVariable("SerialNumberPE")) 
    416 ##        newExamples.addMetaAttribute("SerialNumberPE", 0.) 
    417         for i in range(len(newExamples)): 
    418             newExamples[i]["SerialNumberPE"] = float(i) 
    419     trs = nCrossValidation(newExamples,learner,weightID=weightID, folds=folds, n=n, gen=gen) 
    420     return newExamples             
     367        newId = Orange.core.newmetaid() 
     368        newDomain.addmeta(newId, Orange.core.FloatVariable("SerialNumberPE")) 
     369        newExamples = Orange.core.ExampleTable(newDomain, examples) 
     370    for i in range(len(newExamples)): 
     371        newExamples[i]["SerialNumberPE"] = float(i) 
     372        newExamples[i]["ProbError"] = 0. 
     373 
     374    # it returns a list of examples now: (index of example-starting with 0, example, prob error, rules covering example 
     375    rules = nCrossValidation(newExamples,learner,weightID=weightID, folds=folds, n=n, gen=gen, argument_id=argument_id) 
     376    return [(ei, examples[ei], float(e["ProbError"]), rules[float(e["SerialNumberPE"])]) for ei, e in enumerate(newExamples) if e["ProbError"] > thr] 
     377   
Note: See TracChangeset for help on using the changeset viewer.