source: orange/orange/OrangeWidgets/Visualize/OWMultiCorrespondenceAnalysis.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 10.3 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""<name>Multi Correspondence Analysis</name>
2<description>Takes a ExampleTable and runs multi correspondence analysis</description>
3<icon>icons/CorrespondenceAnalysis.png</icon>
4<priority>3350</priority>
5<contact>Ales Erjavec (ales.erjavec(@ at @)fri.uni-lj.si</contact>
6"""
7
8import orngEnviron
9
10from OWCorrespondenceAnalysis import *
11
12import OWGUI
13import itertools
14
15def burtTable(data, attributes):
16    """ Construct a Burt table (all values cross-tabulation) from data for attributes
17    Return and ordered list of (attribute, value) pairs and a numpy.ndarray with the tabulations
18    """
19    values = [(attr, value) for attr in attributes for value in attr.values]
20    table = numpy.zeros((len(values), len(values)))
21    counts = [len(attr.values) for attr in attributes]
22    offsets = [sum(counts[: i]) for i in range(len(attributes))]
23    for i in range(len(attributes)):
24        for j in range(i + 1):
25            attr1 = attributes[i]
26            attr2 = attributes[j]
27           
28            cm = orange.ContingencyAttrAttr(attr1, attr2, data)
29            cm = numpy.array([list(row) for row in cm])
30           
31            range1 = range(offsets[i], offsets[i] + counts[i])
32            range2 = range(offsets[j], offsets[j] + counts[j])
33            start1, end1 = offsets[i], offsets[i] + counts[i]
34            start2, end2 = offsets[j], offsets[j] + counts[j]
35           
36            table[start1: end1, start2: end2] += cm
37            if i != j: #also fill the upper part
38                table[start2: end2, start1: end1] += cm.T
39               
40    return values, table
41
42
43class OWMultiCorrespondenceAnalysis(OWCorrespondenceAnalysis):
44    contextHandlers = {"": DomainContextHandler("", ["xPricipalAxis", "yPrincipalAxis",
45                                                     ContextField("allAttributes", DomainContextHandler.RequiredList, selected="selectedAttrs")])}
46   
47    settingsList = OWCorrespondenceAnalysis.settingsList + []
48    def __init__(self, parent=None, signalManager=None, title="Multiple Correspondence Analysis"):
49        OWCorrespondenceAnalysis.__init__(self, parent, signalManager, title)
50       
51        self.inputs = [("Data", ExampleTable, self.setData)]
52        self.outputs = [("Selected Data", ExampleTable), ("Remaining Data", ExampleTable)]
53       
54#        self.allAttrs = []
55        self.allAttributes = []
56        self.selectedAttrs = []
57       
58        #  GUI
59       
60        #  Hide the row and column attributes combo boxes
61        self.colAttrCB.box.hide()
62        self.rowAttrCB.box.hide()
63       
64        box = OWGUI.widgetBox(self.graphTab, "Attributes", addToLayout=False)
65        self.graphTab.layout().insertWidget(0, box)
66        self.graphTab.layout().insertSpacing(1, 4)
67       
68        self.attrsListBox = OWGUI.listBox(box, self, "selectedAttrs", "allAttributes", 
69                                          tooltip="Attributes to include in the analysis",
70                                          callback=self.runCA,
71                                          selectionMode=QListWidget.ExtendedSelection,
72                                          )
73       
74        # Find and hide the "Percent of Column Points" box
75        boxes = self.graphTab.findChildren(QGroupBox)
76        boxes = [box for box in boxes if str(box.title()).strip() == "Percent of Column Points"]
77        if boxes:
78            boxes[0].hide()
79       
80    def setData(self, data=None):
81        self.closeContext("")
82        self.clear()
83        self.data = data
84        self.warning([0])
85        if data is not None:
86            attrs = data.domain.variables + data.domain.getmetas().values()
87            attrs = [attr for attr in attrs if isinstance(attr, orange.EnumVariable)]
88            if not attrs:
89                self.warning(0, "Data has no discrete variables!")
90                self.clear()
91                return
92            self.allAttrs = attrs
93            self.allAttributes = [(attr.name, attr.varType) for attr in attrs]
94            self.selectedAttrs = [0, 1, 2][:len(attrs)]
95           
96            self.openContext("", data)
97            self.runCA()
98           
99    def clear(self):
100        self.data = None
101        self.contingency = None
102        self.allAttributes = []
103        self.xAxisCB.clear()
104        self.yAxisCB.clear()
105        self.contributionInfo.setText("NA\nNA")
106        self.graph.removeDrawingCurves(True, True, True)
107        self.send("Selected Data", None)
108        self.send("Remaining Data", None)
109        self.allAttrs = []
110               
111    def runCA(self):
112        attrs = [self.allAttrs[i] for i in self.selectedAttrs]
113        if not attrs:
114            return
115       
116        self.labels, self.contingency = burtTable(self.data, attrs)
117        self.error(0)
118        try:
119            self.CA = orngCA.CA(self.contingency)
120        except numpy.linalg.LinAlgError, ex:
121            self.error(0, "Could not compute the mapping! " + str(ex))
122            self.graph.removeDrawingCurves(True, True, True)
123            raise
124       
125        self.xAxisCB.clear()
126        self.yAxisCB.clear()
127       
128        self.axisCount = min(self.CA.D.shape)
129        self.xAxisCB.addItems([str(i + 1) for i in range(self.axisCount)])
130        self.yAxisCB.addItems([str(i + 1) for i in range(self.axisCount)])
131       
132        self.xPrincipalAxis = min(self.xPrincipalAxis, self.axisCount - 1)
133        self.yPrincipalAxis = min(self.yPrincipalAxis, self.axisCount - 1)
134       
135        self.updateGraph()
136       
137    def updateGraph(self): 
138        self.graph.removeAllSelections()
139        self.graph.removeDrawingCurves(True, True, True)
140       
141        attrs = [self.allAttrs[i] for i in self.selectedAttrs]
142        colors = dict(zip(attrs, ColorPaletteHSV(len(attrs))))
143       
144        rowcor = self.CA.getPrincipalRowProfilesCoordinates((self.xPrincipalAxis, self.yPrincipalAxis))
145        numCor = max(int(math.ceil(len(rowcor) * float(self.percRow) / 100.0)), 2)
146        indices = self.CA.PointsWithMostInertia(rowColumn=0, axis=(self.xPrincipalAxis, self.yPrincipalAxis))[:numCor]
147        indices = sorted(indices)
148        rowpoints = numpy.array([rowcor[i] for i in indices])
149        rowlabels = [self.labels[i] for i in indices]
150       
151        maxx, maxy = numpy.max(rowpoints, axis=0)
152        minx, miny = numpy.min(rowpoints, axis=0)
153        spanx = maxx - minx or 1.0
154        spany = maxy - miny or 1.0
155       
156        random = numpy.random.mtrand.RandomState(0)
157         
158        if self.jitter > 0:
159            rowpoints[:,0] += random.normal(0, spanx * self.jitter / 100.0, (len(rowpoints),))
160            rowpoints[:,1] += random.normal(0, spany * self.jitter / 100.0, (len(rowpoints),))
161           
162        # Plot the points
163        groups = itertools.groupby(rowlabels, key=lambda label: label[0])
164        counts = [len(attr.values) for attr in attrs]
165        count = 0
166        for attr, labels in groups:
167            labels = list(labels)
168            advance = len(labels) # TODO add shape for each attribute and colors for each value
169            self.graph.addCurve(attr.name, brushColor=colors[attr],
170                                penColor=colors[attr], size=self.pointSize,
171                                xData=list(rowpoints[count: count+advance, 0]),
172                                yData=list(rowpoints[count: count+advance, 1]),
173                                autoScale=True, brushAlpha=self.alpha, enableLegend=True)
174            count += advance
175           
176        for label, point in zip(rowlabels, rowpoints):
177            self.graph.addMarker("%s: %s" % (label[0].name, label[1]), point[0], point[1], alignment=Qt.AlignCenter | Qt.AlignBottom)
178           
179           
180           
181        if self.jitter > 0:
182        # Update min, max, span values again due to jittering
183            maxx, maxy = numpy.max(rowpoints, axis=0)
184            minx, miny = numpy.min(rowpoints, axis=0)
185            spanx = maxx - minx or 1.0
186            spany = maxy - miny or 1.0
187       
188        self.graph.setAxisScale(QwtPlot.xBottom, minx - spanx * 0.1, maxx + spanx * 0.1)
189        self.graph.setAxisScale(QwtPlot.yLeft, miny - spany * 0.1, maxy + spany * 0.1)
190       
191        self.graph.setAxisTitle(QwtPlot.xBottom, "Axis %i" % (self.xPrincipalAxis + 1))
192        self.graph.setAxisTitle(QwtPlot.yLeft, "Axis %i" % (self.yPrincipalAxis + 1))
193       
194        #  Store labeled points for selection
195        self.rowPointsLabeled = zip(rowpoints, rowlabels) 
196       
197        inertia = self.CA.InertiaOfAxis(1)
198        fmt = """<table><tr><td>Axis %i:</td><td>%.3f%%</td></tr>
199        <tr><td>Axis %i:</td><td>%.3f%%</td></tr></table>
200        """
201        self.contributionInfo.setText(fmt % (self.xPrincipalAxis + 1, inertia[self.xPrincipalAxis],
202                                             self.yPrincipalAxis + 1, inertia[self.yPrincipalAxis]))
203        self.graph.replot()
204       
205    def sendData(self, *args):
206        def selectedLabels(points_labels):
207            return [label for (x, y), label in points_labels if self.graph.isPointSelected(x, y)]
208       
209        if self.contingency is not None and self.data:
210            rowLabels = set(selectedLabels(self.rowPointsLabeled))
211            selected = []
212            remaining = []
213           
214            groups = itertools.groupby(sorted(rowLabels), key=lambda label: label[0])
215            groups = [(attr, [value for _, value in labels]) for attr, labels in groups]
216           
217            def testAttr(ex, attr, values):
218                if values:
219                    return str(ex[attr]) in values
220                else:
221                    return True
222               
223            def testAll(ex):
224                return reduce(bool.__and__, [testAttr(ex, attr, values) for attr, values in groups], bool(groups))
225               
226            for ex in self.data:
227                if testAll(ex):
228                    selected.append(ex)
229                else:
230                    remaining.append(ex)
231                 
232            selected = orange.ExampleTable(self.data.domain, selected) if selected else \
233                            orange.ExampleTable(self.data.domain)
234           
235            remaining = orange.ExampleTable(self.data.domain, remaining) if remaining else \
236                            orange.ExampleTable(self.data.domain)
237                       
238            self.send("Selected Data", selected)
239            self.send("Remaining Data", remaining)
240        else:
241            self.send("Selected Data", None)
242            self.send("Remaining Data", None)
Note: See TracBrowser for help on using the repository browser.