Changeset 9113:9136a8634278 in orange


Ignore:
Timestamp:
10/17/11 15:52:45 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
b8ec9febc29433617140c5024ffd72e6f8382834
Message:

Remember of the train data are from a multilabel domain or a regular domain with class var.
Use the information when classifying new instances to return a list or single values.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/pls.py

    r9112 r9113  
    262262            # Response variables are defined in the table. 
    263263            label_mask = data_label_mask(domain) 
     264            multilabel_flag = (sum(label_mask) - (1 if domain.class_var else 0)) > 0 
    264265            xVars = [v for v, label in zip(domain, label_mask) if not label] 
    265266            yVars = [v for v, label in zip(domain, label_mask) if label] 
     
    274275                yVars.append(domain.class_var) 
    275276            label_mask = [v in yVars for v in domain.variables] 
    276              
     277            multilabel_flag = True 
    277278            x_table = select_attrs(table, xVars) 
    278279            y_table = select_attrs(table, yVars) 
     
    322323                             coefs=self.coefs, muX=self.muX, muY=self.muY, \ 
    323324                             sigmaX=self.sigmaX, sigmaY=self.sigmaY, \ 
    324                              xVars=xVars, yVars=yVars) 
     325                             xVars=xVars, yVars=yVars, multilabel_flag=multilabel_flag) 
    325326 
    326327    def fit(self, X, Y): 
     
    422423    def __init__(self, label_mask=None, domain=None, \ 
    423424                 coefs=None, muX=None, muY=None, sigmaX=None, sigmaY=None, \ 
    424                  xVars=None, yVars=None): 
     425                 xVars=None, yVars=None, multilabel_flag=0): 
    425426        self.label_mask = label_mask 
    426427        self.domain = domain 
     
    429430        self.sigmaX, self.sigmaY = sigmaX, sigmaY 
    430431        self.xVars, self.yVars = xVars, yVars 
     432        self.multilabel_flag = multilabel_flag 
    431433 
    432434    def __call__(self, instance, result_type=Orange.core.GetValue): 
     
    439441        instance = Orange.data.Instance(self.domain, instance) 
    440442        ins = [instance[v].native() for v in self.xVars] 
     443         
     444        multilabel_flag = sum(data_label_mask(self.domain)) 
    441445         
    442446        if "?" in ins: # missing value -> corresponding coefficient omitted 
     
    446450        xc = (ins - self.muX) / self.sigmaX 
    447451        predicted = dot(xc, self.coefs) * self.sigmaY + self.muY 
    448         yHat = [var(val) for var, val in zip(self.yVars, predicted)] 
     452        y_hat = [var(val) for var, val in zip(self.yVars, predicted)] 
    449453        if result_type == Orange.core.GetValue: 
    450             return yHat 
     454            return y_hat if self.multilabel_flag else y_hat[0] 
    451455        else: 
    452456            from Orange.statistics.distribution import Distribution 
    453457            probs = [] 
    454             for var, val in zip(self.yVars, yHat): 
     458            for var, val in zip(self.yVars, y_hat): 
    455459                dist = Distribution(var) 
    456460                dist[val] = 1.0 
    457461                probs.append(dist) 
    458462            if result_type == Orange.core.GetBoth: 
    459                 return zip(yHat, probs)  
     463                return zip(y_hat, probs) if self.multilabel_flag else (y_hat[0], probs[0]) 
    460464            else: 
    461                 return probs 
     465                return probs if self.multilabel_flag else probs[0] 
    462466             
    463467    def print_pls_regression_coefficients(self): 
Note: See TracChangeset for help on using the changeset viewer.