Ignore:
Timestamp:
01/09/12 17:41:08 (2 years ago)
Author:
lanz <lan.zagar@…>
Branch:
default
Convert:
c228716ea07062e32d54b33dee5b05d4d3c443fd
Message:

Updated PLS to use the multi-target data format.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/pls.py

    r9248 r9523  
    226226        if x_vars is None and y_vars is None: 
    227227            # Response variables are defined in the table. 
    228             label_mask = data_label_mask(domain) 
    229             multilabel_flag = (sum(label_mask) - (1 if domain.class_var else 0)) > 0 
    230             x_vars = [v for v, label in zip(domain, label_mask) if not label] 
    231             y_vars = [v for v, label in zip(domain, label_mask) if label] 
     228            x_vars = domain.features 
     229            if domain.class_var: 
     230                y_vars = [domain.class_var] 
     231                multitarget = False 
     232            elif domain.class_vars: 
     233                y_vars = domain.class_vars 
     234                multitarget = True 
     235            else: 
     236                raise TypeError('Class-less domain (x-vars and y-vars needed).') 
    232237            x_table = select_attrs(table, x_vars) 
    233238            y_table = select_attrs(table, y_vars) 
    234              
    235239        elif x_vars and y_vars: 
    236240            # independent and response variables are passed by the caller 
    237             if domain.class_var and domain.class_var not in y_vars: 
    238                 # if the original table contains class variable 
    239                 # add it to the y_vars 
    240                 y_vars.append(domain.class_var) 
    241             label_mask = [v in y_vars for v in domain.variables] 
    242             multilabel_flag = True 
    243             x_table = select_attrs(table, x_vars) 
    244             y_table = select_attrs(table, y_vars) 
     241            multitarget = True 
    245242        else: 
    246243            raise ValueError("Both x_vars and y_vars must be defined.") 
     244 
     245        x_table = select_attrs(table, x_vars) 
     246        y_table = select_attrs(table, y_vars) 
    247247 
    248248        # dicrete values are continuized         
     
    258258         
    259259        domain = Orange.data.Domain(x_vars + y_vars, False) 
    260         label_mask = [False for _ in x_vars] + [True for _ in y_vars] 
    261          
    262         x = x_table.toNumpy()[0] 
    263         y = y_table.toNumpy()[0] 
     260         
     261        x = x_table.to_numpy()[0] 
     262        y = y_table.to_numpy()[0] 
    264263         
    265264        kwargs = self.fit(x, y) 
    266         return PLSRegression(label_mask=label_mask, domain=domain, \ 
    267 #                             coefs=self.coefs, muX=self.muX, muY=self.muY, \ 
    268 #                             sigmaX=self.sigmaX, sigmaY=self.sigmaY, \ 
    269                              x_vars=x_vars, y_vars=y_vars, 
    270                              multilabel_flag=multilabel_flag, **kwargs) 
     265        return PLSRegression(domain=domain, x_vars=x_vars, y_vars=y_vars, 
     266                             **kwargs) 
    271267 
    272268    def fit(self, X, Y): 
     
    401397         
    402398    """ 
    403     def __init__(self, label_mask=None, domain=None, \ 
    404                  coefs=None, mu_x=None, mu_y=None, sigma_x=None, sigma_y=None, \ 
    405                  x_vars=None, y_vars=None, multilabel_flag=0, **kwargs): 
    406         self.label_mask = label_mask 
     399    def __init__(self, domain=None, multitarget=False, coefs=None, sigma_x=None, sigma_y=None, 
     400                 mu_x=None, mu_y=None, x_vars=None, y_vars=None, **kwargs): 
    407401        self.domain = domain 
     402        self.multitarget = multitarget 
    408403        self.coefs = coefs 
    409404        self.mu_x, self.mu_y = mu_x, mu_y 
    410405        self.sigma_x, self.sigma_y = sigma_x, sigma_y 
    411406        self.x_vars, self.y_vars = x_vars, y_vars 
    412         self.multilabel_flag = multilabel_flag 
    413         if not multilabel_flag and y_vars: 
    414             self.class_var = y_vars[0] 
    415407             
    416408        for name, val in kwargs.items(): 
     
    435427        y_hat = [var(val) for var, val in zip(self.y_vars, predicted)] 
    436428        if result_type == Orange.core.GetValue: 
    437             return y_hat if self.multilabel_flag else y_hat[0] 
     429            return y_hat if self.multitarget else y_hat[0] 
    438430        else: 
    439431            from Orange.statistics.distribution import Distribution 
     
    444436                probs.append(dist) 
    445437            if result_type == Orange.core.GetBoth: 
    446                 return (y_hat, probs) if self.multilabel_flag else (y_hat[0], probs[0]) 
     438                return (y_hat, probs) if self.multitarget else (y_hat[0], probs[0]) 
    447439            else: 
    448                 return probs if self.multilabel_flag else probs[0] 
     440                return probs if self.multitarget else probs[0] 
    449441             
    450     def print_pls_regression_coefficients(self): 
     442    def to_string(self): 
    451443        """ Pretty-prints the coefficient of the PLS regression model. 
    452444        """        
    453445        x_vars, y_vars = [x.name for x in self.x_vars], [y.name for y in self.y_vars] 
    454         print " " * 7 + "%-6s " * len(y_vars) % tuple(y_vars) 
    455446        fmt = "%-6s " + "%-5.3f  " * len(y_vars) 
    456         for i, coef in enumerate(self.coefs): 
    457             print fmt % tuple([x_vars[i]] + list(coef)) 
     447        first = [" " * 7 + "%-6s " * len(y_vars) % tuple(y_vars)] 
     448        lines = [fmt % tuple([x_vars[i]] + list(coef)) 
     449                 for i, coef in enumerate(self.coefs)] 
     450        return '\n'.join(first + lines) 
    458451             
     452    def __str__(self): 
     453        return self.to_string() 
     454 
    459455    """ 
    460456    def transform(self, X, Y=None): 
     
    487483    from Orange.regression import pls 
    488484 
    489     table = Orange.data.Table("test-pls.tab") 
     485    data = Orange.data.Table("test-pls.tab") 
    490486    l = pls.PLSRegressionLearner() 
    491487 
    492     x = [var for var in table.domain if var.name[0]=="X"] 
    493     y = [var for var in table.domain if var.name[0]=="Y"] 
     488    x = [var for var in data.domain.features if var.name[0]=="X"] 
     489    y = [var for var in data.domain.class_vars if var.name[0]=="Y"] 
    494490    print x, y 
    495 #    c = l(table, x_vars=x, y_vars=y) 
    496     c = l(table) 
    497     c.print_pls_regression_coefficients() 
     491#    c = l(data, x_vars=x, y_vars=y) 
     492    c = l(data) 
     493 
     494    print c 
Note: See TracChangeset for help on using the changeset viewer.