source: orange/Orange/orng/orngInteract.py @ 11012:19029caa4a32

Revision 11012:19029caa4a32, 34.6 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 18 months ago (diff)

Fixed more network references.

Line 
1#
2# Module Orange Interactions
3# --------------------------
4#
5# CVS Status: $Id$
6#
7# Author: Aleks Jakulin (jakulin@acm.org)
8# (Copyright (C)2004 Aleks Jakulin)
9#
10# Purpose: Analysis of dependencies between attributes given the class.
11#          3-WAY INTERACTIONS
12#
13# Project initiated on 2003/05/08
14#
15# ChangeLog:
16#   - 2003/05/09:
17#       fixed a problem with domains that need no preprocessing
18#       fixed the decimal point printing problem
19#       added the support for dissimilarity matrix, used for attribute clustering
20#   - 2003/05/10:
21#       fixed a problem with negative percentages of less than a percent
22#   - 2003/05/12:
23#       separated the 'prepare' function
24#   - 2003/09/18:
25#       added support for cluster coloring
26#       cleaned up backwards-incompatible changes (grrr) (color changes, discData)
27#       added color-coded dissimilarity matrix export
28#   - 2004/01/31:
29#       removed adhoc stats-gathering code in favor of the orngContingency module
30#       added p-value estimates
31#   - 2004/03/24:
32#       fixed an ugly bug in dep-dissimilarity matrix processing
33#
34
35import orange, statc
36import orngContingency, numpy
37import warnings, math, string, copy
38import Orange
39
40def _nicefloat(f,sig):
41    # pretty-float formatter
42    i = int(f)
43    s = '%1.0f'%f
44    n = sig-len('%d'%abs(f)) # how many digits is the integer part
45    if n > 0:
46        # we can put a few decimals at the end
47        fp = abs(f)-abs(i)
48        s = ''
49        if f < 0:
50            s += '-'
51        s += '%d'%abs(i) + ('%f'%fp)[1:2+n]
52
53    return s
54
55
56
57class InteractionMatrix:
58    def _prepare(self, t):
59        # prepares an Orange table so that it doesn't contain continuous
60        # attributes or missing values
61
62        ### DISCRETIZE VARIABLES ###
63
64        newatt = []
65        oldatt = []
66        entroD = orange.EntropyDiscretization()
67        equiD = orange.EquiNDiscretization(numberOfIntervals = 2)
68        for i in t.domain.attributes:
69            if i.varType == 2:
70                d = entroD(i,t)
71                if len(d.values) < 2:
72                    # prevent discretization into a single value
73                    d = equiD(i,t)
74                    d.name = 'E'+d.name
75                warnings.warn('Discretizing %s into %s with %d values.'%(i.name,d.name,len(d.values)))
76                newatt.append(d)
77            else:
78                oldatt.append(i)
79        if len(newatt) > 0:
80            t = t.select(oldatt+newatt+[t.domain.classVar])
81
82        ### FIX MISSING VALUES ###
83
84        special_attributes = []
85
86        # 2006-08-23: fixed by PJ: append classVar only if it exists
87##        all_attributes = [i for i in t.domain.attributes]+[t.domain.classVar]
88        all_attributes = [i for i in t.domain.attributes]
89        if t.domain.classVar:
90            all_attributes += [t.domain.classVar]
91
92        for i in range(len(all_attributes)):
93            for j in t:
94                if j[i].isSpecial():
95                    special_attributes.append(i)
96                    break
97        # create new attributes
98        if len(special_attributes) > 0:
99            # prepare attributes
100            newatts = []
101            for i in range(len(all_attributes)):
102                old = all_attributes[i]
103                if i in special_attributes:
104                    oldv = [v for v in old.values]
105                    assert('.' not in oldv)
106                    new = orange.EnumVariable(name='M_'+old.name, values=oldv+['.'])
107                    warnings.warn('Removing special values from %s into %s.'%(old.name,new.name))
108                    newatts.append(new)
109                else:
110                    newatts.append(old)
111            # convert table
112            exs = []
113
114            # 2006-08-23: added by PJ: add a class variable (if not already existing)
115            if not t.domain.classVar:
116                newatts.append(orange.EnumVariable("class", values=["."]))
117                t = orange.ExampleTable(orange.Domain(t.domain.attributes, newatts[-1]), t)
118
119            newd = orange.Domain(newatts)
120            for ex in t:
121                nex = []
122                for i in range(len(newatts)):
123                    if ex[i].isSpecial():
124                        v = newatts[i]('.')
125                    else:
126                        v = newatts[i](int(ex[i]))
127                    nex.append(v)
128                exs.append(orange.Example(newd,nex))
129            t = orange.ExampleTable(exs)
130        return t
131
132    def __init__(self, t, save_data=1, interactions_too = 1, dependencies_too=0, prepare=1, pvalues = 0, simple_too=0,iterative_scaling=0,weighting=None):
133        if prepare:
134            t = self._prepare(t)
135        if save_data:
136            self.discData = t   # save the discretized data
137
138        ### PREPARE INDIVIDUAL ATTRIBUTES ###
139
140        # Attribute Preparation
141        NA = len(t.domain.attributes)
142
143        self.names = []
144        self.labelname = ""
145        if t.domain.classVar:
146            self.labelname = t.domain.classVar.name
147        self.gains = []
148        self.freqs = []
149        self.way2 = {}
150        self.way3 = {}
151        self.ig = []
152        self.list = []
153        self.abslist = []
154        self.plist = []
155        self.plut = {}
156        self.ents = {}
157        self.corr = {}
158        self.chi2 = {}
159        self.simple = {}
160        for i in range(NA):
161            if weighting != None:
162                atc = orngContingency.get2Int(t,t.domain.attributes[i],t.domain.classVar,wid=weighting)
163            else:
164                atc = orngContingency.get2Int(t,t.domain.attributes[i],t.domain.classVar)
165            gai = atc.InteractionInformation()
166            self.gains.append(gai)
167            self.corr[(i,-1)] = gai
168            self.ents[(i,)] = orngContingency.Entropy(atc.a)
169            self.way2[(i,-1,)] = atc
170            self.ents[(i,-1)] = orngContingency.Entropy(atc.m)
171            N = sum(atc.a)
172            self.chi2[(i, i)] = statc.chisqprob(N * (numpy.sum(numpy.outer(atc.pa, atc.pa)) - 2 + len(atc.pa)), (len(atc.pa)-1)**2)
173
174#            self.chi2[(i, i)] = N * (numpy.sum(numpy.outer(atc.pa, atc.pa)) - 2 + len(atc.pa))   
175            if simple_too:
176                simp = 0.0
177                for k in xrange(min(len(atc.a),len(atc.b))):
178                    try:
179                        simp += atc.pm[k,k]
180                    except:
181                        pass
182                self.simple[(i,-1)] = simp
183            # fix the name
184            st = '%s'%t.domain.attributes[i].name # copy
185            self.names.append(st)
186            if pvalues:
187                pv = orngContingency.getPvalue(gai,atc)
188                self.plist.append((pv,(gai,i,-1)))
189                self.plut[(i,-1)] = pv
190                #print "%s\t%f\t%f\t%d"%(st,pv,gai,atc.total)
191            line = []
192            for j in range(i):
193                if dependencies_too:
194                    if weighting != None:
195                        c = orngContingency.get2Int(t,t.domain.attributes[j],t.domain.attributes[i],wid=weighting)
196                    else:
197                        c = orngContingency.get2Int(t,t.domain.attributes[j],t.domain.attributes[i])
198                    self.way2[(j,i,)] = c
199                    gai = c.InteractionInformation()
200                    self.ents[(j,i,)] = orngContingency.Entropy(c.m)
201                    self.corr[(j,i,)] = gai
202                    self.chi2[(j,i)] = c.ChiSquareP()   
203                    if simple_too:
204                        simp = 0.0
205                        for k in xrange(min(len(c.a),len(c.b))):
206                            try:
207                                qq = c.pm[k,k]
208                            except:
209                                qq = 0
210                            simp += qq
211                        self.simple[(j,i)] = simp
212                    if pvalues:
213                        pv = orngContingency.getPvalue(gai,c)
214                        self.plist.append((pv,(gai,j,i)))
215                        self.plut[(j,i)] = pv
216                if interactions_too:
217                    if weighting != None:
218                        c = orngContingency.get3Int(t,t.domain.attributes[j],t.domain.attributes[i],t.domain.classVar,wid=weighting)
219                    else:
220                        c = orngContingency.get3Int(t,t.domain.attributes[j],t.domain.attributes[i],t.domain.classVar)
221                    self.way3[(j,i,-1)] = c
222                    igv = c.InteractionInformation()
223                    line.append(igv)
224                    self.list.append((igv,(igv,j,i)))
225                    self.abslist.append((abs(igv),(igv,j,i)))
226                    if pvalues:
227                        if iterative_scaling:
228                            div = c.IPF()
229                        else:
230                            div = c.KSA()[0]
231                        pv = orngContingency.getPvalue(div,c)
232                        #print "%s-%s\t%f\t%f\t%d"%(c.names[0],c.names[1],pv,igv,c.total)
233                        self.plist.append((pv,(igv,j,i,-1)))
234                        self.plut[(j,i,-1)] = pv
235            self.ig.append(line)
236        self.entropy = orngContingency.Entropy(atc.b)
237        self.ents[(-1,)] = self.entropy
238        self.list.sort()
239        self.abslist.sort()
240        self.plist.sort()
241
242        self.attlist = []
243        for i in range(NA):
244            self.attlist.append((self.gains[i],i))
245        self.attlist.sort()
246        self.NA = NA
247
248    def dump(self):
249        NA = len(self.names)
250        for j in range(1,NA):
251            for i in range(j):
252                t = '%s+%s'%(self.names[i],self.names[j])
253                print "%30s\t%2.4f\t%2.4f\t%2.4f\t%2.4f\t%2.4f"%(t,self.igain[(i,j)],self.corr[(i,j)],self.igain[(i,j)]+self.corr[(i,j)],self.gains[i],self.gains[j])
254               
255    def exportNetwork(self,  absolute_int=10, positive_int = 0, negative_int = 0, best_attributes = 0, significant_digits = 2, pretty_names = 1, widget_coloring=1, pcutoff = 1):
256        NA = len(self.names)
257
258        ### SELECTION OF INTERACTIONS AND ATTRIBUTES ###
259
260        # prevent crashes
261        best_attributes = min(best_attributes,len(self.attlist))
262        positive_int = min(positive_int,len(self.list))
263        absolute_int = min(absolute_int,len(self.list))
264        negative_int = min(negative_int,len(self.list))
265
266        # select the top interactions
267        ins = []
268        if positive_int > 0:
269            ins += self.list[-positive_int:]
270        ins += self.list[:negative_int]
271        if absolute_int > 0:
272            ins += self.abslist[-absolute_int:]
273
274        # pick best few attributes
275        atts = []
276        if best_attributes > 0:
277            atts += [i for (x,i) in self.attlist[-best_attributes:]]
278
279        # disregard the insignificant attributes, interactions
280        if len(self.plist) > 0 and pcutoff < 1:
281            # attributes
282            oats = atts
283            atts = []
284            for i in oats:
285                if self.plut[(i,-1)] < pcutoff:
286                    atts.append(i)
287            # interactions
288            oins = ins
289            ins = []
290            for y in oins:
291                (ig,i,j) = y[1]
292                if self.plut[(i,j,-1)] < pcutoff:
293                    ins.append(y)
294       
295        ints = []
296        max_igain = -1e6
297        min_gain = 1e6 # lowest information gain of involved attributes
298        # remove duplicates and sorting keys
299        for (x,v) in ins:
300            if v not in ints:
301                ints.append(v)
302                # add to attribute list
303                (ig,i,j) = v
304                max_igain = max(abs(ig),max_igain)
305                for x in [i,j]:
306                    if x not in atts:
307                        atts.append(x)
308                        min_gain = min(min_gain,self.gains[x])
309
310        # fill-in the attribute list with all possibly more important attributes
311        ## todo
312           
313        ### NODE DRAWING ###
314        map = {}
315        graph = Orange.core.Network(len(atts), 0)
316        table = []
317       
318        for i in range(len(atts)):
319            map[atts[i]] = i
320           
321            ndx = atts[i]
322            t = '%s' % self.names[ndx]
323            if pretty_names:
324                t = string.replace(t, "ED_", "")
325                t = string.replace(t, "D_", "")
326                t = string.replace(t, "M_", "")
327                t = string.replace(t, " ", "\\n")
328                t = string.replace(t, "-", "\\n")
329                t = string.replace(t, "_", "\\n")
330                r = self.gains[ndx] * 100.0 / self.entropy
331                table.append([i + 1, t, r]) 
332       
333        d = orange.Domain([orange.FloatVariable('index'), orange.StringVariable('label'), orange.FloatVariable('norm. gain')])
334        data = orange.ExampleTable(d, table)
335        graph.items = data
336       
337        table = []
338        for (ig,i,j) in ints:
339            j = map[j]
340            i = map[i]
341           
342            perc = int(abs(ig)*100.0/max(max_igain,self.attlist[-1][0])+0.5)
343            graph[i, j] = perc / 30 + 1
344           
345            if self.entropy > 1e-6:
346                mc = _nicefloat(100.0*ig/self.entropy,significant_digits)+"%"
347            else:
348                mc = _nicefloat(0.0,significant_digits)
349            if len(self.plist) > 0 and pcutoff < 1:
350                mc += "\\nP\<%.3f"%self.plut[(i,j,-1)]
351
352            if ig > 0:
353                if widget_coloring:
354                    color = "green"
355                else:
356                    color = '"0.0 %f 0.9"'%(0.3+0.7*perc/100.0) # adjust saturation
357                dir = "both"
358            else:
359                if widget_coloring:
360                    color = "red"
361                else:
362                    color = '"0.5 %f 0.9"'%(0.3+0.7*perc/100.0) # adjust saturation
363                dir = 'none'
364
365            table.append([i, j, mc, dir, color])
366
367        d = orange.Domain([orange.FloatVariable('u'), orange.FloatVariable('v'), orange.StringVariable('label'), orange.EnumVariable('dir', values = ["both", "none"]), orange.EnumVariable('color', values = ["green", "red"])])
368        data = orange.ExampleTable(d, table)
369        graph.links = data
370
371        return graph
372
373    def exportGraph(self, f, absolute_int=10, positive_int = 0, negative_int = 0, best_attributes = 0, print_bits = 1, black_white = 0, significant_digits = 2, postscript = 1, pretty_names = 1, url = 0, widget_coloring=1, pcutoff = 1):
374        NA = len(self.names)
375
376        ### SELECTION OF INTERACTIONS AND ATTRIBUTES ###
377
378        # prevent crashes
379        best_attributes = min(best_attributes,len(self.attlist))
380        positive_int = min(positive_int,len(self.list))
381        absolute_int = min(absolute_int,len(self.list))
382        negative_int = min(negative_int,len(self.list))
383
384        # select the top interactions
385        ins = []
386        if positive_int > 0:
387            ins += self.list[-positive_int:]
388        ins += self.list[:negative_int]
389        if absolute_int > 0:
390            ins += self.abslist[-absolute_int:]
391
392        # pick best few attributes
393        atts = []
394        if best_attributes > 0:
395            atts += [i for (x,i) in self.attlist[-best_attributes:]]
396
397        # disregard the insignificant attributes, interactions
398        if len(self.plist) > 0 and pcutoff < 1:
399            # attributes
400            oats = atts
401            atts = []
402            for i in oats:
403                if self.plut[(i,-1)] < pcutoff:
404                    atts.append(i)
405            # interactions
406            oins = ins
407            ins = []
408            for y in oins:
409                (ig,i,j) = y[1]
410                if self.plut[(i,j,-1)] < pcutoff:
411                    ins.append(y)
412       
413        ints = []
414        max_igain = -1e6
415        min_gain = 1e6 # lowest information gain of involved attributes
416        # remove duplicates and sorting keys
417        for (x,v) in ins:
418            if v not in ints:
419                ints.append(v)
420                # add to attribute list
421                (ig,i,j) = v
422                max_igain = max(abs(ig),max_igain)
423                for x in [i,j]:
424                    if x not in atts:
425                        atts.append(x)
426                        min_gain = min(min_gain,self.gains[x])
427
428        # fill-in the attribute list with all possibly more important attributes
429        ## todo
430
431        ### NODE DRAWING ###
432
433        # output the attributes
434        f.write("digraph G {\n")
435
436        if print_bits:
437            shap = 'record'
438        else:
439            shap = 'box'
440
441        for i in atts:
442            t = '%s'%self.names[i]
443            if pretty_names:
444                t = string.replace(t,"ED_","")
445                t = string.replace(t,"D_","")
446                t = string.replace(t,"M_","")
447                t = string.replace(t," ","\\n")
448                t = string.replace(t,"-","\\n")
449                t = string.replace(t,"_","\\n")
450            if print_bits:
451                r = self.gains[i]*100.0/self.entropy
452                if len(self.plist) > 0 and pcutoff < 1:
453                    t = "{%s|{%s%% | P\<%.3f}}"%(t,_nicefloat(r,significant_digits),self.plut[(i,-1)])
454                else:
455                    t = "{%s|%s%%}"%(t,_nicefloat(r,significant_digits))
456            if not url:
457                f.write("\tnode [ shape=%s, label = \"%s\"] %d;\n"%(shap,t,i))
458            else:
459                f.write("\tnode [ shape=%s, URL = \"%d\", label = \"%s\"] %d;\n"%(shap,i,t,i))
460
461        ### EDGE DRAWING ###
462
463        for (ig,i,j) in ints:
464            perc = int(abs(ig)*100.0/max(max_igain,self.attlist[-1][0])+0.5)
465
466            if self.entropy > 1e-6:
467                mc = _nicefloat(100.0*ig/self.entropy,significant_digits)+"%"
468            else:
469                mc = _nicefloat(0.0,significant_digits)
470            if len(self.plist) > 0 and pcutoff < 1:
471                mc += "\\nP\<%.3f"%self.plut[(i,j,-1)]
472            if postscript:
473                style = "style=\"setlinewidth(%d)\","%(abs(perc)/30+1)
474            else:
475                style = ''
476            if black_white:
477                color = 'black'
478                if ig > 0:
479                    dir = "both"
480                else:
481                    style = 'style=dashed,'
482                    dir = 'none'
483            else:
484                if ig > 0:
485                    if widget_coloring:
486                        color = "green"
487                    else:
488                        color = '"0.0 %f 0.9"'%(0.3+0.7*perc/100.0) # adjust saturation
489                    dir = "both"
490                else:
491                    if widget_coloring:
492                        color = "red"
493                    else:
494                        color = '"0.5 %f 0.9"'%(0.3+0.7*perc/100.0) # adjust saturation
495                    dir = 'none'
496            if not url:
497                f.write("\t%d -> %d [dir=%s,%scolor=%s,label=\"%s\",weight=%d];\n"%(i,j,dir,style,color,mc,(perc/30+1)))
498            else:
499                f.write("\t%d -> %d [URL=\"%d-%d\",dir=%s,%scolor=%s,label=\"%s\",weight=%d];\n"%(i,j,min(i,j),max(i,j),dir,style,color,mc,(perc/30+1)))
500
501        f.write("}\n")
502
503    def exportDissimilarityMatrix(self, truncation = 1000, pretty_names = 1, print_bits = 0, significant_digits = 2, show_gains = 1, color_coding = 0, color_gains = 0, jaccard=0, noclass=0):
504        NA = self.NA
505
506        ### BEAUTIFY THE LABELS ###
507
508        labels = []
509        maxgain = max(self.gains)
510        for i in range(NA):
511            t = '%s'%self.names[i]
512            if pretty_names:
513                t = string.replace(t,"ED_","")
514                t = string.replace(t,"D_","")
515                t = string.replace(t,"M_","")
516            r = self.gains[i]
517            if print_bits:
518                if self.entropy > 1e-6:
519                    t = "%s (%s%%)"%(t,_nicefloat(r*100.0/self.entropy,significant_digits))
520                else:
521                    t = "%s (0%%)"%(t)
522            if show_gains: # a bar indicating the feature importance
523                if maxgain > 1e-6:
524                    t += ' '+'*'*int(8.0*r/maxgain+0.5)
525            labels.append(t)
526
527        ### CREATE THE DISSIMILARITY MATRIX ###
528
529        if jaccard:
530            # create the lookup of 3-entropies
531            ent3 = {}
532            maxx = 1e-6
533            for i in range(1,NA):
534                for j in range(i):
535                    if noclass:
536                        e = self.ents[(j,i)]
537                    else:
538                        e = self.ents[(j,i)]+self.ents[(j,-1)]+self.ents[(i,-1)]
539                        e -= self.ents[(i,)]+self.ents[(j,)]+self.ents[(-1,)]
540                        e -= self.ig[i][j]
541                    ent3[(i,j)] = e
542                    if e > 1e-6:
543                        e = abs(self.ig[i][j])/e
544                    else:
545                        e = 0.0
546                    maxx = max(maxx,e)
547            # check the information gains...
548            if color_gains:
549                for i in range(NA):
550                    e = self.gains[i]
551                    if self.ents[(i,-1)] > 1e-6:
552                        e /= self.ents[(i,-1)]
553                    else:
554                        e = 0.0
555                    ent3[(i,)] = e
556                    maxx = max(maxx,e)
557        else:
558            maxx = self.abslist[-1][0]
559            if color_gains:
560                maxx = max(maxx,self.attlist[-1][0])
561        if color_gains:
562            if maxx > 1e-6:
563                cgains = [0.5*(1-i/maxx) for i in self.gains]
564            else:
565                cgains = [0.0 for i in self.gains]
566        diss = []
567        for i in range(1,NA):
568            newl = []
569            for j in range(i):
570                d = self.ig[i][j]
571                if jaccard:
572                    if ent3[(i,j)] > 1e-6:
573                        d /= ent3[(i,j)]
574                    else:
575                        d = 0.0
576                if color_coding:
577                    if maxx > 1e-6:
578                        if maxx > 1e-6:
579                            t = 0.5*(1-d/maxx)
580                        else:
581                            t = 0.0
582                    else:
583                        t = 0
584                else:
585                    # transform the IG into a distance
586                    ad = abs(d)
587                    if ad*truncation > 1:
588                        t = 1.0 / ad
589                    else:
590                        t = truncation
591                newl.append(t)
592            diss.append(newl)
593
594        if color_gains:
595            return (diss,labels,cgains)
596        else:
597            return (diss,labels)
598
599    def getClusterAverages(self, clust):
600        #assert(len(self.attlist) == clust.n)
601        # get the max value
602        #d = max(self.attlist[-1][0],self.abslist[-1][0])
603        d = self.abslist[-1][0]
604        # prepare a lookup
605        LUT = {}
606        for (ig,(igv,i,j)) in self.list:
607            LUT[i,j] = igv
608            LUT[j,i] = igv
609
610        cols = []
611        merges = []
612        for i in range(clust.n):
613            merges.append((0.0,[clust.n-i-1]))
614        merges.append("sentry")
615        p = clust.n
616        for i in range(clust.n-1):
617            a = merges[p+clust.merging[i][0]] # cluster 1
618            b = merges[p+clust.merging[i][1]] # cluster 2
619            na = len(a[1])
620            nb = len(b[1])
621            # compute cross-average
622            sum = 0.0
623            for x in a[1]:
624                for y in b[1]:
625                    sum += LUT[x,y]
626            avg = (a[0]*(na*na-na) + b[0]*(nb*nb-nb) + 2*sum)/(math.pow(na+nb,2)-na-nb)
627            clustercolor = 0.5*(1-avg/d)
628            intercluster = 0.5*(1-sum/(d*na*nb))
629            cols.append((clustercolor,intercluster)) # positive -> red, negative -> blue
630            merges.append((avg,a[1]+b[1]))
631        return cols
632
633
634
635
636    def depExportGraph(self, f, n_int=1, print_bits = 1, black_white = 0, undirected = 1, significant_digits = 2, pretty_names = 1, pcutoff=-1, postscript=1, spanning_tree = 1, TAN=1, source=-1, labelled=1,jaccard=1,filter=[],diagonal=0,pvlabel=0):
637        NA = self.NA
638
639        ### SELECTION OF INTERACTIONS AND ATTRIBUTES ###
640
641        links = []
642        maxlink = -1e6
643        if n_int == 1 and spanning_tree:
644            # prepare table
645            lmm = []
646            for i in range(1,NA):
647                ei = self.ents[(i,)]
648                for j in range(i):
649                    ej = self.ents[(j,)]
650                    if TAN:
651                        # I(A;B|C)
652                        v = self.way3[(j,i,-1)].InteractionInformation()
653                        v += self.way2[(j,i)].InteractionInformation()
654                    else:
655                        if jaccard:
656                            v = self.way2[(j,i)].JaccardInteraction() # I(A;B) chow-liu, mutual information
657                        else:
658                            v = self.way2[(j,i)].InteractionInformation() # I(A;B) chow-liu, mutual information
659                    if ei > ej:
660                        lmm.append((abs(v),v,ej,(j,i)))
661                    else:
662                        lmm.append((abs(v),v,ei,(i,j)))
663            lmm.sort()
664            maxlink = lmm[-1][0]
665            # use Prim's algorithm here
666            mapped = []
667            for i in range(NA):
668                mapped.append(i)
669            n = NA
670            idx = -1 # running index in the sorted array of possible links
671            while n > 1:
672                # find the cheapest link
673                while 1:
674                    (av,v,e,(i,j)) = lmm[idx]
675                    idx -= 1
676                    if mapped[i] != mapped[j]:
677                        break
678                links.append((v,(i,j),e))
679                toremove = mapped[j]
680                for k in range(NA):
681                    if mapped[k] == toremove:
682                        mapped[k] = mapped[i]
683                n -= 1
684        else:
685            # select the top
686            lmm = []
687            for i in range(NA):
688                if filter==[] or self.names[i] in filter:
689                    for j in range(i):
690                        if filter==[] or self.names[j] in filter:
691                            ii = max(i,j)
692                            jj = min(i,j)
693                            if jaccard and pcutoff < 0.0:
694                                if self.ents[(jj,ii)] == 0.0:
695                                    v = 1.0
696                                else:
697                                    v = self.way2[(jj,ii)].JaccardInteraction()
698                                lmm.append((v,(i,j)))
699                            else:
700                                v = self.way2[(jj,ii)].InteractionInformation()
701                                if pcutoff >= 0.0:
702                                    xt = self.way2[(jj,ii)]
703                                    dof = 1.0
704                                    dof *= len(xt.values[0])-1
705                                    dof *= len(xt.values[1])-1
706                                    pv = orngContingency.getPvalueDOF(v,xt,dof)
707                                    if pv <= pcutoff:
708                                        v = 1-pv
709                                        lmm.append((v,(i,j)))
710                                else:
711                                    lmm.append((v,(i,j)))
712            lmm.sort()
713            maxlink = max(lmm[-1][0],maxlink)
714            links += [(v,p,1.0) for (v,p) in lmm[-n_int:]]
715
716        # mark vertices
717        mv = [0 for x in range(NA)]
718        for (v,(i,j),e) in links:
719            mv[i] = 1
720            mv[j] = 1
721
722        # output the attributes
723        f.write("digraph G {\n")
724
725        if print_bits:
726            shap = 'record'
727        else:
728            shap = 'box'
729
730        for n in range(NA):
731            if mv[n]:
732                if source != -1 and not type(source)==type(1):
733                    # find the name
734                    if string.upper(self.names[n])==string.upper(source):
735                        source = n
736                t = '%s'%self.names[n]
737                if pretty_names:
738                    t = string.replace(t,"ED_","")
739                    t = string.replace(t,"D_","")
740                    t = string.replace(t,"M_","")
741                    t = string.replace(t," ","\\n")
742                    t = string.replace(t,"-","\\n")
743                    t = string.replace(t,"_","\\n")
744                if print_bits:
745                    #t = "{%s|%s}"%(t,_nicefloat(self.ents[(n,)],significant_digits))
746                    t = "{%s|%s}"%(t,_nicefloat(self.way2[(n,-1)].total,significant_digits))
747                f.write("\tnode [ shape=%s, label = \"%s\"] %d;\n"%(shap,t,n))
748
749        if source != -1:
750            # redirect all links
751            age = [-1]*NA
752            age[source] = 0
753            phase = 1
754            remn = NA-1
755            premn = -1
756            while remn > 0 and premn != remn:
757                premn = remn
758                for (v,(i,j),e) in links:
759                    if age[i] >= 0 and age[i] < phase and age[j] < 0:
760                        age[j] = phase
761                        remn -= 1
762                    if age[j] >= 0 and age[j] < phase and age[i] < 0:
763                        age[i] = phase
764                        remn -= 1
765                phase += 1
766
767        ### EDGE DRAWING ###
768        for (v,(i,j),e) in links:
769            if v > 0:
770                c = v/e
771                perc = int(100*v/maxlink + 0.5)
772
773                style = ''
774                if postscript:
775                    style += "style=\"setlinewidth(%d)\","%(abs(perc)/30+1)
776                if not black_white:
777                    l = 0.3+0.7*perc/100.0
778                    style += 'color="0.5 %f %f",'%(l,1-l) # adjust saturation
779                if labelled:
780                    if diagonal:
781                        ct = self.way2[(min(i,j),max(i,j))]
782                        (ni,nj) = numpy.shape(ct.m)
783                        cc = 0.0
784                        if ni==nj:
785                            for x in range(ni):
786                                cc += ct.m[x,x]
787                        style += 'label=\"%s%%\",'%_nicefloat(100.0*cc/ct.total,significant_digits)
788                    elif pvlabel and pcutoff >= 0.0:
789                        style += 'label=\"%e\",'%(1-v)
790                    else:
791                        style += 'label=\"%s%%\",'%_nicefloat(100.0*c,significant_digits)
792                if source == -1 or undirected:
793                    f.write("\t%d -> %d [%sweight=%d,dir=none];\n"%(j,i,style,(perc/30+1)))
794                else:
795                    if age[i] > age[j]:
796                        f.write("\t%d -> %d [%sweight=%d];\n"%(j,i,style,(perc/30+1)))
797                    else:
798                        f.write("\t%d -> %d [%sweight=%d];\n"%(i,j,style,(perc/30+1)))
799        f.write("}\n")
800
801    def exportChi2Matrix(self, pretty_names = 1):
802        labels = []
803        for i in range(self.NA):
804            t = '%s'%self.names[i]
805            if pretty_names:
806                t = string.replace(t,"ED_","")
807                t = string.replace(t,"D_","")
808                t = string.replace(t,"M_","")
809            labels.append(t)
810
811        diss = [[self.chi2[(i,j)] for i in range(j+1)] for j in range(self.NA)]
812        return diss, labels
813
814    def depExportDissimilarityMatrix(self, truncation = 1000, pretty_names = 1, jaccard = 1, simple_metric=0,color_coding = 0, verbose=0, include_label=0):
815        NA = self.NA
816
817        ### BEAUTIFY THE LABELS ###
818
819        labels = []
820        for i in range(NA):
821            t = '%s'%self.names[i]
822            if pretty_names:
823                t = string.replace(t,"ED_","")
824                t = string.replace(t,"D_","")
825                t = string.replace(t,"M_","")
826            labels.append(t)
827        if include_label:
828            labels.append(self.labelname)
829
830        ### CREATE THE DISSIMILARITY MATRIX ###
831
832        if color_coding:
833            maxx = -1
834            pett = range(1,NA)
835            if include_label:
836                pett.append(-1)
837            for x in pett:
838                if x == -1:
839                    sett = range(NA)
840                else:
841                    sett = range(x)
842                for y in sett:
843                    t = self.corr[(y,x)]
844                    if jaccard:
845                        l = self.ents[(y,x)]
846                        if l > 1e-6:
847                            t /= l
848                    maxx = max(maxx,t)
849            if verbose:
850                if jaccard:
851                    print 'maximum intersection is %3d percent.'%(maxx*100.0)
852                else:
853                    print 'maximum intersection is %f bits.'%maxx
854        diss = []
855        pett = range(1,NA)
856        if include_label:
857            pett.append(-1)
858        for x in pett:
859            if x == -1:
860                sett = range(NA)
861            else:
862                sett = range(x)
863            newl = []
864            for y in sett:
865                if simple_metric:
866                    t = 1-self.simple[(y,x)]
867                else:
868                    t = self.corr[(y,x)]
869                if jaccard:
870                    l = self.ents[(y,x)]
871                    if l > 1e-6:
872                        t /= l
873                if color_coding:
874                    #t = 0.5*(1-t/maxx)
875                    if jaccard:
876                        t = (1-t)*0.5
877                    else:
878                        t = 0.5*(1-t/maxx)
879                else:
880                    if t*truncation > 1:
881                        t = 1.0 / t
882                    else:
883                        t = truncation
884                newl.append(t)
885            diss.append(newl)
886        return (diss, labels)
887
888
889    def depGetClusterAverages(self, clust):
890        d = 1.0
891        cols = []
892        merges = []
893        for i in range(clust.n):
894            merges.append((0.0,[clust.n-i-1]))
895        merges.append("sentry")
896        p = clust.n
897        for i in range(clust.n-1):
898            a = merges[p+clust.merging[i][0]] # cluster 1
899            b = merges[p+clust.merging[i][1]] # cluster 2
900            na = len(a[1])
901            nb = len(b[1])
902            # compute cross-average
903            sum = 0.0
904            for x in a[1]:
905                for y in b[1]:
906                    xx = max(x,y)
907                    yy = min(x,y)
908                    if xx == self.NA:
909                        xx = -1
910                    t = self.corr[(yy,xx)]
911                    l = self.ents[(yy,xx)]
912                    if l > 1e-6:
913                        t /= l
914                    sum += t
915            avg = (a[0]*(na*na-na) + b[0]*(nb*nb-nb) + 2*sum)/(math.pow(na+nb,2)-na-nb)
916            clustercolor = 0.5*(1-avg/d)
917            intercluster = 0.5*(1-sum/(d*na*nb))
918            cols.append((clustercolor,intercluster)) # positive -> red, negative -> blue
919            merges.append((avg,a[1]+b[1]))
920        return cols
921
922
923if __name__== "__main__":
924    t = orange.ExampleTable('d_zoo.tab')
925    im = InteractionMatrix(t,save_data=0, pvalues = 1,iterative_scaling=0)
926
927    # interaction graph
928    f = open('zoo.dot','w')
929    im.exportGraph(f,significant_digits=3,pcutoff = 0.01,absolute_int=1000,best_attributes=100,widget_coloring=0,black_white=1)
930    f.close()
931
932    # interaction clustering
933    import orngCluster
934    (diss,labels) = im.exportDissimilarityMatrix(show_gains=0)
935    c = orngCluster.DHClustering(diss)
936    NCLUSTERS = 6
937    c.domapping(NCLUSTERS)
938    print "Clusters:"
939    for j in range(1,NCLUSTERS+1):
940        print "%d: "%j,
941        # print labels of that cluster
942        for i in range(len(labels)):
943            if c.mapping[i] == j:
944                print labels[i],
945        print
Note: See TracBrowser for help on using the repository browser.