Changeset 8150:f7cc2be1d407 in orange


Ignore:
Timestamp:
08/04/11 16:25:17 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
d2fad15a3eb1e4a12aaae0e22e6462d10008f464
Message:

Added ScoreEarthImportance feature scoring class (scoring.Measure interface).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/regression/earth.py

    r8144 r8150  
    744744 
    745745 
     746""" 
     747High level interface for measuring variable importance 
     748(compatible with Orange.feature.scoring module). 
     749 
     750""" 
     751from Orange.feature import scoring 
     752from Orange.misc import _orange__new__ 
     753 
     754class ScoreEarthImportance(scoring.Measure): 
     755    """ Score features based on there importance in the Earth model using 
     756    ``bagged_evimp``'s function return value. 
     757    """ 
     758    # Return types   
     759    NSUBSETS = 0 
     760    RSS = 1 
     761    GCV = 2 
     762     
     763#    _cache = weakref.WeakKeyDictionary() 
     764    __new__ = _orange__new__(scoring.Measure) 
     765         
     766    def __init__(self, t=10, score_what="nsubsets", cached=True): 
     767        """ 
     768        :param t: Number of earth models to train on the data 
     769            (using BaggedLearner). 
     770             
     771        :param score_what: What to return as a score. 
     772            Can be one of: "nsubsets", "rss", "gcv" or class constants 
     773            NSUBSETS, RSS, GCV. 
     774             
     775        """ 
     776        self.t = t 
     777        if isinstance(score_what, basestring): 
     778            score_what = {"nsubsets":self.NSUBSETS, "rss":self.RSS, 
     779                          "gcv":self.GCV}.get(score_what, None) 
     780                           
     781        if score_what not in range(3): 
     782            raise ValueError("Invalid  'score_what' parameter.") 
     783 
     784        self.score_what = score_what 
     785        self.cached = cached 
     786        self._cache_ref = None 
     787         
     788    def __call__(self, attr, data, weight_id=None): 
     789        ref = self._cache_ref 
     790        if ref is not None and ref is data: 
     791            evimp = self._cached_evimp 
     792        else: 
     793            from Orange.ensemble.bagging import BaggedLearner 
     794            bc = BaggedLearner(EarthLearner(degree=2, terms=10))(data, weight_id) 
     795            evimp = bagged_evimp(bc, used_only=False) 
     796            self._cache_ref = data #weakref.ref(data) 
     797            self._cached_evimp = evimp 
     798             
     799        evimp = dict(evimp) 
     800        score = evimp.get(attr, None) 
     801         
     802        if score is None: 
     803            raise ValueError("Attribute %r not in the domain." % attr) 
     804        else: 
     805            return score[self.score_what] 
     806     
     807     
     808class ScoreRSS(scoring.Measure): 
     809    __new__ = _orange__new__(scoring.Measure) 
     810    def __init__(self): 
     811        self._cache_data = None 
     812        self._cache_rss = None 
     813         
     814    def __call__(self, attr, data, weight_id=None): 
     815        ref = self._cache_data 
     816        if ref is not None and ref is data: 
     817            rss = self._cache_rss 
     818        else: 
     819            x, y = data.to_numpy_MA("1A/c") 
     820            subsets, rss = subsets_selection_xtx_numpy(x, y) 
     821            rss_diff = -numpy.diff(rss) 
     822            rss = numpy.zeros_like(rss) 
     823            for s_size in range(1, subsets.shape[0]): 
     824                subset = subsets[s_size, :s_size + 1] 
     825                rss[subset] += rss_diff[s_size - 1] 
     826            rss = rss[1:] #Drop the intercept 
     827            self._cache_data = data 
     828            self._cache_rss = rss 
     829#        print sorted(zip(rss, data.domain.attributes)) 
     830        index = list(data.domain.attributes).index(attr) 
     831        return rss[index] 
     832         
     833     
    746834#from Orange.misc import member_set 
    747835#  
Note: See TracChangeset for help on using the changeset viewer.