Changeset 8116:edf2c6e55587 in orange


Ignore:
Timestamp:
07/27/11 10:52:41 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
b7543f28c2e124a7b3b54ef01ba8d9a6734d1303
Message:

Some code cleanup.
Added plot_evimp function.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/earth.py

    r8114 r8116  
    11"""\ 
    2 EARTH (Multivariate_adaptive_regression_splines - MARS) (``earth``) 
     2EARTH (Multivariate adaptive regression splines - MARS) (``earth``) 
    33 
    44  
     
    5959         
    6060    return bx 
     61 
    6162 
    6263class EarthLearner(BaseEarthLearner): 
     
    154155    @property 
    155156    def num_terms(self): 
     157        """ Number of terms in the model (including the intercept). 
     158        """ 
    156159        return numpy.sum(numpy.asarray(self.best_set, dtype=int)) 
    157160     
    158161    @property 
    159162    def max_terms(self): 
     163        """ Maximum number of terms (as specified in the learning step). 
     164        """ 
    160165        return self.best_set.size 
    161166     
    162167    @property 
    163168    def num_preds(self): 
     169        """ Number of predictors (variables) included in the model. 
     170        """ 
    164171        return len(self.used_attributes(term)) 
    165172     
     
    186193        raise NotImplementedError 
    187194     
     195    def filters(self): 
     196        """ Orange.core.filter objects for each term (where the hinge 
     197        function is not 0). 
     198          
     199        """ 
     200         
     201     
    188202    def base_matrix(self, examples=None): 
    189203        """ Return the base matrix (bx) of the Earth model for the table. 
     
    191205        is returned. 
    192206         
     207         
    193208        :param examples: Input examples for the base matrix. 
    194209        :type examples: Orange.data.Table  
     
    208223             
    209224        return base_matrix(data, self.best_set, self.dirs, self.cuts, self.betas) 
     225     
     226    def _anova_order(self): 
     227        """ Return indices that sort the terms into the 'ANOVA' format. 
     228        """ 
     229        terms = [([], 0)] # intercept 
     230        for i, used in enumerate(self.best_set[1:], 1): 
     231            attrs = sorted(self.used_attributes(i)) 
     232            if used and attrs: 
     233                terms.append((attrs, i)) 
     234        terms = sotred(terms, key=lambda t:(len(t[0]), t[0])) 
     235        return [i for _, i in terms] 
    210236     
    211237    def format_model(self, percision=3, indent=3): 
     
    253279        print self.format_model(percision, indent) 
    254280         
    255     def format_summary(self): 
    256         pass 
    257      
    258     def print_summary(self): 
    259         """ Print model summary 
    260         """ 
    261         print self.format_summary() 
    262          
    263281    def predict(self, ex): 
    264282        """ Return the predicted value (float) for example. 
     
    297315             
    298316    def used_attributes(self, term=None): 
    299         """ Return a list of used attributes. If term is given return only 
    300         attributes used in that single term. 
     317        """ Return a list of used attributes. If term (index) is given 
     318        return only attributes used in that single term. 
    301319         
    302320        """ 
     
    306324            terms = [term] 
    307325        attrs = set() 
    308 #        print terms 
    309326        for ti in terms: 
    310327            attri = numpy.where(self.dirs[ti] != 0.0)[0] 
    311 #            print attri 
    312328            attrs.update([self.domain.attributes[i] for i in attri]) 
    313329        return attrs 
    314330         
    315331    def evimp(self, used_only=True): 
    316         """ Return the variable importance. 
     332        """ Return the estimated variable importance. 
    317333        """ 
    318334        if self.subsets is None: 
     
    337353                importances[attr2ind[attr]][2] += gcv[i - 1] 
    338354                importances[attr2ind[attr]][3] += rss[i - 1] 
     355        imp_min = numpy.min(importances[:, [2, 3]], axis=0) 
     356        imp_max = numpy.max(importances[:, [2, 3]], axis=0) 
     357        importances[:, [2, 3]] = 100.0 * (importances[:, [2, 3]] - [imp_min]) / ([imp_max - imp_min]) 
    339358         
    340359        importances = list(importances) 
     
    353372        import pylab 
    354373        n_terms = self.num_terms 
     374        grid_size = int(numpy.ceil(numpy.sqrt(n_terms))) 
     375        fig = pylab.figure() 
    355376         
    356377    def __reduce__(self): 
     
    361382                {}) 
    362383                                  
    363          
    364  
     384     
    365385def gcv(rss, n, n_effective_params): 
    366386    """ Return the generalized cross validation. 
     
    395415        else: 
    396416            rss_vec[subset_size] = numpy.sum((Y - numpy.dot(X_work, b)) ** 2) 
    397 #        print "Subset size", subset_size, "Rss", rss_vec[subset_size] 
     417             
    398418        XtX = numpy.dot(X_work.T, X_work) 
    399419        iXtX = numpy.linalg.pinv(XtX) 
    400420        diag = numpy.diag(iXtX) 
    401421         
    402          
    403422        if subset_size == 0: 
    404423            break 
     
    407426        delete_i = numpy.argmin(delta_rss[1:]) + 1 # Keep the intercept 
    408427        del working_set[delete_i] 
    409 #        print delete_i 
    410428    return subsets, rss_vec 
    411429 
    412      
     430 
     431def plot_evimp(evimp): 
     432    """ Plot the return value from EarthClassifier.evimp. 
     433    """ 
     434    import pylab 
     435    fig = pylab.figure() 
     436    axes1 = fig.add_subplot(111) 
     437    attrs = [a for a, _ in evimp] 
     438    imp = [s for _, s in evimp] 
     439    imp = numpy.array(imp) 
     440    X = range(len(attrs)) 
     441    l1 = axes1.plot(X, imp[:,0], "b-",) 
     442    axes2 = axes1.twinx() 
     443     
     444    l2 = axes2.plot(X, imp[:,1], "g-",) 
     445    l3 = axes2.plot(X, imp[:,2], "r-",) 
     446     
     447    x_axis = axes1.xaxis 
     448    x_axis.set_ticks(X) 
     449    x_axis.set_ticklabels([a.name for a in attrs], rotation=45) 
     450     
     451    axes1.yaxis.set_label_text("nsubsets") 
     452    axes2.yaxis.set_label_text("normalizes gcc or rss") 
     453 
     454    axes1.legend([l1, l2, l3], ["nsubsets", "gcv", "rss"]) 
     455    axes1.set_title("Variable importance") 
     456    fig.show() 
     457     
     458     
Note: See TracChangeset for help on using the changeset viewer.