Changeset 10328:9bb531f21ac5 in orange


Ignore:
Timestamp:
02/21/12 15:49:26 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Using the new multi-target domain definition for multiresponse learning. Changed the base class to Orange.regression.base.BaseRegressionLearner.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/regression/earth.py

    r10140 r10328  
    2020    >>> import Orange 
    2121    >>> data = Orange.data.Table("housing") 
    22     >>> c = Orange.regression.earth.EarthLearner(data, degree=2) 
     22    >>> c = Orange.regression.earth.EarthLearner(data, degree=2, terms=10) 
    2323    >>> print c 
    2424    MEDV = 
     
    8989        new_vars.append(new) 
    9090    return new_vars 
    91      
    92 class EarthLearner(Orange.core.LearnerFD): 
    93     """ Earth learner class. Supports both regression and classification 
     91 
     92def select_attrs(table, features, class_var=None, 
     93                 class_vars=None, metas=None): 
     94    """ Select only ``attributes`` from the ``table``. 
     95    """ 
     96    if class_vars is None: 
     97        domain = Domain(features, class_var) 
     98    else: 
     99        domain = Domain(features, class_var, class_vars=class_vars) 
     100    if metas: 
     101        domain.add_metas(metas) 
     102    return Table(domain, table) 
     103     
     104class EarthLearner(Orange.regression.base.BaseRegressionLearner): 
     105    """Earth learner class. Supports both regression and classification 
    94106    problems. In case of classification the class values are expanded into  
    95107    continuous indicator columns (one for each value if the number of  
     
    100112    """ 
    101113    def __new__(cls, instances=None, weight_id=None, **kwargs): 
    102         self = Orange.core.LearnerFD.__new__(cls) 
     114        self = Orange.regression.base.BaseRegressionLearner.__new__(cls) 
    103115        if instances is not None: 
    104116            self.__init__(**kwargs) 
     
    110122                 min_span=0, new_var_penalty=0, fast_k=20, fast_beta=1, 
    111123                 pruned_terms=None, scale_resp=True, store_instances=True, 
    112                  multi_label=False, **kwds): 
    113         """ Initialize the learner instance. 
     124                **kwds): 
     125        """Initialize the learner instance. 
    114126         
    115127        :param degree: Maximum degree (num. of hinge functions per term) 
     
    145157            (default True). 
    146158        :type store_instances: bool 
    147         :param multi_label: If True build a multi label model (default False). 
    148         :type multi_label: bool   
    149159          
    150160        .. todo:: min_span, prunning_method (need Leaps like functionality, 
     
    152162         
    153163        """ 
     164         
     165        super(EarthLearner, self).__init__() 
     166         
    154167        self.degree = degree 
    155168        self.terms = terms 
     
    165178        self.scale_resp = scale_resp 
    166179        self.store_instances = store_instances 
    167         self.multi_label = multi_label 
    168180        self.__dict__.update(kwds) 
    169181         
    170         impute = Preprocessor_impute() 
    171         cont = Preprocessor_continuize(multinomialTreatment= 
    172                                        DomainContinuizer.AsOrdinal) 
    173         self.preproc = Preprocessor_preprocessorList(preprocessors=\ 
    174                                                      [impute, cont]) 
     182        self.continuizer.class_treatment = DomainContinuizer.Ignore 
    175183         
    176184    def __call__(self, instances, weight_id=None): 
    177         instances = self.preproc(instances) 
     185         
    178186        expanded_class = None 
    179         if self.multi_label: 
    180             label_mask = data_label_mask(instances.domain) 
    181             data = instances.to_numpy_MA("Ac")[0] 
    182             y = data[:, label_mask] 
    183             x = data[:, ~ label_mask] 
    184         else: 
    185             # Expand a discrete class with indicator columns 
     187        multitarget = False 
     188         
     189        if instances.domain.class_var: 
     190            instances = self.impute_table(instances) 
     191            instances = self.continuize_table(instances) 
     192             
    186193            if is_discrete(instances.domain.class_var): 
     194                # Expand a discrete class with indicator columns 
    187195                expanded_class = expand_discrete(instances.domain.class_var) 
    188                 y = Table(Domain(expanded_class, None), instances) 
    189                 y = y.to_numpy_MA("A")[0] 
    190                 x = instances.to_numpy_MA("A")[0] 
    191                 label_mask = [False] * x.shape[1] + [True] * y.shape[1] 
    192                 label_mask = numpy.array(label_mask) 
     196                y_table = select_attrs(instances, expanded_class) 
     197                (y, ) = y_table.to_numpy_MA("A") 
     198                (x, ) = instances.to_numpy_MA("A") 
    193199            elif is_continuous(instances.domain.class_var): 
    194                 label_mask = numpy.zeros(len(instances.domain.variables), 
    195                                          dtype=bool) 
    196                 label_mask[-1] = True 
    197200                x, y, _ = instances.to_numpy_MA() 
    198201                y = y.reshape((-1, 1)) 
    199202            else: 
    200203                raise ValueError("Cannot handle the response.") 
     204        elif instances.domain.class_vars: 
     205            # Multi-target domain 
     206            if not all(map(is_continuous, instances.domain.class_vars)): 
     207                raise TypeError("Only continuous multi-target classes are supported.") 
     208            x_table = select_attrs(instances, instances.domain.attributes) 
     209            y_table = select_attrs(instances, instances.domain.class_vars) 
     210             
     211            # Impute and continuize only the x_table 
     212            x_table = self.impute_table(x_table) 
     213            x_table = self.continuize_table(x_table) 
     214            domain = Domain(x_table.domain.attributes, 
     215                            class_vars=instances.domain.class_vars) 
     216             
     217            (x, ) = x_table.to_numpy_MA("A") 
     218            (y, ) = y_table.to_numpy_MA("A") 
     219             
     220            multitarget = True 
     221        else: 
     222            raise ValueError("Class variable expected.") 
    201223         
    202224        if self.scale_resp and y.shape[1] == 1: 
     
    233255                               subsets, rss_per_subset, gcv_per_subset, 
    234256                               instances=instances if self.store_instances else None, 
    235                                label_mask=label_mask, multi_flag=self.multi_label, 
    236                                expanded_class=expanded_class) 
    237      
    238      
     257                               multitarget=multitarget, 
     258                               expanded_class=expanded_class 
     259                               ) 
     260 
     261 
    239262def soft_max(values): 
    240263    values = numpy.asarray(values) 
     
    247270    def __init__(self, domain, best_set, dirs, cuts, betas, subsets=None, 
    248271                 rss_per_subset=None, gcv_per_subset=None, instances=None, 
    249                  label_mask=None, multi_flag=False, expanded_class=None, 
     272                 multitarget=False, expanded_class=None, 
    250273                 original_domain=None, **kwargs): 
    251         self.multi_flag = multi_flag 
     274        self.multitarget = multitarget 
    252275        self.domain = domain 
    253276        self.class_var = domain.class_var 
     277        if self.multitarget: 
     278            self.class_vars = domain.class_vars 
     279             
    254280        self.best_set = best_set 
    255281        self.dirs = dirs 
     
    260286        self.gcv_per_subset = gcv_per_subset 
    261287        self.instances = instances 
    262         self.label_mask = label_mask 
    263288        self.expanded_class = expanded_class 
    264289        self.original_domain = original_domain 
     
    266291         
    267292    def __call__(self, instance, result_type=Orange.core.GetValue): 
    268         if self.multi_flag: 
    269             resp_vars = [v for v, m in zip(self.domain.variables, 
    270                                            self.label_mask) if m] 
     293        if self.multitarget and self.domain.class_vars: 
     294            resp_vars = list(self.domain.class_vars) 
    271295        elif is_discrete(self.class_var): 
    272296            resp_vars = self.expanded_class 
     
    279303        from Orange.statistics.distribution import Distribution 
    280304         
    281         if not self.multi_flag and is_discrete(self.class_var): 
     305        if not self.multitarget and is_discrete(self.class_var): 
    282306            dist = Distribution(self.class_var) 
    283307            if len(self.class_var.values) == 2: 
     
    297321                probs.append(dist) 
    298322             
    299         if not self.multi_flag: 
     323        if not self.multitarget: 
    300324            vals, probs = vals[0], probs[0] 
    301325             
     
    320344        if instances is None: 
    321345            instances = self.instances 
    322         (data,) = instances.to_numpy_MA("Ac") 
    323         data = data[:, ~ self.label_mask] 
     346        instances = select_attrs(instances, self.domain.attributes) 
     347        (data,) = instances.to_numpy_MA("A") 
    324348        bx = base_matrix(data, self.best_set, self.dirs, self.cuts) 
    325349        return bx 
     
    350374                                      for i in range(self.best_set.size)], 
    351375                          set()) 
    352         attrs = [a for a, m in zip(self.domain.variables, self.label_mask) 
    353                  if not m] 
     376             
     377        attrs = self.domain.attributes 
    354378         
    355379        used_mask = self.dirs[term, :] != 0.0 
     
    673697        X_work = X[:, working_set] 
    674698        b, rsd, rank = linalg.qr_lstsq(X_work, Y) 
    675 #        print rsd 
    676699        rss_vec[subset_size] = numpy.sum(rsd ** 2) 
    677700        XtX = numpy.dot(X_work.T, X_work) 
     
    718741    """ Return a formated string representation of the earth model. 
    719742    """ 
    720     if model.multi_flag: 
    721         r_vars = [v for v, m in zip(model.domain.variables, 
    722                                     model.label_mask) 
    723                   if m] 
     743    if model.multitarget: 
     744        r_vars = list(model.domain.class_vars) 
    724745    elif is_discrete(model.class_var): 
    725746        r_vars = model.expanded_class 
     
    746767            beta_i += 1 
    747768            beta = fmt % abs(betas[beta_i]) 
    748             knots = [_format_knot(model, attr.name, d, c) for d, c, attr in \ 
     769            knots = [_format_knot(model, attr.name, d, c, percision) \ 
     770                     for d, c, attr in \ 
    749771                     zip(model.dirs[i], model.cuts[i], model.domain.attributes) \ 
    750772                     if d != 0] 
     
    762784    return "\n".join([header] + [indent + t for _, t in terms]) 
    763785         
    764 def _format_knot(model, name, dir, cut): 
     786def _format_knot(model, name, dir, cut, percision=3): 
     787    fmt = "%%.%if" % percision 
    765788    if dir == 1: 
    766         txt = "max(0, %s - %.3f)" % (name, cut) 
     789        txt = ("max(0, %s - " + fmt + ")") % (name, cut) 
    767790    elif dir == -1: 
    768         txt = "max(0, %.3f - %s)" % (cut, name) 
     791        txt = ("max(0, " + fmt + " - %s)") % (cut, name) 
    769792    elif dir == 2: 
    770793        txt = name 
    771794    return txt 
    772795 
    773 def _format_term(model, i, attr_name): 
    774     knots = [_format_knot(model, attr, d, c) for d, c, attr in \ 
    775              zip(model.dirs[i], model.cuts[i], model.domain.attributes) \ 
    776              if d != 0] 
    777     return " * ".join(knots) 
    778796     
    779797"""\ 
     
    881899    .. image:: files/earth-evimp.png 
    882900      
    883     The left axis is the nsubsets measure an on the right are the normalized 
     901    The left axis is the nsubsets measure and on the right are the normalized 
    884902    RSS and GCV. 
    885903     
     
    10811099            self._cache_data = data 
    10821100            self._cache_rss = rss 
    1083 #        print sorted(zip(rss, data.domain.attributes)) 
     1101             
    10841102        index = list(data.domain.attributes).index(attr) 
    10851103        return rss[index] 
Note: See TracChangeset for help on using the changeset viewer.