Changeset 10896:9e5b3b2fa757 in orange


Ignore:
Timestamp:
06/01/12 14:36:34 (23 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Speed up the the earth basis term domain translation.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/regression/earth.py

    r10855 r10896  
    357357        cuts = self.cuts[self.best_set] 
    358358 
    359         for dir, cut in zip(dirs[1:], cuts[1:]): # Drop the intercept (first)  
     359        for dir, cut in zip(dirs[1:], cuts[1:]):  # Drop the intercept (first column)  
    360360            hinge = [_format_knot(self, attr.name, dir1, cut1) \ 
    361361                     for (attr, dir1, cut1) in \ 
     
    482482    return  rss / (n * (1. - n_effective_params / n) ** 2) 
    483483 
     484 
    484485class term_computer(Orange.core.ClassifierFD): 
    485486    """An utility class for computing basis terms. Can be used as 
    486487    a :obj:`~Orange.feature.Descriptior.get_value_from` member. 
     488 
    487489    """ 
    488490    def __init__(self, term_var=None, domain=None, dir=None, cut=None): 
     
    493495 
    494496    def __call__(self, instance, return_what=Orange.core.GetValue): 
    495         instance = Orange.data.Instance(self.domain, instance) 
    496         attributes = self.domain.attributes 
    497         sum = 0.0 
    498         for val, dir1, cut1 in zip(instance, self.dir, self.cut): 
    499             if dir1 != 0.0 and dir1 != 2 and not val.isSpecial(): 
    500                 sum += max(dir1 * (float(val) - cut1), 0.0) 
    501         return self.class_var(sum) 
     497        instance = Orange.data.Table(self.domain, [instance]) 
     498        (instance,) = instance.to_numpy_MA("A") 
     499        instance = instance[0] 
     500 
     501        mask = self.dir != 0 
     502        dir = self.dir[mask] 
     503        cut = self.cut[mask] 
     504 
     505        values = instance[mask] - cut 
     506        values *= dir 
     507 
     508        values = numpy.where(values > 0, values, 0) 
     509        value = numpy.prod(values.filled(0)) 
     510 
     511        return self.class_var(value if value is not numpy.ma.masked else "?") 
     512 
    502513 
    503514""" 
     
    509520ctypes interface to ForwardPass and EvalSubsetsUsingXtx. 
    510521""" 
    511          
     522 
    512523import ctypes 
    513524from numpy import ctypeslib 
Note: See TracChangeset for help on using the changeset viewer.