source: orange/Orange/regression/lasso.py @ 10859:08a0a35c1687

Revision 10859:08a0a35c1687, 11.3 KB checked in by Lan Zagar <lan.zagar@…>, 2 years ago (diff)

Reimplemented lasso. Breaks compatibility.

It now uses a proximal gradient method for optimization instead of using scipy.optimize (see #1118).
The formulation is slightly different so there are new parameters (mainly lasso_lambda instead of t/s).
Improved some other things as well.

Line 
1from numpy import dot, std, array, zeros, maximum, sqrt, sign, log, abs, \
2                  ascontiguousarray, random as rnd
3from scipy.linalg import norm, eigh
4
5import Orange
6from Orange.utils import deprecated_members, deprecated_keywords
7
8
9def get_bootstrap_sample(data):
10    """Generate a bootstrap sample of a given data set.
11
12    :param data: the original data sample
13    :type data: :class:`Orange.data.Table`
14    """
15    n = len(data)
16    bootstrap = Orange.data.Table(data.domain)
17    for id in rnd.randint(0, n, n):
18        bootstrap.append(data[id])
19    return bootstrap
20
21def permute_responses(data):
22    """Permute values of the class (response) variable.
23    The independence between independent variables and the response
24    is obtained but the distribution of the response variable is kept.
25
26    :param data: Original data.
27    :type data: :class:`Orange.data.Table`
28    """
29    n = len(data)
30    perm = rnd.permutation(n)
31    perm_data = Orange.data.Table(data.domain, data)
32    for i, ins in enumerate(data):
33        perm_data[i].set_class(data[perm[i]].get_class())
34    return perm_data
35
36class LassoRegressionLearner(Orange.regression.base.BaseRegressionLearner):
37    """Fits the lasso regression model using FISTA
38    (Fast Iterative Shrinkage-Thresholding Algorithm).
39    """
40
41    def __init__(self, lasso_lambda=0.1, max_iter=20000, eps=1e-6,
42                 n_boot=0, n_perm=0, imputer=None, continuizer=None,
43                 name='Lasso'):
44        """
45        :param lasso_lambda: Regularization parameter.
46        :type lasso_lambda: float
47
48        :param max_iter: Maximum number of iterations for
49                         the optimization method.
50        :type max_iter: int
51
52        :param eps: Stop optimization when improvements are lower than eps.
53        :type eps: float
54
55        :param n_boot: Number of bootstrap samples used for non-parametric
56                       estimation of standard errors.
57        :type n_boot: int
58
59        :param n_perm: Number of permuations used for non-parametric
60                       estimation of p-values.
61        :type n_perm: int
62
63        :param name: Learner name.
64        :type name: str
65       
66        """
67        self.lasso_lambda = lasso_lambda
68        self.max_iter = max_iter
69        self.eps = eps
70        self.n_boot = n_boot
71        self.n_perm = n_perm
72        self.set_imputer(imputer=imputer)
73        self.set_continuizer(continuizer=continuizer)
74        self.name = name
75
76    def get_lipschitz(self, X):
77        """Return the Lipschitz constant of :math:`\\nabla f`,
78        where :math:`f(w) = \\frac{1}{2}||Xw-y||^2`.
79        """
80        n, m = X.shape
81        if n > m:
82            X = ascontiguousarray(X.T)
83        k = min(n, m) - 1
84        eigvals = eigh(dot(X, X.T), eigvals_only=True, eigvals=(k, k))
85        return eigvals[-1]
86
87    def fista(self, X, y, l, lipschitz, w_init=None):
88        """Fast Iterative Shrinkage-Thresholding Algorithm (FISTA)."""
89        z = w_old = zeros(X.shape[1]) if w_init is None else w_init
90        t_old, obj_old = 1, 1e400
91        XT = ascontiguousarray(X.T)
92        for i in range(self.max_iter):
93            z -= 1. / lipschitz * dot(XT, dot(X, z) - y)
94            w = maximum(0, abs(z) - l / lipschitz) * sign(z)
95            t = (1 + sqrt(1 + 4 * t_old**2)) / 2
96            z = w + (t_old - 1) / t * (w - w_old)
97            obj = ((y - dot(X, w))**2).sum() + l * norm(w, 1) 
98            if abs(obj_old - obj) / obj < self.eps:
99                stop += 1
100                if obj < obj_old and stop > log(i + 1):
101                    break
102            else:
103                stop = 0
104            w_old, t_old = w, t
105            obj_old = obj
106        return w
107
108    def __call__(self, data, weight=None):
109        """
110        :param data: Training data.
111        :type data: :class:`Orange.data.Table`
112        :param weight: Weights for instances. Not implemented yet.
113       
114        """
115        # dicrete values are continuized       
116        data = self.continuize_table(data)
117        # missing values are imputed
118        data = self.impute_table(data)
119        domain = data.domain
120        # prepare numpy matrices
121        X, y, _ = data.to_numpy()
122        n, m = X.shape
123        coefficients = zeros(m)
124        std_errors = array([float('nan')] * m)
125        p_vals = array([float('nan')] * m)
126        # standardize y
127        coef0, sigma_y = y.mean(), y.std() + 1e-6
128        y = (y - coef0) / sigma_y
129        # standardize X and remove constant vars
130        mu_x = X.mean(axis=0)
131        X -= mu_x
132        sigma_x = X.std(axis=0)
133        nz = sigma_x != 0
134        X = ascontiguousarray(X[:, nz])
135        sigma_x = sigma_x[nz]
136        X /= sigma_x
137        m = sum(nz)
138
139        # run optimization method
140        lipschitz = self.get_lipschitz(X)
141        l = 0.5 * self.lasso_lambda * n / m
142        coefficients[nz] = self.fista(X, y, l, lipschitz)
143        coefficients[nz] *= sigma_y / sigma_x
144
145        d = dict(self.__dict__)
146        d.update({'n_boot': 0, 'n_perm': 0})
147
148        # bootstrap estimator of standard error of the coefficient estimators
149        # assumption: fixed regularization parameter
150        if self.n_boot > 0:
151            coeff_b = [] # bootstrapped coefficients
152            for i in range(self.n_boot):
153                tmp_data = get_bootstrap_sample(data)
154                l = LassoRegressionLearner(**d)
155                c = l(tmp_data)
156                coeff_b.append(c.coefficients)
157            std_errors[nz] = std(coeff_b, axis=0)
158
159        # permutation test to obtain the significance of
160        # the regression coefficients
161        if self.n_perm > 0:
162            coeff_p = []
163            for i in range(self.n_perm):
164                tmp_data = permute_responses(data)
165                l = LassoRegressionLearner(**d)
166                c = l(tmp_data)
167                coeff_p.append(c.coefficients)
168            p_vals[nz] = (abs(coeff_p) > abs(coefficients)).sum(axis=0)
169            p_vals[nz] /= float(self.n_perm)
170
171        # dictionary of regression coefficients with standard errors
172        # and p-values
173        model = {}
174        for i, var in enumerate(domain.attributes):
175            model[var.name] = (coefficients[i], std_errors[i], p_vals[i])
176
177        return LassoRegression(domain=domain, class_var=domain.class_var,
178            coef0=coef0, coefficients=coefficients, std_errors=std_errors,
179            p_vals=p_vals, model=model, mu_x=mu_x)
180
181deprecated_members({"nBoot": "n_boot",
182                    "nPerm": "n_perm"},
183                   wrap_methods=["__init__"],
184                   in_place=True)(LassoRegressionLearner)
185
186class LassoRegression(Orange.classification.Classifier):
187    """Lasso regression predicts the value of the response variable
188    based on the values of independent variables.
189
190    .. attribute:: coef0
191
192        Intercept (sample mean of the response variable).   
193
194    .. attribute:: coefficients
195
196        Regression coefficients.
197
198    .. attribute:: std_errors
199
200        Standard errors of coefficient estimates for a fixed
201        regularization parameter. The standard errors are estimated
202        using the bootstrapping method.
203
204    .. attribute:: p_vals
205
206        List of p-values for the null hypotheses that the regression
207        coefficients equal 0 based on a non-parametric permutation test.
208
209    .. attribute:: model
210
211        Dictionary with the statistical properties of the model:
212        Keys - names of the independent variables
213        Values - tuples (coefficient, standard error, p-value)
214
215    .. attribute:: mu_x
216
217        Sample mean of independent variables.   
218
219    """
220    def __init__(self, domain=None, class_var=None, coef0=None,
221                 coefficients=None, std_errors=None, p_vals=None,
222                 model=None, mu_x=None):
223        self.domain = domain
224        self.class_var = class_var
225        self.coef0 = coef0
226        self.coefficients = coefficients
227        self.std_errors = std_errors
228        self.p_vals = p_vals
229        self.model = model
230        self.mu_x = mu_x
231
232    def _miss_2_0(self, x):
233        return x if x != '?' else 0
234
235    @deprecated_keywords({"resultType": "result_type"})
236    def __call__(self, instance, result_type=Orange.core.GetValue):
237        """
238        :param instance: Data instance for which the value of the response
239                         variable will be predicted.
240        :type instance: :obj:`Orange.data.Instance`
241        """
242        ins = Orange.data.Instance(self.domain, instance)
243        if '?' in ins: # missing value -> corresponding coefficient omitted
244            ins = map(self._miss_2_0, ins)
245            ins = array(ins)[:-1] - self.mu_x
246        else:
247            ins = array(ins.native())[:-1] - self.mu_x
248
249        y_hat = dot(self.coefficients, ins) + self.coef0
250        y_hat = self.class_var(y_hat)
251        dist = Orange.statistics.distribution.Continuous(self.class_var)
252        dist[y_hat] = 1.
253        if result_type == Orange.core.GetValue:
254            return y_hat
255        if result_type == Orange.core.GetProbabilities:
256            return dist
257        else:
258            return (y_hat, dist)
259
260    @deprecated_keywords({"skipZero": "skip_zero"})
261    def to_string(self, skip_zero=True):
262        """Pretty-prints a lasso regression model,
263        i.e. estimated regression coefficients with standard errors
264        and significances. Standard errors are obtained using the
265        bootstrapping method and significances by a permuation test.
266
267        :param skip_zero: If True, variables with estimated coefficient
268                          equal to 0 are omitted.
269        :type skip_zero: bool
270        """
271        labels = ('Variable', 'Coeff Est', 'Std Error', 'p')
272        lines = [' '.join(['%10s' % l for l in labels])]
273
274        fmt = '%10s ' + ' '.join(['%10.3f'] * 3) + ' %5s'
275        fmt1 = '%10s %10.3f'
276
277        def get_star(p):
278            if p < 0.001: return  '*' * 3
279            elif p < 0.01: return '*' * 2
280            elif p < 0.05: return '*'
281            elif p < 0.1: return  '.'
282            else: return ' '
283
284        stars = get_star(self.p_vals[0])
285        lines.append(fmt1 % ('Intercept', self.coef0))
286        skipped = []
287        for i in range(len(self.domain.attributes)):
288            if self.coefficients[i] == 0. and skip_zero:
289                skipped.append(self.domain.attributes[i].name)
290                continue
291            stars = get_star(self.p_vals[i])
292            lines.append(fmt % (self.domain.attributes[i].name,
293                         self.coefficients[i], self.std_errors[i],
294                         self.p_vals[i], stars))
295        lines.append('Signif. codes:  0 *** 0.001 ** 0.01 * 0.05 . 0.1 empty 1\n')
296        if skip_zero:
297            k = len(skipped)
298            if k == 0:
299                lines.append('All variables have non-zero regression coefficients.')
300            else:
301                lines.append('For %d variable%s the regression coefficient equals 0:'
302                             % (k, 's' if k > 1 else ''))
303                lines.append(', '.join(var for var in skipped))
304        return '\n'.join(lines)
305
306    def __str__(self):
307        return self.to_string(skip_zero=True)
308
309deprecated_members({"muX": "mu_x",
310                    "stdErrorsFixedT": "std_errors",
311                    "pVals": "p_vals",
312                    "dictModel": "model"},
313                   wrap_methods=["__init__"],
314                   in_place=True)(LassoRegression)
315
316if __name__ == "__main__":
317    data = Orange.data.Table('housing')
318    c = LassoRegressionLearner(data)
319    print c
Note: See TracBrowser for help on using the repository browser.