source: orange/Orange/OrangeWidgets/Regression/OWEarth.py @ 11455:ea4f1b23ca68

Revision 11455:ea4f1b23ca68, 4.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 12 months ago (diff)

Preserve meta attributes in 'Basis matrix' output.

Line 
1"""
2<name>Earth</name>
3<description>Multivariate Adaptive Regression Splines (MARS)</description>
4<category>Regression</category>
5<icon>icons/EarthMars.svg</icon>
6<priority>100</priority>
7<tags>MARS, Multivariate, Adaptive, Regression, Splines</tags>
8"""
9
10from OWWidget import *
11import OWGUI
12import Orange
13
14from Orange.regression import earth
15from orngWrap import PreprocessedLearner
16
17
18class OWEarth(OWWidget):
19    settingsList = ["name", "degree", "terms", "penalty"]
20
21    def __init__(self, parent=None, signalManager=None,
22                 title="Earth"):
23        OWWidget.__init__(self, parent, signalManager, title,
24                          wantMainArea=False)
25
26        self.inputs = [("Data", Orange.data.Table, self.set_data),
27                       ("Preprocessor", PreprocessedLearner,
28                        self.set_preprocessor)]
29
30        self.outputs = [("Learner", earth.EarthLearner, Default),
31                        ("Predictor", earth.EarthClassifier, Default),
32                        ("Basis Matrix", Orange.data.Table)]
33
34        self.name = "Earth Learner"
35        self.degree = 1
36        self.terms = 21
37        self.penalty = 2
38
39        self.loadSettings()
40
41        #####
42        # GUI
43        #####
44
45        OWGUI.lineEdit(self.controlArea, self, "name",
46                       box="Learner/Classifier Name",
47                       tooltip="Name for the learner/predictor")
48
49        box = OWGUI.widgetBox(self.controlArea, "Forward Pass", addSpace=True)
50        OWGUI.spin(box, self, "degree", 1, 3, step=1,
51                   label="Max. term degree",
52                   tooltip="Maximum degree of the terms derived "
53                           "(number of hinge functions).")
54        s = OWGUI.spin(box, self, "terms", 1, 200, step=1,
55                       label="Max. terms",
56                       tooltip="Maximum number of terms derived in the "
57                               "forward pass.")
58        s.control.setSpecialValueText("Automatic")
59
60        box = OWGUI.widgetBox(self.controlArea, "Pruning Pass", addSpace=True)
61        OWGUI.doubleSpin(box, self, "penalty", min=0.0, max=10.0, step=0.25,
62                   label="Knot penalty")
63
64        OWGUI.button(self.controlArea, self, "&Apply",
65                     callback=self.apply)
66
67        self.data = None
68        self.preprocessor = None
69        self.resize(300, 200)
70
71        self.apply()
72
73    def set_data(self, data=None):
74        self.data = data
75
76    def set_preprocessor(self, pproc=None):
77        self.preprocessor = pproc
78
79    def handleNewSignals(self):
80        self.apply()
81
82    def apply(self):
83        learner = earth.EarthLearner(
84            degree=self.degree,
85            terms=self.terms if self.terms >= 2 else None,
86            penalty=self.penalty,
87            name=self.name)
88
89        predictor = None
90        basis_matrix = None
91        if self.preprocessor:
92            learner = self.preprocessor.wrapLearner(learner)
93
94        self.error(0)
95        if self.data is not None:
96            try:
97                predictor = learner(self.data)
98                predictor.name = self.name
99            except Exception, ex:
100                self.error(0, "An error during learning: %r" % ex)
101
102            if predictor is not None:
103                base_features = predictor.base_features()
104                basis_domain = Orange.data.Domain(
105                    base_features,
106                    self.data.domain.class_var,
107                    self.data.domain.class_vars)
108                basis_domain.add_metas(self.data.domain.get_metas())
109                basis_matrix = Orange.data.Table(basis_domain, self.data)
110
111        self.send("Learner", learner)
112        self.send("Predictor", predictor)
113        self.send("Basis Matrix", basis_matrix)
114
115    def sendReport(self):
116        self.reportSettings(
117            "Learning parameters",
118            [("Degree", self.degree),
119             ("Terms", self.terms if self.terms >= 2 else "Automatic"),
120             ("Knot penalty", "%.2f" % self.penalty)
121             ])
122
123        self.reportData(self.data)
124
125if __name__ == "__main__":
126    app = QApplication(sys.argv)
127    w = OWEarth()
128    w.set_data(Orange.data.Table("auto-mpg"))
129    w.show()
130    app.exec_()
131    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.