source: orange/Orange/OrangeWidgets/Regression/OWEarth.py @ 11455:ea4f1b23ca68

Revision 11455:ea4f1b23ca68, 4.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 12 months ago (diff)

Preserve meta attributes in 'Basis matrix' output.

RevLine 
[9223]1"""
2<name>Earth</name>
3<description>Multivariate Adaptive Regression Splines (MARS)</description>
4<category>Regression</category>
[11217]5<icon>icons/EarthMars.svg</icon>
[9223]6<priority>100</priority>
[10776]7<tags>MARS, Multivariate, Adaptive, Regression, Splines</tags>
[9223]8"""
[8151]9
10from OWWidget import *
11import OWGUI
12import Orange
13
14from Orange.regression import earth
[8152]15from orngWrap import PreprocessedLearner
[11455]16
17
[8151]18class OWEarth(OWWidget):
[8833]19    settingsList = ["name", "degree", "terms", "penalty"]
[10776]20
[8151]21    def __init__(self, parent=None, signalManager=None,
[9142]22                 title="Earth"):
[11455]23        OWWidget.__init__(self, parent, signalManager, title,
24                          wantMainArea=False)
[10776]25
26        self.inputs = [("Data", Orange.data.Table, self.set_data),
[11455]27                       ("Preprocessor", PreprocessedLearner,
28                        self.set_preprocessor)]
[10776]29
30        self.outputs = [("Learner", earth.EarthLearner, Default),
31                        ("Predictor", earth.EarthClassifier, Default),
32                        ("Basis Matrix", Orange.data.Table)]
33
[8151]34        self.name = "Earth Learner"
[8152]35        self.degree = 1
[8151]36        self.terms = 21
37        self.penalty = 2
[10776]38
[8152]39        self.loadSettings()
[10776]40
[8151]41        #####
42        # GUI
43        #####
[10776]44
[11455]45        OWGUI.lineEdit(self.controlArea, self, "name",
[8151]46                       box="Learner/Classifier Name",
47                       tooltip="Name for the learner/predictor")
[10776]48
[8151]49        box = OWGUI.widgetBox(self.controlArea, "Forward Pass", addSpace=True)
50        OWGUI.spin(box, self, "degree", 1, 3, step=1,
[11455]51                   label="Max. term degree",
52                   tooltip="Maximum degree of the terms derived "
53                           "(number of hinge functions).")
[9343]54        s = OWGUI.spin(box, self, "terms", 1, 200, step=1,
55                       label="Max. terms",
[11455]56                       tooltip="Maximum number of terms derived in the "
57                               "forward pass.")
[9343]58        s.control.setSpecialValueText("Automatic")
[10776]59
[9142]60        box = OWGUI.widgetBox(self.controlArea, "Pruning Pass", addSpace=True)
[8151]61        OWGUI.doubleSpin(box, self, "penalty", min=0.0, max=10.0, step=0.25,
62                   label="Knot penalty")
[10776]63
[8151]64        OWGUI.button(self.controlArea, self, "&Apply",
65                     callback=self.apply)
[10776]66
[8151]67        self.data = None
68        self.preprocessor = None
69        self.resize(300, 200)
[10776]70
[8151]71        self.apply()
[10776]72
[8151]73    def set_data(self, data=None):
74        self.data = data
[10776]75
[8151]76    def set_preprocessor(self, pproc=None):
77        self.preprocessor = pproc
[10776]78
[8151]79    def handleNewSignals(self):
80        self.apply()
[10776]81
[8151]82    def apply(self):
[11455]83        learner = earth.EarthLearner(
84            degree=self.degree,
85            terms=self.terms if self.terms >= 2 else None,
86            penalty=self.penalty,
87            name=self.name)
88
[8151]89        predictor = None
[10776]90        basis_matrix = None
[8151]91        if self.preprocessor:
92            learner = self.preprocessor.wrapLearner(learner)
[10776]93
[8151]94        self.error(0)
95        if self.data is not None:
96            try:
97                predictor = learner(self.data)
98                predictor.name = self.name
99            except Exception, ex:
100                self.error(0, "An error during learning: %r" % ex)
[10776]101
[11455]102            if predictor is not None:
[10776]103                base_features = predictor.base_features()
[11455]104                basis_domain = Orange.data.Domain(
105                    base_features,
106                    self.data.domain.class_var,
107                    self.data.domain.class_vars)
108                basis_domain.add_metas(self.data.domain.get_metas())
[10776]109                basis_matrix = Orange.data.Table(basis_domain, self.data)
110
[8151]111        self.send("Learner", learner)
112        self.send("Predictor", predictor)
[10776]113        self.send("Basis Matrix", basis_matrix)
114
[9304]115    def sendReport(self):
[11455]116        self.reportSettings(
117            "Learning parameters",
118            [("Degree", self.degree),
119             ("Terms", self.terms if self.terms >= 2 else "Automatic"),
120             ("Knot penalty", "%.2f" % self.penalty)
121             ])
122
[9304]123        self.reportData(self.data)
[10776]124
[8151]125if __name__ == "__main__":
126    app = QApplication(sys.argv)
127    w = OWEarth()
128    w.set_data(Orange.data.Table("auto-mpg"))
129    w.show()
130    app.exec_()
[8152]131    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.