Changeset 8121:992e3068e55f in orange


Ignore:
Timestamp:
07/28/11 16:40:23 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
3c12db113aff2ffa1791c317eb1a41a4a6054018
Message:

Added multi-label EarthLearner/Classifier.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/earth.py

    r8116 r8121  
    2727    setattr(obj, name, old_val) 
    2828     
    29 def base_matrix(data, best_set, dirs, cuts, betas): 
     29def base_matrix(data, best_set, dirs, cuts): 
    3030    """ Return the base matrix for the earth model. 
    3131     
     
    3535    dirs = numpy.asarray(dirs) 
    3636    cuts = numpy.asarray(cuts) 
    37     betas = numpy.asarray(betas) 
    3837     
    3938    bx = numpy.zeros((data.shape[0], best_set.shape[0])) 
     
    7877     
    7978    def __call__(self, data, weightId=None): 
     79        if not data.domain.class_var: 
     80            raise ValueError("No class var in the domain.") 
     81         
    8082        with member_set(self, "prune", False): 
    8183            # We overwrite the prune argument (will do the pruning in python). 
     
    99101    def pruning_pass(self, base_clsf, examples): 
    100102        """ Prune the terms constructed in the forward pass. 
     103        (Pure numpy reimplementation) 
    101104        """ 
    102105        n_terms = numpy.sum(base_clsf.best_set) 
     
    108111        bx = base_matrix(data, base_clsf.best_set, 
    109112                         base_clsf.dirs, base_clsf.cuts, 
    110                          base_clsf.betas) 
     113                         ) 
    111114         
    112115        bx_used = bx[:, best_set] 
     
    127130        betas, rss, rank, s = numpy.linalg.lstsq(bx_subset, y) 
    128131        return best_set, betas, rss, subsets, rss_per_subset, gcv_per_subset 
    129          
     132     
    130133         
    131134class EarthClassifier(Orange.core.ClassifierFD): 
     
    187190            return (value, dist) 
    188191     
    189       
    190     def terms(self): 
    191         """ Return the terms in the Earth model. 
    192         """ 
    193         raise NotImplementedError 
    194      
    195     def filters(self): 
    196         """ Orange.core.filter objects for each term (where the hinge 
    197         function is not 0). 
    198           
    199         """ 
    200          
    201      
    202192    def base_matrix(self, examples=None): 
    203193        """ Return the base matrix (bx) of the Earth model for the table. 
     
    222212            data = numpy.asarray(examples) 
    223213             
    224         return base_matrix(data, self.best_set, self.dirs, self.cuts, self.betas) 
     214        return base_matrix(data, self.best_set, self.dirs, self.cuts) 
    225215     
    226216    def _anova_order(self): 
     
    236226     
    237227    def format_model(self, percision=3, indent=3): 
    238         header = "%s =" % self.class_var.name 
    239         indent = " " * indent 
    240         fmt = "%." + str(percision) + "f" 
    241         terms = [([], fmt % self.betas[0])] 
    242         beta_i = 0 
    243         for i, used in enumerate(self.best_set[1:], 1): 
    244             if used: 
    245                 beta_i += 1 
    246                 beta = fmt % abs(self.betas[beta_i]) 
    247                 knots = [self._format_knot(attr.name, d, c) for d, c, attr in \ 
    248                          zip(self.dirs[i], self.cuts[i], self.domain.attributes) \ 
    249                          if d != 0] 
    250                 term_attrs = [a for a, d in zip(self.domain.attributes, self.dirs[i]) \ 
    251                               if d != 0] 
    252                 term_attrs = sorted(term_attrs) 
    253                 sign = "-" if self.betas[beta_i] < 0 else "+" 
    254                 if knots: 
    255                     terms.append((term_attrs, 
    256                                   sign + " * ".join([beta] + knots))) 
    257                 else: 
    258                     terms.append((term_attr, sign + beta)) 
    259         # Sort by len(term_attrs), then by term_attrs 
    260         terms = sorted(terms, key=lambda t: (len(t[0]), t[0])) 
    261         return "\n".join([header] + [indent + t for _, t in terms]) 
    262              
    263     def _format_knot(self, name, dir, cut): 
    264         if dir == 1: 
    265             txt = "max(0, %s - %.3f)" % (name, cut) 
    266         elif dir == -1: 
    267             txt = "max(0, %.3f - %s)" % (cut, name) 
    268         elif dir == 2: 
    269             txt = name 
    270         return txt 
    271      
    272     def _format_term(self, i): 
    273         knots = [self._format_knot(attr.name, d, c) for d, c, attr in \ 
    274                  zip(self.dirs[i], self.cuts[i], self.domain.attributes) \ 
    275                  if d != 0] 
    276         return " * ".join(knots) 
     228        return format_model(self, percision, indent) 
    277229     
    278230    def print_model(self, percision=3, indent=3): 
     
    332284        """ Return the estimated variable importance. 
    333285        """ 
    334         if self.subsets is None: 
    335             raise ValueError("No subsets. Use the learner with 'prune=True'.") 
    336          
    337         subsets = self.subsets 
    338         n_subsets = self.num_terms 
    339          
    340         rss = -numpy.diff(self.rss_per_subset) 
    341         gcv = -numpy.diff(self.gcv_per_subset) 
    342         attributes = list(self.domain.attributes) 
    343          
    344         attr2ind = dict(zip(attributes, range(len(attributes)))) 
    345         importances = numpy.zeros((len(attributes), 4)) 
    346         importances[:, 0] = range(len(attributes)) 
    347          
    348         for i in range(1, n_subsets): 
    349             term_subset = self.subsets[i, :i + 1] 
    350             used_attributes = reduce(set.union, [self.used_attributes(term) for term in term_subset], set()) 
    351             for attr in used_attributes: 
    352                 importances[attr2ind[attr]][1] += 1.0 
    353                 importances[attr2ind[attr]][2] += gcv[i - 1] 
    354                 importances[attr2ind[attr]][3] += rss[i - 1] 
    355         imp_min = numpy.min(importances[:, [2, 3]], axis=0) 
    356         imp_max = numpy.max(importances[:, [2, 3]], axis=0) 
    357         importances[:, [2, 3]] = 100.0 * (importances[:, [2, 3]] - [imp_min]) / ([imp_max - imp_min]) 
    358          
    359         importances = list(importances) 
    360         # Sort by n_subsets and gcv. 
    361         importances = sorted(importances, key=lambda row: (row[1], row[2]), 
    362                              reverse=True) 
    363         importances = numpy.array(importances) 
    364          
    365         if used_only: 
    366             importances = importances[importances[:,0] > 0.0] 
    367          
    368         res = [(attributes[int(row[0])], tuple(row[1:])) for row in importances] 
    369         return res 
    370              
    371     def plot(self): 
    372         import pylab 
    373         n_terms = self.num_terms 
    374         grid_size = int(numpy.ceil(numpy.sqrt(n_terms))) 
    375         fig = pylab.figure() 
     286        return evimp(self, used_only) 
    376287         
    377288    def __reduce__(self): 
     
    397308 
    398309def subsets_selection_xtx(X, Y): 
    399     """ 
     310    """ A numpy implementation of EvalSubsetsUsingXtx in the Earth package.  
    400311    """ 
    401312    X = numpy.asarray(X) 
     
    429340 
    430341 
     342""" 
     343Multilabel 
     344""" 
     345 
     346def is_label_attr(attr): 
     347    """ Is attribute a label. 
     348    """ 
     349    return attr.attributes.has_key("label") 
     350     
     351def data_label_indices(domain): 
     352    """ Return the indices of label attributes in data. 
     353    """ 
     354    return numpy.where(data_label_mask(domain))[0] 
     355 
     356def data_label_mask(domain): 
     357    """ Return an array of booleans indicating whether a variable in the 
     358    domain is a label. 
     359    """ 
     360    is_label = map(is_label_attr, domain.variables) 
     361    if domain.class_var: 
     362        is_label[-1] = True 
     363    return numpy.array(is_label, dtype=bool) 
     364 
     365class EarthLearnerML(Orange.core.LearnerFD): 
     366    """ Multilabel (multi response) earth learner. 
     367    """ 
     368    def __new__(cls, examples=None, weight_id=None, **kwargs): 
     369        self = Orange.core.LearnerFD.__new__(cls) 
     370        if examples is not None: 
     371            self.__init__(**kwargs) 
     372            return self.__call__(examples, weight_id) 
     373        else: 
     374            return self 
     375         
     376    def __init__(self, degree=1, terms=21, penalty= None, thresh=1e-3, 
     377                 min_span=0, new_var_penalty=0, fast_k=20, fast_beta=1, 
     378                 pruned_terms=None, scale_resp=False, store_examples=True, **kwds): 
     379        """  
     380        .. todo:: min_span, prunning_method 
     381         
     382        """ 
     383        self.degree = degree 
     384        self.terms = terms 
     385        if penalty is None: 
     386            penalty = 3 if degree > 1 else 2 
     387        self.penalty = penalty  
     388        self.thresh = thresh 
     389        self.min_span = min_span 
     390        self.new_var_penalty = new_var_penalty 
     391        self.fast_k = fast_k 
     392        self.fast_beta = fast_beta 
     393        self.pruned_terms = pruned_terms 
     394        self.scale_resp = scale_resp 
     395        self.store_examples = store_examples 
     396        self.__dict__.update(kwds) 
     397         
     398    def __call__(self, examples, weight_id=None): 
     399        label_mask = data_label_mask(examples.domain) 
     400        if not any(label_mask): 
     401            raise ValueError("The domain has no response variable.") 
     402        data = examples.to_numpy_MA("Ac")[0] 
     403        y = data[:, label_mask] 
     404        x = data[:, ~ label_mask] 
     405         
     406        # TODO: y scaling 
     407        n_terms, used, bx, dirs, cuts = _forward_pass(x, y, 
     408            degree=self.degree, terms=self.terms, penalty=self.penalty, 
     409            thresh=self.thresh, fast_k=self.fast_k, fast_beta=self.fast_beta, 
     410            new_var_penalty=self.new_var_penalty) 
     411         
     412        # discard unused terms from bx, dirs, cuts 
     413        bx = bx[:, used] 
     414        dirs = dirs[used, :] 
     415        cuts = cuts[used, :] 
     416         
     417        # pruning 
     418        used, subsets, rss_per_subset, gcv_per_subset = \ 
     419            pruning_pass(bx, y, self.penalty, 
     420                         pruned_terms=self.pruned_terms) 
     421         
     422        # Fit betas 
     423        bx_used = bx[:, used] 
     424        betas, res, rank, s = numpy.linalg.lstsq(bx_used, y) 
     425         
     426        return EarthClassifierML(examples.domain, used, dirs, cuts, betas.T, 
     427                                 subsets, rss_per_subset, gcv_per_subset, 
     428                                 examples=examples if self.store_examples else None, 
     429                                 label_mask=label_mask) 
     430     
     431 
     432class EarthClassifierML(Orange.core.ClassifierFD): 
     433    """ Multi label Earth classifier. 
     434    """ 
     435    def __init__(self, domain, best_set, dirs, cuts, betas, subsets=None, 
     436                 rss_per_subset=None, gcv_per_subset=None, examples=None, 
     437                 label_mask=None, **kwargs): 
     438        self.multi_flag = 1 
     439        self.domain = domain 
     440        self.best_set = best_set 
     441        self.dirs = dirs 
     442        self.cuts = cuts 
     443        self.betas = betas 
     444        self.subsets = subsets 
     445        self.rss_per_subset = rss_per_subset 
     446        self.gcv_per_subset = gcv_per_subset 
     447        self.examples = examples 
     448        self.label_mask = label_mask 
     449        self.__dict__.update(kwargs) 
     450         
     451    def __call__(self, example, result_type=Orange.core.GetValue): 
     452        resp_vars = [v for v, m in zip(self.domain.variables, self.label_mask)\ 
     453                     if m] 
     454        vals = self.predict(example) 
     455        vals = [var(val) for var, val in zip(resp_vars, vals)] 
     456         
     457        probs = [] 
     458        for var, val in zip(resp_vars, vals): 
     459            dist = Orange.statistics.distribution.Distribution(var) 
     460            dist[val] = 1.0 
     461            probs.append(dist) 
     462             
     463        if result_type == Orange.core.GetValue: 
     464            return vals 
     465        elif result_type == Orange.core.GetProbabilities: 
     466            return zip(vals, probs) 
     467        else: 
     468            return probs 
     469     
     470    def format_model(self, percision=3, indent=3): 
     471        """ Return a string representation of the model. 
     472        """ 
     473        return format_model(self, percision, indent) 
     474     
     475    def print_model(self, percision=3, indent=3): 
     476        """ Print the model to stdout. 
     477        """ 
     478        print self.format_model(percision, indent) 
     479         
     480    def base_matrix(self, examples=None): 
     481        """Return the base matrix (bx) of the Earth model for the table. 
     482        If table is not supplied the base matrix of the training examples  
     483        is returned. 
     484         
     485        :param examples: Input examples for the base matrix. 
     486        :type examples: :class:`Orange.data.Table`  
     487         
     488        """ 
     489        if examples is None: 
     490            examples = self.examples 
     491        (data,) = examples.to_numpy_MA("Ac") 
     492        data = data[:, ~ self.label_mask] 
     493        bx = base_matrix(data, self.best_set, self.dirs, self.cuts) 
     494        return bx 
     495     
     496    def predict(self, example): 
     497        """ Predict the response values for the example 
     498         
     499        :param example: example instance 
     500        :type example: :class:`Orange.data.Example` 
     501        """ 
     502        data = Orange.data.Table(self.domain, [example]) 
     503        bx = self.base_matrix(data) 
     504        bx_used = bx[:, self.best_set] 
     505        vals = numpy.dot(bx_used, self.betas.T).ravel() 
     506        return vals 
     507     
     508    def used_attributes(self, term=None): 
     509        """ Return the used terms for term (index). If no term is given 
     510        return all attributes in the model. 
     511         
     512        :param term: term index 
     513        :type term: int 
     514         
     515        """ 
     516        if term is None: 
     517            return reduce(set.union, [self.used_attributes(i) \ 
     518                                      for i in range(self.best_set.size)], 
     519                          set()) 
     520        attrs = [a for a, m in zip(self.domain.variables, self.label_mask) 
     521                 if not m] 
     522         
     523        used_mask = self.dirs[term, :] != 0.0 
     524        return [a for a, u in zip(attrs, used_mask) if u] 
     525     
     526    def evimp(self, used_only=True): 
     527        """ Return the estimated variable importances. 
     528         
     529        :param used_only: if True return only used attributes 
     530          
     531        """   
     532        return evimp(self, used_only) 
     533     
     534    def __reduce__(self): 
     535        return (type(self), (self.domain, self.best_set, self.dirs, 
     536                            self.cuts, self.betas), 
     537                dict(self.__dict__)) 
     538 
     539     
     540""" 
     541ctypes interface to ForwardPass and EvalSubsetsUsingXtx. 
     542 
     543""" 
     544         
     545import ctypes 
     546from numpy import ctypeslib 
     547import orange 
     548 
     549_c_orange_lib = ctypeslib.load_library(orange.__file__, "") 
     550_c_forward_pass_ = _c_orange_lib.EarthForwardPass 
     551 
     552_c_forward_pass_.argtypes = \ 
     553    [ctypes.POINTER(ctypes.c_int),  #pnTerms: 
     554     ctypeslib.ndpointer(dtype=numpy.bool, ndim=1),  #FullSet 
     555     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"), #bx 
     556     ctypeslib.ndpointer(dtype=numpy.int, ndim=2, flags="F_CONTIGUOUS"),    #Dirs 
     557     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"), #Cuts 
     558     ctypeslib.ndpointer(dtype=numpy.int, ndim=1),  #nFactorsInTerms 
     559     ctypeslib.ndpointer(dtype=numpy.int, ndim=1),  #nUses 
     560     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"), #x 
     561     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"), #y 
     562     ctypeslib.ndpointer(dtype=numpy.double, ndim=1), # Weights 
     563     ctypes.c_int,  #nCases 
     564     ctypes.c_int,  #nResp 
     565     ctypes.c_int,  #nPred 
     566     ctypes.c_int,  #nMaxDegree 
     567     ctypes.c_int,  #nMaxTerms 
     568     ctypes.c_double,   #Penalty 
     569     ctypes.c_double,   #Thresh 
     570     ctypes.c_int,  #nFastK 
     571     ctypes.c_double,   #FastBeta 
     572     ctypes.c_double,   #NewVarPenalty 
     573     ctypeslib.ndpointer(dtype=numpy.int, ndim=1),  #LinPreds 
     574     ctypes.c_bool, #UseBetaCache 
     575     ctypes.c_char_p    #sPredNames 
     576     ] 
     577     
     578def _forward_pass(x, y, degree=1, terms=21, penalty=None, thresh=0.001, 
     579                  fast_k=21, fast_beta=1, new_var_penalty=2): 
     580    """ Do forward pass. 
     581    """ 
     582    import ctypes, orange 
     583    x = numpy.asfortranarray(x, dtype="d") 
     584    y = numpy.asfortranarray(y, dtype="d") 
     585    if x.shape[0] != y.shape[0]: 
     586        raise ValueError("First dimensions of x and y must be the same.") 
     587    if y.ndim == 1: 
     588        y = y.reshape((-1, 1), order="F") 
     589    if penalty is None: 
     590        penalty = 2 
     591    n_cases = x.shape[0] 
     592    n_preds = x.shape[1] 
     593     
     594    n_resp = y.shape[1] if y.ndim == 2 else y.shape[0] 
     595     
     596    # Output variables 
     597    n_term = ctypes.c_int() 
     598    full_set = numpy.zeros((terms,), dtype=numpy.bool, order="F") 
     599    bx = numpy.zeros((n_cases, terms), dtype=numpy.double, order="F") 
     600    dirs = numpy.zeros((terms, n_preds), dtype=numpy.int, order="F") 
     601    cuts = numpy.zeros((terms, n_preds), dtype=numpy.double, order="F") 
     602    n_factors_in_terms = numpy.zeros((terms,), dtype=numpy.int, order="F") 
     603    n_uses = numpy.zeros((n_preds,), dtype=numpy.int, order="F") 
     604    weights = numpy.ones((n_cases,), dtype=numpy.double, order="F") 
     605    lin_preds = numpy.zeros((n_preds,), dtype=numpy.int, order="F") 
     606    use_beta_cache = True 
     607     
     608    _c_forward_pass_(ctypes.byref(n_term), full_set, bx, dirs, cuts, 
     609                     n_factors_in_terms, n_uses, x, y, weights, n_cases, 
     610                     n_resp, n_preds, degree, terms, penalty, thresh, 
     611                     fast_k, fast_beta, new_var_penalty, lin_preds,  
     612                     use_beta_cache, None) 
     613    return n_term.value, full_set, bx, dirs, cuts 
     614 
     615 
     616_c_eval_subsets_xtx = _c_orange_lib.EarthEvalSubsetsUsingXtx 
     617 
     618_c_eval_subsets_xtx.argtypes = \ 
     619    [ctypeslib.ndpointer(dtype=numpy.bool, ndim=2, flags="F_CONTIGUOUS"),   #PruneTerms 
     620     ctypeslib.ndpointer(dtype=numpy.double, ndim=1),   #RssVec 
     621     ctypes.c_int,  #nCases 
     622     ctypes.c_int,  #nResp 
     623     ctypes.c_int,  #nMaxTerms 
     624     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"),   #bx 
     625     ctypeslib.ndpointer(dtype=numpy.double, ndim=2, flags="F_CONTIGUOUS"),   #y 
     626     ctypeslib.ndpointer(dtype=numpy.double, ndim=1)  #WeightsArg 
     627     ] 
     628 
     629def _subset_selection_xtx(X, Y): 
     630    """ Subsets selection using EvalSubsetsUsingXtx in the Earth package. 
     631    """ 
     632    X = numpy.asfortranarray(X, dtype=numpy.double) 
     633    Y = numpy.asfortranarray(Y, dtype=numpy.double) 
     634    if Y.ndim == 1: 
     635        Y = Y.reshape((-1, 1), order="F") 
     636         
     637    if X.shape[0] != Y.shape[0]: 
     638        raise ValueError("First dimensions of bx and y must be the same") 
     639         
     640    var_count = X.shape[1] 
     641    resp_count = Y.shape[1] 
     642    cases = X.shape[0] 
     643    subsets = numpy.zeros((var_count, var_count), dtype=numpy.bool, 
     644                              order="F") 
     645    rss_vec = numpy.zeros((var_count,), dtype=numpy.double, order="F") 
     646    weights = numpy.ones((cases,), dtype=numpy.double, order="F") 
     647     
     648    _c_eval_subsets_xtx(subsets, rss_vec, cases, resp_count, var_count, 
     649                        X, Y, weights) 
     650     
     651    subsets_ind = numpy.zeros((var_count, var_count), dtype=int) 
     652    for i, used in enumerate(subsets.T): 
     653        subsets_ind[i, :i + 1] = numpy.where(used)[0] 
     654         
     655    return subsets_ind, rss_vec 
     656     
     657     
     658def pruning_pass(bx, y, penalty, pruned_terms=-1): 
     659    """ Do pruning pass 
     660     
     661    .. todo:: leaps 
     662     
     663    """ 
     664    subsets, rss_vec = _subset_selection_xtx(bx, y) 
     665     
     666    cases, terms = bx.shape 
     667    n_effective_params = numpy.arange(terms) + 1.0 
     668    n_effective_params += penalty * (n_effective_params - 1.0) / 2.0 
     669     
     670    gcv_vec = gcv(rss_vec, cases, n_effective_params) 
     671     
     672    min_i = numpy.argmin(gcv_vec) 
     673    used = numpy.zeros((terms), dtype=bool) 
     674     
     675    used[subsets[min_i, :min_i + 1]] = True 
     676     
     677    return used, subsets, rss_vec, gcv_vec 
     678     
     679     
     680def evimp(model, used_only=True): 
     681    """ Return the estimated variable importance for the model. 
     682    """ 
     683    if model.subsets is None: 
     684        raise ValueError("No subsets. Use the learner with 'prune=True'.") 
     685     
     686    subsets = model.subsets 
     687    n_subsets = numpy.sum(model.best_set) 
     688     
     689    rss = -numpy.diff(model.rss_per_subset) 
     690    gcv = -numpy.diff(model.gcv_per_subset) 
     691    attributes = list(model.domain.variables) 
     692     
     693    attr2ind = dict(zip(attributes, range(len(attributes)))) 
     694    importances = numpy.zeros((len(attributes), 4)) 
     695    importances[:, 0] = range(len(attributes)) 
     696     
     697    for i in range(1, n_subsets): 
     698        term_subset = subsets[i, :i + 1] 
     699        used_attributes = reduce(set.union, [model.used_attributes(term) \ 
     700                                             for term in term_subset], set()) 
     701        for attr in used_attributes: 
     702            importances[attr2ind[attr]][1] += 1.0 
     703            importances[attr2ind[attr]][2] += gcv[i - 1] 
     704            importances[attr2ind[attr]][3] += rss[i - 1] 
     705    imp_min = numpy.min(importances[:, [2, 3]], axis=0) 
     706    imp_max = numpy.max(importances[:, [2, 3]], axis=0) 
     707     
     708    #Normalize importances. 
     709    importances[:, [2, 3]] = 100.0 * (importances[:, [2, 3]] \ 
     710                            - [imp_min]) / ([imp_max - imp_min]) 
     711     
     712    importances = list(importances) 
     713    # Sort by n_subsets and gcv. 
     714    importances = sorted(importances, key=lambda row: (row[1], row[2]), 
     715                         reverse=True) 
     716    importances = numpy.array(importances) 
     717     
     718    if used_only: 
     719        importances = importances[importances[:,1] > 0.0] 
     720     
     721    res = [(attributes[int(row[0])], tuple(row[1:])) for row in importances] 
     722    return res 
     723 
     724 
    431725def plot_evimp(evimp): 
    432     """ Plot the return value from EarthClassifier.evimp. 
     726    """ Plot the return value from :obj:`EarthClassifier.evimp` call. 
    433727    """ 
    434728    import pylab 
     
    456750    fig.show() 
    457751     
    458      
     752 
     753""" 
     754Printing functions. 
     755""" 
     756 
     757def print_model(model, percision=3, indent=3): 
     758    """ Print model to stdout. 
     759    """ 
     760    print format_model(model, percision, indent) 
     761     
     762def format_model(model, percision=3, indent=3): 
     763    """ Return a formated string representation of the model. 
     764    """ 
     765    if model.class_var: 
     766        r_names = [model.class_var.name] 
     767        betas = [model.betas] 
     768    elif hasattr(model, "label_mask"): 
     769        mask = model.label_mask 
     770        r_vars = [v for v, m in zip(model.domain.variables, 
     771                                    model.label_mask) 
     772                  if m] 
     773        r_names = [v.name for v in r_vars] 
     774        betas = model.betas 
     775         
     776    resp = [] 
     777    for name, betas in zip(r_names, betas): 
     778        resp.append(_format_response(model, name, betas, 
     779                                     percision, indent)) 
     780    return "\n\n".join(resp) 
     781 
     782def _format_response(model, resp_name, betas, percision=3, indent=3): 
     783    header = "%s =" % resp_name 
     784    indent = " " * indent 
     785    fmt = "%." + str(percision) + "f" 
     786    terms = [([], fmt % betas[0])] 
     787    beta_i = 0 
     788    for i, used in enumerate(model.best_set[1:], 1): 
     789        if used: 
     790            beta_i += 1 
     791            beta = fmt % abs(betas[beta_i]) 
     792            knots = [_format_knot(model, attr.name, d, c) for d, c, attr in \ 
     793                     zip(model.dirs[i], model.cuts[i], model.domain.attributes) \ 
     794                     if d != 0] 
     795            term_attrs = [a for a, d in zip(model.domain.attributes, model.dirs[i]) \ 
     796                          if d != 0] 
     797            term_attrs = sorted(term_attrs) 
     798            sign = "-" if betas[beta_i] < 0 else "+" 
     799            if knots: 
     800                terms.append((term_attrs, 
     801                              sign + " * ".join([beta] + knots))) 
     802            else: 
     803                terms.append((term_attrs, sign + beta)) 
     804    # Sort by len(term_attrs), then by term_attrs 
     805    terms = sorted(terms, key=lambda t: (len(t[0]), t[0])) 
     806    return "\n".join([header] + [indent + t for _, t in terms]) 
     807         
     808def _format_knot(model, name, dir, cut): 
     809    if dir == 1: 
     810        txt = "max(0, %s - %.3f)" % (name, cut) 
     811    elif dir == -1: 
     812        txt = "max(0, %.3f - %s)" % (cut, name) 
     813    elif dir == 2: 
     814        txt = name 
     815    return txt 
     816 
     817def _format_term(model, i, attr_name): 
     818    knots = [_format_knot(model, attr, d, c) for d, c, attr in \ 
     819             zip(model.dirs[i], model.cuts[i], model.domain.attributes) \ 
     820             if d != 0] 
     821    return " * ".join(knots) 
Note: See TracChangeset for help on using the changeset viewer.