Ignore:
Timestamp:
01/30/12 16:53:53 (2 years ago)
Author:
Miha Stajdohar <miha.stajdohar@…>
Branch:
default
rebase_source:
b959bc81c598d90244583a2b3375b025e99015e3
Message:

Fixed a bug in to_network method.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/classification/tree.py

    r9567 r9607  
    14621462            "gain_ratio": "gainRatio" } 
    14631463    #_rename_new_old = {} 
    1464     _rename_old_new = dict((a,b) for b,a in _rename_new_old.items()) 
     1464    _rename_old_new = dict((a, b) for b, a in _rename_new_old.items()) 
    14651465 
    14661466    @classmethod 
    14671467    def _rename_dict(cls, dic): 
    1468         return dict((cls._rename_arg(a),b) for a,b in dic.items()) 
     1468        return dict((cls._rename_arg(a), b) for a, b in dic.items()) 
    14691469 
    14701470    @classmethod 
     
    14741474        return cls._rename_new_old.get(a, a) 
    14751475 
    1476     def __new__(cls, instances = None, weightID = 0, **argkw): 
     1476    def __new__(cls, instances=None, weightID=0, **argkw): 
    14771477        self = Orange.classification.Learner.__new__(cls, **cls._rename_dict(argkw)) 
    14781478        if instances: 
     
    14811481        else: 
    14821482            return self 
    1483          
     1483 
    14841484    def __init__(self, **kwargs): 
    14851485        self.base = _C45Learner(**self._rename_dict(kwargs)) 
     
    15071507        """ 
    15081508        self.base.commandline(ln) 
    1509      
    1510   
     1509 
     1510 
    15111511class C45Classifier(Orange.classification.Classifier): 
    15121512    """ 
     
    15241524        for k, v in self.nativeClassifier.__dict__.items(): 
    15251525            self.__dict__[k] = v 
    1526    
     1526 
    15271527    def __call__(self, instance, result_type=Orange.classification.Classifier.GetValue, 
    15281528                 *args, **kwdargs): 
     
    15511551    def __str__(self): 
    15521552        return self.dump() 
    1553     
    1554  
    1555     def dump(self):   
     1553 
     1554 
     1555    def dump(self): 
    15561556        """ 
    15571557        Print the tree in the same form as Ross Quinlan's  
     
    16061606    str_ = "" 
    16071607    if node.node_type == 1: 
    1608         str_ += "\n"+"|   "*lev + "%s = %s:" % (var.name, var.values[i]) 
    1609         str_ += _c45_printTree0(node.branch[i], classvar, lev+1) 
     1608        str_ += "\n" + "|   "*lev + "%s = %s:" % (var.name, var.values[i]) 
     1609        str_ += _c45_printTree0(node.branch[i], classvar, lev + 1) 
    16101610    elif node.node_type == 2: 
    1611         str_ += "\n"+"|   "*lev + "%s %s %.1f:" % (var.name, ["<=", ">"][i], node.cut) 
    1612         str_ += _c45_printTree0(node.branch[i], classvar, lev+1) 
     1611        str_ += "\n" + "|   "*lev + "%s %s %.1f:" % (var.name, ["<=", ">"][i], node.cut) 
     1612        str_ += _c45_printTree0(node.branch[i], classvar, lev + 1) 
    16131613    else: 
    1614         inset = filter(lambda a:a[1]==i, enumerate(node.mapping)) 
     1614        inset = filter(lambda a:a[1] == i, enumerate(node.mapping)) 
    16151615        inset = [var.values[j[0]] for j in inset] 
    1616         if len(inset)==1: 
    1617             str_ += "\n"+"|   "*lev + "%s = %s:" % (var.name, inset[0]) 
     1616        if len(inset) == 1: 
     1617            str_ += "\n" + "|   "*lev + "%s = %s:" % (var.name, inset[0]) 
    16181618        else: 
    1619             str_ +=  "\n"+"|   "*lev + "%s in {%s}:" % (var.name, ", ".join(inset)) 
    1620         str_ += _c45_printTree0(node.branch[i], classvar, lev+1) 
     1619            str_ += "\n" + "|   "*lev + "%s in {%s}:" % (var.name, ", ".join(inset)) 
     1620        str_ += _c45_printTree0(node.branch[i], classvar, lev + 1) 
    16211621    return str_ 
    1622          
    1623          
     1622 
     1623 
    16241624def _c45_printTree0(node, classvar, lev): 
    16251625    var = node.tested 
    16261626    str_ = "" 
    16271627    if node.node_type == 0: 
    1628         str_ += "%s (%.1f)" % (classvar.values[int(node.leaf)], node.items)  
     1628        str_ += "%s (%.1f)" % (classvar.values[int(node.leaf)], node.items) 
    16291629    else: 
    16301630        for i, branch in enumerate(node.branch): 
     
    18181818        else: 
    18191819            return self 
    1820      
     1820 
    18211821    def __init__(self, **kw): 
    18221822 
     
    18241824        #buildfunctions are not saved as function references 
    18251825        #because that would make problems with object copies 
    1826         for n,(fn,_) in self._built_fn.items(): 
     1826        for n, (fn, _) in self._built_fn.items(): 
    18271827            self.__dict__["_handset_" + n] = False 
    18281828 
     
    18321832        self.stop = None 
    18331833        self.splitter = None 
    1834          
    1835         for n,(fn,_) in self._built_fn.items(): 
     1834 
     1835        for n, (fn, _) in self._built_fn.items(): 
    18361836            self.__dict__[n] = fn(self) 
    18371837 
    1838         for k,v in kw.items(): 
    1839             self.__setattr__(k,v) 
    1840        
     1838        for k, v in kw.items(): 
     1839            self.__setattr__(k, v) 
     1840 
    18411841    def __call__(self, instances, weight=0): 
    18421842        """ 
     
    18531853            bl.split.continuous_split_constructor.measure = measure 
    18541854            bl.split.discrete_split_constructor.measure = measure 
    1855           
     1855 
    18561856        if self.splitter != None: 
    18571857            bl.example_splitter = self.splitter 
     
    18641864            tree = Pruner_m(tree, m=self.m_pruning) 
    18651865 
    1866         return TreeClassifier(base_classifier=tree)  
     1866        return TreeClassifier(base_classifier=tree) 
    18671867 
    18681868    def __setattr__(self, name, value): 
    18691869        self.__dict__[name] = value 
    1870         for n,(fn,v) in self._built_fn.items(): 
     1870        for n, (fn, v) in self._built_fn.items(): 
    18711871            if name in v: 
    18721872                if not self.__dict__["_handset_" + n]: 
     
    19311931        if relM and measureIsRelief: 
    19321932            measure.m = relM 
    1933          
     1933 
    19341934        relK = getattr(self, "relief_k", None) 
    19351935        if relK and measureIsRelief: 
     
    19801980        return learner 
    19811981 
    1982     _built_fn = {  
     1982    _built_fn = { 
    19831983            "split": [ _build_split, [ "binarization", "measure", "relief_m", "relief_k", "worst_acceptable", "min_subset" ] ], \ 
    1984             "stop": [ _build_stop, ["max_majority", "min_instances" ] ]  
     1984            "stop": [ _build_stop, ["max_majority", "min_instances" ] ] 
    19851985        } 
    19861986 
     
    20242024fromto = r"(?P<out>!?)(?P<lowin>\(|\[)(?P<lower>\d*\.?\d+)\s*,\s*(?P<upper>\d*\.?\d+)(?P<upin>\]|\))" 
    20252025re_V = re.compile("%V") 
    2026 re_N = re.compile("%"+fs+"N"+by) 
    2027 re_M = re.compile("%"+fs+"M"+by) 
    2028 re_m = re.compile("%"+fs+"m"+by) 
    2029 re_Ccont = re.compile("%"+fs+"C"+by+opc) 
    2030 re_Cdisc = re.compile("%"+fs+"C"+by+opd) 
    2031 re_ccont = re.compile("%"+fs+"c"+by+opc) 
    2032 re_cdisc = re.compile("%"+fs+"c"+by+opd) 
    2033 re_Cconti = re.compile("%"+fs+"C"+by+fromto) 
    2034 re_cconti = re.compile("%"+fs+"c"+by+fromto) 
    2035 re_D = re.compile("%"+fs+"D"+by) 
    2036 re_d = re.compile("%"+fs+"d"+by) 
    2037 re_AE = re.compile("%"+fs+"(?P<AorE>A|E)"+bysub) 
    2038 re_I = re.compile("%"+fs+"I"+intrvl) 
     2026re_N = re.compile("%" + fs + "N" + by) 
     2027re_M = re.compile("%" + fs + "M" + by) 
     2028re_m = re.compile("%" + fs + "m" + by) 
     2029re_Ccont = re.compile("%" + fs + "C" + by + opc) 
     2030re_Cdisc = re.compile("%" + fs + "C" + by + opd) 
     2031re_ccont = re.compile("%" + fs + "c" + by + opc) 
     2032re_cdisc = re.compile("%" + fs + "c" + by + opd) 
     2033re_Cconti = re.compile("%" + fs + "C" + by + fromto) 
     2034re_cconti = re.compile("%" + fs + "c" + by + fromto) 
     2035re_D = re.compile("%" + fs + "D" + by) 
     2036re_d = re.compile("%" + fs + "d" + by) 
     2037re_AE = re.compile("%" + fs + "(?P<AorE>A|E)" + bysub) 
     2038re_I = re.compile("%" + fs + "I" + intrvl) 
    20392039 
    20402040def insert_str(s, mo, sub): 
     
    20732073    with, when division is required. 
    20742074    """ 
    2075     if by=="bP": 
     2075    if by == "bP": 
    20762076        return parent 
    20772077    else: 
     
    20922092            return insert_dot(strg, mo) 
    20932093    return insert_num(strg, mo, N) 
    2094          
     2094 
    20952095 
    20962096def replaceM(strg, mo, node, parent, tree): 
     
    21062106            return insert_dot(strg, mo) 
    21072107    return insert_num(strg, mo, N) 
    2108          
     2108 
    21092109 
    21102110def replacem(strg, mo, node, parent, tree): 
     
    21282128    if tree.class_var.var_type != Orange.data.Type.Discrete: 
    21292129        return insert_dot(strg, mo) 
    2130      
     2130 
    21312131    by, op, cls = mo.group("by", "op", "cls") 
    21322132    N = node.distribution[cls] 
     
    21422142    return insert_num(strg, mo, N) 
    21432143 
    2144      
     2144 
    21452145def replacecdisc(strg, mo, node, parent, tree): 
    21462146    if tree.class_var.var_type != Orange.data.Type.Discrete: 
    21472147        return insert_dot(strg, mo) 
    2148      
     2148 
    21492149    op, by, cls = mo.group("op", "by", "cls") 
    21502150    N = node.distribution[cls] 
     
    21682168    if tree.class_var.var_type != Orange.data.Type.Continuous: 
    21692169        return insert_dot(strg, mo) 
    2170      
     2170 
    21712171    by, op, num = mo.group("by", "op", "num") 
    21722172    op = __opdict[op] 
     
    21832183 
    21842184    return insert_num(strg, mo, N) 
    2185      
    2186      
     2185 
     2186 
    21872187def replaceccont(strg, mo, node, parent, tree): 
    21882188    if tree.class_var.var_type != Orange.data.Type.Continuous: 
    21892189        return insert_dot(strg, mo) 
    2190      
     2190 
    21912191    by, op, num = mo.group("by", "op", "num") 
    21922192    op = __opdict[op] 
     
    22002200            byN = sum([x[1] for x in whom.distribution.items() if op(x[0], num)], 0.) 
    22012201            if byN > 1e-30: 
    2202                 N /= byN/whom.distribution.abs # abs > byN, so byN>1e-30 => abs>1e-30 
     2202                N /= byN / whom.distribution.abs # abs > byN, so byN>1e-30 => abs>1e-30 
    22032203        else: 
    22042204            return insert_dot(strg, mo) 
     
    22182218        return filter(lambda x:lop(x[0], lower) and hop(x[0], upper), dist.items()) 
    22192219 
    2220      
     2220 
    22212221def replaceCconti(strg, mo, node, parent, tree): 
    22222222    if tree.class_var.var_type != Orange.data.Type.Continuous: 
     
    22332233        else: 
    22342234            return insert_dot(strg, mo) 
    2235          
     2235 
    22362236    return insert_num(strg, mo, N) 
    22372237 
    2238              
     2238 
    22392239def replacecconti(strg, mo, node, parent, tree): 
    22402240    if tree.class_var.var_type != Orange.data.Type.Continuous: 
     
    22522252            byN = sum([x[1] for x in extractInterval(mo, whom.distribution)]) 
    22532253            if byN > 1e-30: 
    2254                 N /= byN/whom.distribution.abs 
     2254                N /= byN / whom.distribution.abs 
    22552255        else: 
    22562256            return insert_dot(strg, mo) 
    2257          
     2257 
    22582258    return insert_num(strg, mo, N) 
    22592259 
    2260      
     2260 
    22612261def replaceD(strg, mo, node, parent, tree): 
    22622262    if tree.class_var.var_type != Orange.data.Type.Discrete: 
     
    22752275    mul = m100 and 100 or 1 
    22762276    fs = fs or (m100 and ".0" or "5.3") 
    2277     return insert_str(strg, mo, "["+", ".join(["%%%sf" % fs % (N*mul) for N in dist])+"]") 
     2277    return insert_str(strg, mo, "[" + ", ".join(["%%%sf" % fs % (N * mul) for N in dist]) + "]") 
    22782278 
    22792279 
     
    22862286    ab = node.distribution.abs 
    22872287    if ab > 1e-30: 
    2288         dist = [d/ab for d in dist] 
     2288        dist = [d / ab for d in dist] 
    22892289    if by: 
    22902290        whom = by_whom(by, parent, tree) 
     
    22922292            for i, d in enumerate(whom.distribution): 
    22932293                if d > 1e-30: 
    2294                     dist[i] /= d/whom.distribution.abs # abs > d => d>1e-30 => abs>1e-30 
     2294                    dist[i] /= d / whom.distribution.abs # abs > d => d>1e-30 => abs>1e-30 
    22952295        else: 
    22962296            return insert_dot(strg, mo) 
    22972297    mul = m100 and 100 or 1 
    22982298    fs = fs or (m100 and ".0" or "5.3") 
    2299     return insert_str(strg, mo, "["+", ".join(["%%%sf" % fs % (N*mul) for N in dist])+"]") 
     2299    return insert_str(strg, mo, "[" + ", ".join(["%%%sf" % fs % (N * mul) for N in dist]) + "]") 
    23002300 
    23012301 
     
    23052305 
    23062306    AorE, bysub, by = mo.group("AorE", "bysub", "by") 
    2307      
     2307 
    23082308    if AorE == "A": 
    23092309        A = node.distribution.average() 
     
    23112311        A = node.distribution.error() 
    23122312    if by: 
    2313         whom = by_whom("b"+by, parent, tree) 
     2313        whom = by_whom("b" + by, parent, tree) 
    23142314        if whom: 
    23152315            if AorE == "A": 
     
    23342334 
    23352335    fs = mo.group("fs") or "5.3" 
    2336     intrvl = float(mo.group("intp") or mo.group("intv") or "95")/100. 
     2336    intrvl = float(mo.group("intp") or mo.group("intv") or "95") / 100. 
    23372337    mul = mo.group("m100") and 100 or 1 
    23382338 
     
    23402340        raise SystemError, "Cannot compute %5.3f% confidence intervals" % intrvl 
    23412341 
    2342     av = node.distribution.average()     
     2342    av = node.distribution.average() 
    23432343    il = node.distribution.error() * Z[intrvl] 
    2344     return insert_str(strg, mo, "[%%%sf-%%%sf]" % (fs, fs) % ((av-il)*mul, (av+il)*mul)) 
     2344    return insert_str(strg, mo, "[%%%sf-%%%sf]" % (fs, fs) % ((av - il) * mul, (av + il) * mul)) 
    23452345 
    23462346 
     
    23502350class _TreeDumper: 
    23512351    defaultStringFormats = [(re_V, replaceV), (re_N, replaceN), 
    2352          (re_M, replaceM), (re_m, replacem),  
     2352         (re_M, replaceM), (re_m, replacem), 
    23532353         (re_Cdisc, replaceCdisc), (re_cdisc, replacecdisc), 
    23542354         (re_Ccont, replaceCcont), (re_ccont, replaceccont), 
    23552355         (re_Cconti, replaceCconti), (re_cconti, replacecconti), 
    2356          (re_D, replaceD), (re_d, replaced), (re_AE, replaceAE),  
     2356         (re_D, replaceD), (re_d, replaced), (re_AE, replaceAE), 
    23572357         (re_I, replaceI) ] 
    23582358 
     
    23602360        return self.tree.tree if "tree" in self.tree.__dict__ else self.tree 
    23612361 
    2362     def __init__(self, leafStr, nodeStr, stringFormats, minExamples,  
     2362    def __init__(self, leafStr, nodeStr, stringFormats, minExamples, 
    23632363        maxDepth, simpleFirst, tree, **kw): 
    23642364        self.stringFormats = stringFormats 
     
    23822382        else: 
    23832383            self.nodeStr = nodeStr 
    2384          
     2384 
    23852385 
    23862386    def formatString(self, strg, node, parent): 
    23872387        if hasattr(strg, "__call__"): 
    23882388            return strg(node, parent, self.tree) 
    2389          
     2389 
    23902390        if not node: 
    23912391            return "<null node>" 
    2392          
     2392 
    23932393        for rgx, replacer in self.stringFormats: 
    23942394            if not node.distribution: 
     
    24012401                        break 
    24022402                    strg = replacer(strg, mo, node, parent, self.tree) 
    2403                     strt = mo.start()+1 
    2404                          
     2403                    strt = mo.start() + 1 
     2404 
    24052405        return strg 
    2406          
     2406 
    24072407 
    24082408    def showBranch(self, node, parent, lev, i): 
     
    24162416            nodedes = "<null node>" 
    24172417        return "|    "*lev + bdes + nodedes 
    2418          
    2419          
     2418 
     2419 
    24202420    def dumpTree0(self, node, parent, lev): 
    24212421        if node.branches: 
     
    24232423                lev > self.maxDepth: 
    24242424                return "|    "*lev + ". . .\n" 
    2425              
     2425 
    24262426            res = "" 
    24272427            if self.leafStr and self.nodeStr and self.leafStr != self.nodeStr: 
    2428                 leafsep = "\n"+("|    "*lev)+"    " 
     2428                leafsep = "\n" + ("|    "*lev) + "    " 
    24292429            else: 
    24302430                leafsep = "" 
     
    24382438                            res += "%s: %s\n" % \ 
    24392439                                (self.showBranch(node, parent, lev, i), 
    2440                                  leafsep +  
     2440                                 leafsep + 
    24412441                                 self.formatString(self.leafStr, branch, node)) 
    24422442            for i, branch in enumerate(node.branches): 
    24432443                if branch and branch.branches: 
    24442444                    res += "%s\n%s" % (self.showBranch(node, parent, lev, i), 
    2445                                        self.dumpTree0(branch, node, lev+1)) 
     2445                                       self.dumpTree0(branch, node, lev + 1)) 
    24462446                elif not self.simpleFirst: 
    24472447                    if self.leafStr == self.nodeStr: 
     
    24502450                        res += "%s: %s\n" % \ 
    24512451                            (self.showBranch(node, parent, lev, i), 
    2452                              leafsep +  
     2452                             leafsep + 
    24532453                             self.formatString(self.leafStr, branch, node)) 
    24542454            return res 
     
    24662466            lev, res = 0, "" 
    24672467        return res + self.dumpTree0(node, None, lev) 
    2468          
     2468 
    24692469 
    24702470    def dotTree0(self, node, parent, internalName): 
    24712471        if node.branches: 
    24722472            if node.distribution.abs < self.minExamples or \ 
    2473                 len(internalName)-1 > self.maxDepth: 
     2473                len(internalName) - 1 > self.maxDepth: 
    24742474                self.fle.write('%s [ shape="plaintext" label="..." ]\n' % \ 
    24752475                    _quoteName(internalName)) 
    24762476                return 
    2477                  
     2477 
    24782478            label = node.branch_selector.class_var.name 
    24792479            if self.nodeStr: 
     
    24812481            self.fle.write('%s [ shape=%s label="%s"]\n' % \ 
    24822482                (_quoteName(internalName), self.nodeShape, label)) 
    2483              
     2483 
    24842484            for i, branch in enumerate(node.branches): 
    24852485                if branch: 
    2486                     internalBranchName = "%s-%d" % (internalName,i) 
     2486                    internalBranchName = "%s-%d" % (internalName, i) 
    24872487                    self.fle.write('%s -> %s [ label="%s" ]\n' % \ 
    2488                         (_quoteName(internalName),  
    2489                          _quoteName(internalBranchName),  
     2488                        (_quoteName(internalName), 
     2489                         _quoteName(internalBranchName), 
    24902490                         node.branch_descriptions[i])) 
    24912491                    self.dotTree0(branch, node, internalBranchName) 
    2492                      
     2492 
    24932493        else: 
    24942494            self.fle.write('%s [ shape=%s label="%s"]\n' % \ 
    2495                 (_quoteName(internalName), self.leafShape,  
     2495                (_quoteName(internalName), self.leafShape, 
    24962496                self.formatString(self.leafStr, node, parent))) 
    24972497 
     
    25342534        deeply as possible according to the instance's feature values. 
    25352535    """ 
    2536      
     2536 
    25372537    def __init__(self, base_classifier=None): 
    25382538        if not base_classifier: base_classifier = _TreeClassifier() 
     
    25402540        for k, v in self.nativeClassifier.__dict__.items(): 
    25412541            self.__dict__[k] = v 
    2542    
     2542 
    25432543    def __call__(self, instance, result_type=Orange.classification.Classifier.GetValue, 
    25442544                 *args, **kwdargs): 
     
    25652565            self.nativeClassifier.__dict__[name] = value 
    25662566        self.__dict__[name] = value 
    2567      
     2567 
    25682568    def __str__(self): 
    25692569        return self.to_string() 
     
    25732573        "userFormats": "user_formats", "minExamples": "min_examples", \ 
    25742574        "maxDepth": "max_depth", "simpleFirst": "simple_first"}) 
    2575     def to_string(self, leaf_str = "", node_str = "", \ 
     2575    def to_string(self, leaf_str="", node_str="", \ 
    25762576            user_formats=[], min_examples=0, max_depth=1e10, \ 
    25772577            simple_first=True): 
     
    26032603          information in the nodes. 
    26042604        """ 
    2605         return _TreeDumper(leaf_str, node_str, user_formats +  
    2606             _TreeDumper.defaultStringFormats, min_examples,  
     2605        return _TreeDumper(leaf_str, node_str, user_formats + 
     2606            _TreeDumper.defaultStringFormats, min_examples, 
    26072607            max_depth, simple_first, self).dumpTree() 
    26082608 
     
    26142614        "userFormats": "user_formats", "minExamples": "min_examples", \ 
    26152615        "maxDepth": "max_depth", "simpleFirst": "simple_first"}) 
    2616     def dot(self, file_name, leaf_str = "", node_str = "", \ 
     2616    def dot(self, file_name, leaf_str="", node_str="", \ 
    26172617            leaf_shape="plaintext", node_shape="plaintext", \ 
    26182618            user_formats=[], min_examples=0, max_depth=1e10, \ 
     
    26352635        fle = type(file_name) == str and open(file_name, "wt") or file_name 
    26362636 
    2637         _TreeDumper(leaf_str, node_str, user_formats +  
    2638             _TreeDumper.defaultStringFormats, min_examples,  
     2637        _TreeDumper(leaf_str, node_str, user_formats + 
     2638            _TreeDumper.defaultStringFormats, min_examples, 
    26392639            max_depth, simple_first, self, 
    26402640            leafShape=leaf_shape, nodeShape=node_shape, fle=fle).dotTree() 
     
    26652665        data = Orange.data.Table(domain) 
    26662666        self.to_network0(self.tree, net, data) 
    2667         return net, data 
     2667        net.set_items(data) 
     2668        return net 
    26682669 
    26692670    def to_network0(self, node, net, table): 
     
    26742675        if self.class_var.var_type == Orange.data.Type.Discrete: 
    26752676            if d.abs > 1e-6: 
    2676                 table.append([maj, d.abs, d[maj]] + [x/d.abs for x in d]) 
     2677                table.append([maj, d.abs, d[maj]] + [x / d.abs for x in d]) 
    26772678            else: 
    2678                 table.append([maj] + [0]*(2 + len(d))) 
     2679                table.append([maj] + [0] * (2 + len(d))) 
    26792680        else: 
    26802681            table.append([maj, d.error(), d.abs]) 
Note: See TracChangeset for help on using the changeset viewer.