source: orange/Orange/doc/ofb/roc.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 2.1 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1# Description: Implementation of AUC (area under ROC curve) statistics, test of different methods through 10-fold cross validation (warning: for educational purposes only, use orngEval for estimation of AUC and similar statistics)
2# Category:    evaluation
3# Uses:        voting.tab
4# Classes:     MakeRandomIndicesCV
5# Referenced:  c_performance.htm
6
7
8import orange, orngTree
9
10def aroc(data, classifiers):
11    ar = []
12    for c in classifiers:
13        p = []
14        for d in data:
15            p.append(c(d, orange.GetProbabilities)[0])
16        correct = 0.0; valid = 0.0
17        for i in range(len(data)-1):
18            for j in range(i+1,len(data)):
19                if data[i].getclass() <> data[j].getclass():
20                    valid += 1
21                    if p[i] == p[j]:
22                        correct += 0.5
23                    elif data[i].getclass() == 0:
24                        if p[i] > p[j]:
25                            correct += 1.0
26                    else:
27                        if p[j] > p[i]:
28                            correct += 1.0
29        ar.append(correct / valid)
30    return ar
31
32def cross_validation(data, learners, k=10):
33    ar = [0.0]*len(learners)
34    selection = orange.MakeRandomIndicesCV(data, folds=k)
35    for test_fold in range(k):
36        train_data = data.select(selection, test_fold, negate=1)
37        test_data = data.select(selection, test_fold)
38        classifiers = []
39        for l in learners:
40            classifiers.append(l(train_data))
41        result = aroc(test_data, classifiers)
42        for j in range(len(learners)):
43            ar[j] += result[j]
44    for j in range(len(learners)):
45        ar[j] = ar[j]/k
46    return ar
47
48orange.setrandseed(0)   
49# set up the learners
50bayes = orange.BayesLearner()
51tree = orngTree.TreeLearner(mForPruning=2)
52maj = orange.MajorityLearner()
53bayes.name = "bayes"
54tree.name = "tree"
55maj.name = "majority"
56learners = [bayes, tree, maj]
57
58# compute accuracies on data
59data = orange.ExampleTable("voting")
60acc = cross_validation(data, learners, k=10)
61print "Area under ROC:"
62for i in range(len(learners)):
63    print learners[i].name, "%.2f" % acc[i]
Note: See TracBrowser for help on using the repository browser.