Ignore:
Timestamp:
06/07/12 15:24:13 (23 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Another speedup of earth basis computation.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/regression/earth.py

    r10897 r10907  
    359359        dirs = self.dirs[self.best_set] 
    360360        cuts = self.cuts[self.best_set] 
    361  
     361        # For faster domain translation all the features share 
     362        # this _instance_cache. 
     363        _instance_cache = {} 
    362364        for dir, cut in zip(dirs[1:], cuts[1:]):  # Drop the intercept (first column) 
    363365            hinge = [_format_knot(self, attr.name, dir1, cut1) \ 
     
    367369            term_name = " * ".join(hinge) 
    368370            term = Orange.feature.Continuous(term_name) 
    369             term.get_value_from = term_computer(term, self.domain, dir, cut) 
     371            term.get_value_from = term_computer( 
     372                term, self.domain, dir, cut, 
     373                _instance_cache=_instance_cache 
     374            ) 
     375 
    370376            terms.append(term) 
    371377        return terms 
     
    492498 
    493499    """ 
    494     def __init__(self, term_var=None, domain=None, dir=None, cut=None): 
     500    def __init__(self, term_var=None, domain=None, dir=None, cut=None, 
     501                 _instance_cache=None): 
    495502        self.class_var = term_var 
    496503        self.domain = domain 
     504 
    497505        self.dir = dir 
    498506        self.cut = cut 
    499507 
     508        if dir is not None: 
     509            self.mask = self.dir != 0 
     510            self.masked_dir = self.dir[self.mask] 
     511            self.masked_cut = self.cut[self.mask] 
     512        else: 
     513            # backcompat. with old pickled format. 
     514            self.mask = self.masked_dir = self.masked_cut = None 
     515 
     516        self._instance_cache = _instance_cache 
     517 
    500518    def __call__(self, instance, return_what=Orange.core.GetValue): 
    501         instance = Orange.data.Table(self.domain, [instance]) 
    502         (instance,) = instance.to_numpy_MA("A") 
    503         instance = instance[0] 
    504  
    505         mask = self.dir != 0 
    506         dir = self.dir[mask] 
    507         cut = self.cut[mask] 
    508  
    509         values = instance[mask] - cut 
    510         values *= dir 
    511  
    512         values = numpy.where(values > 0, values, 0) 
    513         value = numpy.prod(values.filled(0)) 
    514  
    515         return self.class_var(value if value is not numpy.ma.masked else "?") 
     519        instance = self._instance_as_masked_array(instance) 
     520 
     521        if self.mask is None: 
     522            self.mask = self.dir != 0 
     523            self.masked_dir = self.dir[self.mask] 
     524            self.masked_cut = self.cut[self.mask] 
     525 
     526        values = instance[self.mask] 
     527        if numpy.ma.is_masked(values): 
     528            # Can't compute the term. 
     529            return self.class_var("?") 
     530 
     531        # Works faster with plain arrays 
     532        values = numpy.array(values) 
     533        values -= self.masked_cut 
     534        values *= self.masked_dir 
     535 
     536        values[values < 0] = 0 
     537        value = numpy.prod(values) 
     538 
     539        return self.class_var(value) 
     540 
     541    def _instance_as_masked_array(self, instance): 
     542        array = None 
     543        if self._instance_cache is not None: 
     544            array = self._instance_cache.get(instance, None) 
     545 
     546        if array is None: 
     547            table = Orange.data.Table(self.domain, [instance]) 
     548            (array,) = table.to_numpy_MA("A") 
     549            array = array[0] 
     550 
     551            if self._instance_cache is not None: 
     552                self._instance_cache.clear() 
     553                self._instance_cache[instance] = array 
     554        return array 
     555 
     556    def __reduce__(self): 
     557        return (type(self), (self.class_var, self.domain, self.dir, self.cut), 
     558                dict(self.__dict__.items())) 
    516559 
    517560 
Note: See TracChangeset for help on using the changeset viewer.