source: orange/orange/OrangeWidgets/Regression/OWEarth.py @ 9343:7a1287d57379

Revision 9343:7a1287d57379, 3.7 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Added option to set the terms automatically similar to Earth in R.

Line 
1"""
2<name>Earth</name>
3<description>Multivariate Adaptive Regression Splines (MARS)</description>
4<category>Regression</category>
5<icon>icons/Earth.png<icon>
6<priority>100</priority>
7<keywords>MARS, Multivariate, Adaptive, Regression, Splines</keywords>
8"""
9
10from OWWidget import *
11import OWGUI
12import Orange
13
14from Orange.regression import earth
15from orngWrap import PreprocessedLearner
16 
17class OWEarth(OWWidget):
18    settingsList = ["name", "degree", "terms", "penalty"]
19   
20    def __init__(self, parent=None, signalManager=None,
21                 title="Earth"):
22        OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False)
23       
24        self.inputs = [("Training data", Orange.data.Table, self.set_data), ("Preprocessor", PreprocessedLearner, self.set_preprocessor)]
25        self.outputs = [("Learner", earth.EarthLearner), ("Predictor", earth.EarthClassifier)]
26       
27        self.name = "Earth Learner"
28        self.degree = 1
29        self.terms = 21
30        self.penalty = 2
31       
32        self.loadSettings()
33       
34        #####
35        # GUI
36        #####
37       
38        OWGUI.lineEdit(self.controlArea, self, "name", 
39                       box="Learner/Classifier Name",
40                       tooltip="Name for the learner/predictor")
41       
42        box = OWGUI.widgetBox(self.controlArea, "Forward Pass", addSpace=True)
43        OWGUI.spin(box, self, "degree", 1, 3, step=1,
44                   label="Max. term degree", 
45                   tooltip="Maximum degree of the terms derived (number of hinge functions).")
46        s = OWGUI.spin(box, self, "terms", 1, 200, step=1,
47                       label="Max. terms",
48                       tooltip="Maximum number of terms derived in the forward pass.")
49        s.control.setSpecialValueText("Automatic")
50       
51        box = OWGUI.widgetBox(self.controlArea, "Pruning Pass", addSpace=True)
52        OWGUI.doubleSpin(box, self, "penalty", min=0.0, max=10.0, step=0.25,
53                   label="Knot penalty")
54       
55        OWGUI.button(self.controlArea, self, "&Apply",
56                     callback=self.apply)
57       
58        self.data = None
59        self.preprocessor = None
60        self.resize(300, 200)
61       
62        self.apply()
63       
64    def set_data(self, data=None):
65        self.data = data
66           
67    def set_preprocessor(self, pproc=None):
68        self.preprocessor = pproc
69       
70    def handleNewSignals(self):
71        self.apply()
72           
73    def apply(self):
74        learner = earth.EarthLearner(degree=self.degree,
75                                    terms=self.terms if self.terms >= 2 else None,
76                                    penalty=self.penalty,
77                                    name=self.name)
78        predictor = None
79        if self.preprocessor:
80            learner = self.preprocessor.wrapLearner(learner)
81       
82        self.error(0)
83        if self.data is not None:
84            try:
85                predictor = learner(self.data)
86                predictor.name = self.name
87            except Exception, ex:
88                self.error(0, "An error during learning: %r" % ex)
89           
90        self.send("Learner", learner)
91        self.send("Predictor", predictor)
92       
93    def sendReport(self):
94        self.reportSettings("Learning parameters", 
95                            [("Degree", self.degree),
96                             ("Terms", self.terms if self.terms >= 2 else "Automatic"),
97                             ("Knot penalty", "%.2f" % self.penalty)
98                             ])
99        self.reportData(self.data)
100       
101if __name__ == "__main__":
102    app = QApplication(sys.argv)
103    w = OWEarth()
104    w.set_data(Orange.data.Table("auto-mpg"))
105    w.show()
106    app.exec_()
107    w.saveSettings()
108           
109           
110       
111       
112       
113       
Note: See TracBrowser for help on using the repository browser.