Changeset 3572:b46ee08f3ffe in orange


Ignore:
Timestamp:
04/23/07 13:24:20 (7 years ago)
Author:
markotoplak
Branch:
default
Convert:
e1a14c124700bddb0e55f02e9cb952b69010a04e
Message:

Added class MeasureAttribute_randomForests.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/orngEnsemble.py

    r2247 r3572  
    235235        elif resultType == orange.GetProbabilities: return cprob 
    236236        else: return (cvalue, cprob) 
     237 
     238 
     239########################################################## 
     240### MeasureAttribute_randomForests 
     241 
     242class MeasureAttribute_randomForests(orange.MeasureAttribute): 
     243 
     244  def __init__(self, learner=None, trees = 100, attributes=None, rand=None): 
     245    self.trees = trees 
     246    self.learner = learner 
     247    self.bufexamples = None 
     248    self.attributes = attributes 
     249     
     250    if self.learner == None: 
     251      temp = RandomForestLearner(attributes=self.attributes) 
     252      self.learner = temp.learner 
     253     
     254    if hasattr(self.learner.split, 'attributes'): 
     255      self.origattr = self.learner.split.attributes 
     256       
     257    if rand: 
     258      self.rand = rand             # a random generator 
     259    else: 
     260      self.rand = random.Random() 
     261      self.rand.seed(0) 
     262 
     263  def __call__(self, a1, a2, a3=None): 
     264    """ 
     265    Returns importance of a given attribute. Can be given by index,  
     266    name or as a orange.Variable. 
     267    """ 
     268    attrNo = None 
     269    examples = None 
     270 
     271    if type(a1) == int: #by attr. index 
     272      attrNo, examples, apriorClass = a1, a2, a3 
     273    elif type(a1) == type("a"): #by attr. name 
     274      attrName, examples, apriorClass = a1, a2, a3 
     275      attrNo = examples.domain.index(attrName) 
     276    elif isinstance(a1, orange.Variable): 
     277      a1, examples, apriorClass = a1, a2, a3 
     278      atrs = [a for a in examples.domain.attributes] 
     279      attrNo = atrs.index(a1) 
     280    else: 
     281      contingency, classDistribution, apriorClass = a1, a2, a3 
     282      raise Exception("MeasureAttribute_rf can not be called with (contingency, classDistribution, apriorClass) as fuction arguments.") 
     283 
     284    self.buffer(examples) 
     285 
     286    return self.avimp[attrNo]*100/self.trees 
     287 
     288  def importances(self, examples): 
     289    """ 
     290    Returns importances of all attributes in dataset in a list. Buffered. 
     291    """ 
     292    self.buffer(examples) 
     293     
     294    return [a*100/self.trees for a in self.avimp] 
     295 
     296  def buffer(self, examples): 
     297    """ 
     298    recalcule importances if needed (new examples) 
     299    """ 
     300    recalculate = False 
     301     
     302    if examples != self.bufexamples: 
     303      recalculate = True 
     304    elif examples.version != self.bufexamples.version: 
     305      recalculate = True 
     306          
     307    if (recalculate): 
     308      self.bufexamples = examples 
     309      self.avimp = [0.0]*len(self.bufexamples.domain.attributes) 
     310      self.acu = 0 
     311       
     312      if hasattr(self.learner.split, 'attributes'): 
     313          self.learner.split.attributes = self.origattr 
     314       
     315      # if number of attributes for subset is not set, use square root 
     316      if hasattr(self.learner.split, 'attributes') and not self.learner.split.attributes: 
     317          self.learner.split.attributes = int(sqrt(len(examples.domain.attributes))) 
     318       
     319      self.importanceAcu(self.bufexamples, self.trees, self.avimp) 
     320       
     321  def getOOB(self, examples, selection, nexamples): 
     322        ooblist = filter(lambda x: x not in selection, range(nexamples)) 
     323        return examples.getitems(ooblist) 
     324 
     325  def numRight(self, oob, classifier): 
     326        """ 
     327        returns a number of examples which are classified correcty 
     328        """ 
     329        right = 0 
     330        for el in oob: 
     331            if (el.getclass() == classifier(el)): 
     332                right = right + 1 
     333        return right 
     334     
     335  def numRightMix(self, oob, classifier, attr): 
     336        """ 
     337        returns a number of examples  which are classified 
     338        correctly even if an attribute is shuffled 
     339        """ 
     340        n = len(oob) 
     341 
     342        perm = range(n) 
     343        self.rand.shuffle(perm) 
     344 
     345        right = 0 
     346 
     347        for i in range(n): 
     348            ex = orange.Example(oob[i]) 
     349            ex[attr] = oob[perm[i]][attr] 
     350             
     351            if (ex.getclass() == classifier(ex)): 
     352                right = right + 1 
     353                 
     354        return right 
     355 
     356  def importanceAcu(self, examples, trees, avimp): 
     357        """ 
     358        accumulate avimp by importances for a given number of trees 
     359        """ 
     360   
     361 
     362        n = len(examples) 
     363 
     364        attrs = len(examples.domain.attributes) 
     365 
     366        attrnum = {} 
     367        for attr in range(len(examples.domain.attributes)): 
     368           attrnum[examples.domain.attributes[attr].name] = attr             
     369    
     370        # build the forest 
     371        classifiers = []   
     372        for i in range(trees): 
     373             
     374            # draw bootstrap sample 
     375            selection = [] 
     376            for j in range(n): 
     377                selection.append(self.rand.randrange(n)) 
     378            data = examples.getitems(selection) 
     379             
     380            # build the model from the bootstrap sample 
     381            cla = self.learner(data) 
     382 
     383            #prepare OOB data 
     384            oob = self.getOOB(examples, selection, n) 
     385             
     386            #right on unmixed 
     387            right = self.numRight(oob, cla) 
     388             
     389            presl = list(self.presentInTree(cla.tree, attrnum)) 
     390                       
     391            #randomize each attribute in data and test 
     392            #only those on which there was a split 
     393            for attr in presl: 
     394                #calculate number of right classifications 
     395                #if the values of this attribute are permutated randomly 
     396                rightimp = self.numRightMix(oob, cla, attr)                 
     397                avimp[attr] += (float(right-rightimp))/len(oob) 
     398 
     399        self.acu += trees   
     400 
     401  def presentInTree(self, node, attrnum): 
     402        """ 
     403        returns attributes present in tree (attributes that split) 
     404        """ 
     405 
     406        if not node: 
     407          return set([]) 
     408 
     409        if  node.branchSelector: 
     410            j = attrnum[node.branchSelector.classVar.name] 
     411             
     412            cs = set([]) 
     413            for i in range(len(node.branches)): 
     414                s = self.presentInTree(node.branches[i], attrnum) 
     415                cs = s | cs 
     416             
     417            cs = cs | set([j]) 
     418             
     419            return cs 
     420             
     421        else: 
     422          return set([]) 
     423 
     424 
Note: See TracChangeset for help on using the changeset viewer.