Changeset 8141:1e2e8b53c6fd in orange


Ignore:
Timestamp:
08/02/11 17:03:29 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
0c110d79887d2e883079c9d932e0184d43cbec48
Message:

Added y scaling.
Check parameters in forward_pass (to prevent a call to exit from earth C code).
Added bagged_evimp function.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/earth.py

    r8135 r8141  
    4545.. autofunction:: plot_evimp 
    4646 
     47.. autofunction:: bagged_evimp 
     48 
    4749""" 
    4850 
     
    6668    def __init__(self, degree=1, terms=21, penalty= None, thresh=1e-3, 
    6769                 min_span=0, new_var_penalty=0, fast_k=20, fast_beta=1, 
    68                  pruned_terms=None, scale_resp=False, store_examples=True, 
     70                 pruned_terms=None, scale_resp=True, store_examples=True, 
    6971                 multi_label=False, **kwds): 
    7072        """ Initialize the learner instance. 
     
    132134        x = data[:, ~ label_mask] 
    133135         
     136        if self.scale_resp: 
     137            y = y - numpy.mean(y, axis=0) 
     138            y = y / numpy.std(y, axis=1) 
     139             
     140             
    134141        # TODO: y scaling 
    135142        n_terms, used, bx, dirs, cuts = forward_pass(x, y, 
     
    456463    use_beta_cache = True 
    457464     
     465    # These tests are performed in ForwardPass, and if they fail the function 
     466    # calls exit. So we must check it here and raise a exception to avoid a 
     467    # process shutdown. 
     468    if n_cases < 8: 
     469        raise ValueError("Need at least 8 data instances.") 
     470    if n_cases > 1e8: 
     471        raise ValueError("To many data instances.") 
     472    if n_resp < 1: 
     473        raise ValueError("No response column.") 
     474    if n_resp > 1e6: 
     475        raise ValueError("To many response columns.") 
     476    if n_preds < 1: 
     477        raise ValueError("No predictor columns.") 
     478    if n_preds > 1e5: 
     479        raise ValueError("To many predictor columns.") 
     480    if degree <= 0 or degree > 100: 
     481        raise ValueError("Invalid 'degree'.") 
     482    if terms < 3 or terms > 10000: 
     483        raise ValueError("'terms' must be in >= 3 and <= 10000.") 
     484    if penalty < 0 and penalty != -1: 
     485        raise ValueError("Invalid 'penalty' (the only legal negative value is -1).") 
     486    if penalty > 1000: 
     487        raise ValueError("Invalid 'penalty' (must be <= 1000).") 
     488    if thresh < 0.0 or thresh >= 1.0: 
     489        raise ValueError("Invalid 'thresh' (must be in [0.0, 1.0) ).") 
     490    if fast_beta < 0 or fast_beta > 1000: 
     491        raise ValueError("Invalid 'fast_beta' (must be in [0, 1000] ).") 
     492    if new_var_penalty < 0 or new_var_penalty > 10: 
     493        raise ValueError("Invalid 'new_var_penalty' (must be in [0, 10] ).") 
     494    if (numpy.var(y, axis=0) <= 1e-8).any(): 
     495        raise ValueError("Variance of y is zero (or near zero).") 
     496      
    458497    _c_forward_pass_(ctypes.byref(n_term), full_set, bx, dirs, cuts, 
    459498                     n_factors_in_terms, n_uses, x, y, weights, n_cases, 
     
    598637     
    599638    axes1.yaxis.set_label_text("nsubsets") 
    600     axes2.yaxis.set_label_text("normalizes gcc or rss") 
     639    axes2.yaxis.set_label_text("normalized gcv or rss") 
    601640 
    602641    axes1.legend([l1, l2, l3], ["nsubsets", "gcv", "rss"]) 
     
    604643    fig.show() 
    605644     
     645def bagged_evimp(classifier, used_only=True): 
     646    """ Extract combined (average) evimp from an instance of BaggedClassifier 
     647     
     648    Example: :: 
     649        >>> from Orange.ensemble.bagging import BaggedLearner 
     650        >>> bc = BaggedLearner(EarthLearner(degree=3, terms=10), data) 
     651        >>> bagged_evimp(bc) 
     652         
     653    """ 
     654    def assert_type(object, class_): 
     655        if not isinstance(object, class_): 
     656            raise TypeError("Instance of %r expected." % (class_)) 
     657    from collections import defaultdict 
     658    from Orange.ensemble.bagging import BaggedClassifier 
     659     
     660    assert_type(classifier, BaggedClassifier) 
     661    bagged_imp = defaultdict(list) 
     662     
     663    for c in classifier.classifiers: 
     664        assert_type(c, EarthClassifier) 
     665        imp = evimp(c, used_only=False) 
     666        for attr, score in imp: 
     667            bagged_imp[attr].append(score) 
     668             
     669    for attr, scores in bagged_imp.items(): 
     670        scores = numpy.average(scores, axis=0) 
     671        bagged_imp[attr] = tuple(scores) 
     672     
     673    bagged_imp = sorted(bagged_imp.items(), key=lambda t:t[1][0], 
     674                        reverse=True)     
     675    if used_only: 
     676        bagged_imp = [(a, r) for a, r in bagged_imp if r[0] > 0] 
     677    return bagged_imp 
    606678 
    607679""" 
Note: See TracChangeset for help on using the changeset viewer.