source: orange/Orange/OrangeWidgets/Regression/OWLinearRegression.py @ 10961:e38223aa3c8a

Revision 10961:e38223aa3c8a, 7.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 21 months ago (diff)

Fixed Linear Regression widget's parameters for Lasso regression.

Line 
1"""
2<name>Linear Regression</name>
3<description>Linear Regression</name>
4<icon>icons/LinearRegression.png</icon>
5<priority>10</priority>
6<category>Regression</category>
7<keywords>linear, model, ridge, regression, lasso, least, absolute, shrinkage</keywords>
8
9"""
10
11import sys
12from OWWidget import *
13
14import Orange
15from Orange.regression import linear, lasso
16from orngWrap import PreprocessedLearner
17from Orange import feature as variable
18
19
20class OWLinearRegression(OWWidget):
21    settingsList = ["name", "intercept", "use_ridge", "ridge_lambda",
22                    "use_lasso", "lasso_lambda", "eps"]
23
24    def __init__(self, parent=None, signalManager=None,
25                 title="Linear Regression"):
26        OWWidget.__init__(self, parent, signalManager, title,
27                          wantMainArea=False)
28
29        self.inputs = [("Data", Orange.data.Table, self.set_data),
30                       ("Preprocessor", PreprocessedLearner,
31                        self.set_preprocessor)]
32
33        self.outputs = [("Learner", Orange.core.Learner),
34                        ("Predictor", Orange.core.Classifier),
35                        ("Model Statistics", Orange.data.Table)]
36
37        ##########
38        # Settings
39        ##########
40
41        self.name = "Linear Regression"
42        self.intercept = True
43        self.use_ridge = False
44        self.ridge_lambda = 1.0
45        self.use_lasso = False
46        self.lasso_lambda = 0.1
47        self.eps = 1e-6
48
49        self.loadSettings()
50
51        #####
52        # GUI
53        #####
54
55        OWGUI.lineEdit(self.controlArea, self, "name",
56                       box="Learner/predictor name",
57                       tooltip="Name of the learner/predictor")
58
59        OWGUI.checkBox(self.controlArea, self, 'intercept', 'Intercept')
60
61        bbox = OWGUI.radioButtonsInBox(self.controlArea, self, "use_lasso", [],
62                                       box=None,
63                                       callback=self.on_method_changed
64                                       )
65
66        rb = OWGUI.appendRadioButton(bbox, self, "use_lasso",
67                                     label="Ordinary/Ridge Linear Regression",
68                                     tooltip="",
69                                     insertInto=bbox)
70
71        self.lm_box = box = OWGUI.indentedBox(
72            self.controlArea, sep=OWGUI.checkButtonOffsetHint(rb)
73            )
74
75        self.lm_box.setEnabled(not self.use_lasso)
76
77        OWGUI.doubleSpin(box, self, "ridge_lambda", 0.1, 100, step=0.1,
78                         label="Ridge lambda", checked="use_ridge",
79                         tooltip="Ridge lambda for ridge regression")
80
81        rb = OWGUI.appendRadioButton(bbox, self, "use_lasso",
82                                     label="LASSO Regression",
83                                     tooltip="",
84                                     insertInto=bbox)
85
86        self.lasso_box = box = OWGUI.indentedBox(
87             self.controlArea, sep=OWGUI.checkButtonOffsetHint(rb)
88             )
89
90        self.lasso_box.setEnabled(self.use_lasso)
91
92        OWGUI.doubleSpin(box, self, "lasso_lambda", 0.0, 100.0, 1e-2,
93                         label="Lasso lambda",
94                         )
95
96        OWGUI.doubleSpin(box, self, "eps", 0.0, 0.01, 1e-7,
97                         label="Tolerance",
98                         tooltip="Numerical tolerance."
99                         )
100
101        OWGUI.rubber(self.controlArea)
102
103        OWGUI.button(self.controlArea, self, "&Apply",
104                     callback=self.apply,
105                     tooltip="Send the learner/classifier on output",
106                     autoDefault=True)
107
108        self.data = None
109        self.preprocessor = None
110        self.resize(300, 100)
111        self.apply()
112
113    def set_data(self, data=None):
114        if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous, 
115                                    checkMissing=True):
116            data = None
117        self.data = data
118
119    def set_preprocessor(self, pproc=None):
120        self.preprocessor = pproc
121
122    def handleNewSignals(self):
123        self.apply()
124
125    def on_method_changed(self):
126        self.lm_box.setEnabled(not self.use_lasso)
127        self.lasso_box.setEnabled(self.use_lasso)
128
129    def apply(self):
130        if self.use_lasso:
131            self.apply_lasso()
132        else:
133            self.apply_ridge()
134
135    def apply_ridge(self):
136        if self.use_ridge:
137            learner = linear.LinearRegressionLearner(name=self.name,
138                intercept=self.intercept, ridgeLambda=self.ridge_lambda)
139        else:
140            learner = linear.LinearRegressionLearner(name=self.name,
141                intercept=self.intercept)
142        predictor = None
143        if self.preprocessor:
144            learner = self.preprocessor.wrapLearner(learner)
145
146        self.error(0)
147        if self.data is not None:
148            try:
149                predictor = learner(self.data)
150                predictor.name = self.name
151            except Exception, ex:
152                self.error(0, "An error during learning: %r" % ex)
153
154        self.send("Learner", learner)
155        self.send("Predictor", predictor)
156        self.send("Model Statistics", self.statistics_olr(predictor))
157
158    def apply_lasso(self):
159        learner = lasso.LassoRegressionLearner(
160            lasso_lambda=self.lasso_lambda, eps=self.eps,
161            n_boot=0, n_perm=0,
162            name=self.name
163            )
164
165        predictor = None
166
167        if self.preprocessor is not None:
168            learner = self.preprocessor.wrapLearner(learner)
169
170        self.error(0)
171        try:
172            if self.data is not None:
173                ll = lasso.LassoRegressionLearner(
174                    lasso_lambda=self.lasso_lambda, eps=self.eps,
175                    n_boot=10, n_perm=10
176                    )
177                predictor = ll(self.data)
178                predictor.name = self.name
179        except Exception, ex:
180            self.error(0, "An error during learning: %r" % ex)
181
182        self.send("Learner", learner)
183        self.send("Predictor", predictor)
184        self.send("Model Statistics", self.statistics_lasso(predictor))
185
186    def statistics_olr(self, m):
187        if m is None:
188            return None
189
190        columns = [variable.String("Variable"),
191                   variable.Continuous("Coeff Est"),
192                   variable.Continuous("Std Error"),
193                   variable.Continuous("t-value"),
194                   variable.Continuous("p")]
195
196        domain = Orange.data.Domain(columns, None)
197        vars = ["Intercept"] if m.intercept else []
198        vars.extend([a.name for a in m.domain.attributes])
199        stats = []
200        geti = lambda list, i: list[i] if list is not None else "?"
201
202        for i, var in enumerate(vars):
203            coef = m.coefficients[i]
204            std_err = geti(m.std_error, i)
205            t_val = geti(m.t_scores, i)
206            p = geti(m.p_vals, i)
207            stats.append([var, coef, std_err, t_val, p])
208
209        return Orange.data.Table(domain, stats)
210
211    def statistics_lasso(self, m):
212        if m is None:
213            return None
214
215        columns = [variable.String("Variable"),
216                   variable.Continuous("Coeff Est"),
217                   variable.Continuous("Std Error"),
218                   variable.Continuous("p")]
219
220        domain = Orange.data.Domain(columns, None)
221        vars = []
222        vars.extend([a.name for a in m.domain.attributes])
223        stats = [["Intercept", m.coef0, "?", "?"]]
224        geti = lambda list, i: list[i] if list is not None else "?"
225
226        for i, var in enumerate(vars):
227            coef = m.coefficients[i]
228            std_err = geti(m.std_errors, i)
229            p = geti(m.p_vals, i)
230            stats.append([var, coef, std_err, p])
231
232        return Orange.data.Table(domain, stats)
233
234
235if __name__ == "__main__":
236    app = QApplication(sys.argv)
237    w = OWLinearRegression()
238    w.set_data(Orange.data.Table("housing"))
239    w.show()
240    app.exec_()
241#    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.