source: orange/orange/OrangeWidgets/Visualize/OWCorrespondenceAnalysis.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 14.1 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""<name>Correspondence Analysis</name>
2<description>Takes a ExampleTable and runs correspondence analysis</description>
3<icon>icons/CorrespondenceAnalysis.png</icon>
4<priority>3300</priority>
5<contact>Ales Erjavec (ales.erjavec(@ at @)fri.uni-lj.si</contact>
6"""
7
8from OWWidget import *
9from OWGraph import *
10from OWToolbars import ZoomSelectToolbar
11from OWColorPalette import ColorPaletteHSV
12
13import OWGUI
14import orngCA
15
16import math
17
18class OWCorrespondenceAnalysis(OWWidget):
19    contextHandlers = {"": DomainContextHandler("", ["colAttr", "rowAttr", "xPricipalAxis", "yPrincipalAxis"])}
20    settingsList = ["pointSize", "alpha", "jitter", "showGridlines"]
21   
22    def __init__(self, parent=None, signalManager=None, name="Correspondence Analysis"):
23        OWWidget.__init__(self, parent, signalManager, name, wantGraph=True)
24       
25        self.inputs = [("Data", ExampleTable, self.setData)]
26        self.outputs = [("Selected Data", ExampleTable), ("Remaining Data", ExampleTable)]
27       
28        self.colAttr = 0
29        self.rowAttr = 1
30        self.xPrincipalAxis = 0
31        self.yPrincipalAxis = 1
32        self.pointSize = 6
33        self.alpha = 240
34        self.jitter = 0
35        self.showGridlines = 0
36        self.percCol = 100
37        self.percRow = 100
38        self.autoSend = 0
39       
40        self.loadSettings()
41       
42        # GUI
43        self.graph = OWGraph(self)
44        self.graph.sendData = self.sendData
45        self.mainArea.layout().addWidget(self.graph)
46       
47       
48        self.controlAreaTab = OWGUI.tabWidget(self.controlArea)
49        # Graph tab
50        self.graphTab = graphTab = OWGUI.createTabPage(self.controlAreaTab, "Graph")
51        self.colAttrCB = OWGUI.comboBox(graphTab, self, "colAttr", "Column Attribute", 
52                                        tooltip="Column attribute",
53                                        callback=self.runCA)
54       
55        self.rowAttrCB = OWGUI.comboBox(graphTab, self, "rowAttr", "Row Attribute", 
56                                        tooltip="Row attribute",
57                                        callback=self.runCA)
58       
59        self.xAxisCB = OWGUI.comboBox(graphTab, self, "xPrincipalAxis", "Principal Axis X",
60                                      tooltip="Principal axis X",
61                                      callback=self.updateGraph)
62       
63        self.yAxisCB = OWGUI.comboBox(graphTab, self, "yPrincipalAxis", "Principal Axis Y",
64                                      tooltip="Principal axis Y",
65                                      callback=self.updateGraph)
66       
67        box = OWGUI.widgetBox(graphTab, "Contribution to Inertia")
68        self.contributionInfo = OWGUI.widgetLabel(box, "NA\nNA")
69       
70        OWGUI.hSlider(graphTab, self, "percCol", "Percent of Column Points", 1, 100, 1,
71                      callback=self.updateGraph,
72                      tooltip="The percent of column points with the largest contribution to inertia")
73       
74        OWGUI.hSlider(graphTab, self, "percRow", "Percent of Row Points", 1, 100, 1,
75                      callback=self.updateGraph,
76                      tooltip="The percent of row points with the largest contribution to inertia")
77       
78        self.zoomSelect = ZoomSelectToolbar(self, graphTab, self.graph, self.autoSend)
79        OWGUI.rubber(graphTab)
80       
81        # Settings tab
82        self.settingsTab = settingsTab = OWGUI.createTabPage(self.controlAreaTab, "Settings")
83        OWGUI.hSlider(settingsTab, self, "pointSize", "Point Size", 3, 20, step=1,
84                      callback=self.setPointSize)
85       
86        OWGUI.hSlider(settingsTab, self, "alpha", "Transparancy", 1, 255, step=1,
87                      callback=self.updateAlpha)
88       
89        OWGUI.hSlider(settingsTab, self, "jitter", "Jitter Points", 0, 20, step=1,
90                      callback=self.updateGraph)
91       
92        box = OWGUI.widgetBox(settingsTab, "General Settings")
93        OWGUI.checkBox(box, self, "showGridlines", "Show gridlines",
94                       tooltip="Show gridlines in the plot.",
95                       callback=self.updateGridlines)
96        OWGUI.rubber(settingsTab)
97       
98        self.connect(self.graphButton, SIGNAL("clicked()"), self.graph.saveToFile)
99       
100        self.contingency = None
101        self.contColAttr = None
102        self.contRowAttr = None
103       
104        self.resize(800, 600)
105       
106    def setData(self, data=None):
107        self.closeContext("")
108        self.clear()
109        self.data = data
110        self.warning([0])
111        if data is not None:
112            attrs = data.domain.variables + data.domain.getmetas().values()
113            attrs = [attr for attr in attrs if isinstance(attr, orange.EnumVariable)]
114            if not attrs:
115                self.warning(0, "Data has no discrete variables!")
116                self.clear()
117                return
118            self.allAttrs = attrs
119            self.colAttrCB.clear()
120            self.rowAttrCB.clear()
121            icons = OWGUI.getAttributeIcons()
122            for attr in attrs:
123                self.colAttrCB.addItem(QIcon(icons[attr.varType]), attr.name)
124                self.rowAttrCB.addItem(QIcon(icons[attr.varType]), attr.name)
125               
126            self.colAttr = max(min(len(attrs) - 1, self.colAttr), 0)
127            self.rowAttr = max(min(len(attrs) - 1, self.rowAttr), min(1, len(attrs) - 1))
128           
129            self.openContext("", data)
130            self.runCA()
131           
132    def clear(self):
133        self.data = None
134        self.colAttrCB.clear()
135        self.rowAttrCB.clear()
136        self.xAxisCB.clear()
137        self.yAxisCB.clear()
138        self.contributionInfo.setText("NA\nNA")
139        self.graph.removeDrawingCurves(True, True, True)
140        self.send("Selected Data", None)
141        self.send("Remaining Data", None)
142        self.allAttrs = []
143       
144    def runCA(self):
145        self.contColAttr = colAttr = self.allAttrs[self.colAttr]
146        self.contRowAttr = rowAttr = self.allAttrs[self.rowAttr]
147        self.contingency = orange.ContingencyAttrAttr(rowAttr, colAttr, self.data)
148        self.error(0)
149        try:
150            self.CA = orngCA.CA([[c for c in row] for row in self.contingency])
151        except numpy.linalg.LinAlgError:
152            self.error(0, "Could not compute the mapping! " + str(ex))
153            self.graph.removeDrawingCurves(True, True, True)
154            raise
155           
156        self.rowItems = [s for s, v in self.contingency.outerDistribution.items()]
157        self.colItems = [s for s, v in self.contingency.innerDistribution.items()]
158       
159        self.xAxisCB.clear()
160        self.yAxisCB.clear()
161       
162        self.axisCount = min(self.CA.D.shape)
163        self.xAxisCB.addItems([str(i + 1) for i in range(self.axisCount)])
164        self.yAxisCB.addItems([str(i + 1) for i in range(self.axisCount)])
165       
166        self.xPrincipalAxis = min(self.xPrincipalAxis, self.axisCount - 1)
167        self.yPrincipalAxis = min(self.yPrincipalAxis, self.axisCount - 1)
168       
169        self.updateGraph()
170       
171    def updateGraph(self): 
172        self.graph.removeAllSelections()
173        self.graph.removeDrawingCurves(True, True, True)
174       
175        colors = ColorPaletteHSV(2)
176       
177        rowcor = self.CA.getPrincipalRowProfilesCoordinates((self.xPrincipalAxis, self.yPrincipalAxis))
178        numCor = int(math.ceil(len(rowcor) * float(self.percRow) / 100.0))
179        indices = self.CA.PointsWithMostInertia(rowColumn=0, axis=(self.xPrincipalAxis, self.yPrincipalAxis))[:numCor]
180        rowpoints = numpy.array([rowcor[i] for i in indices])
181        rowlabels = [self.rowItems[i] for i in indices]
182           
183       
184        colcor = self.CA.getPrincipalColProfilesCoordinates((self.xPrincipalAxis, self.yPrincipalAxis))
185        numRow = int(math.ceil(len(colcor) * float(self.percCol) / 100.0))
186        indices = self.CA.PointsWithMostInertia(rowColumn=1, axis=(self.xPrincipalAxis, self.yPrincipalAxis))[:numRow]
187        colpoints = numpy.array([colcor[i] for i in indices])
188        collabels = [self.colItems[i] for i in indices]
189       
190        vstack = ((rowpoints,) if rowpoints.size else ()) + \
191                 ((colpoints,) if colpoints.size else ())
192        allpoints = numpy.vstack(vstack)
193        maxx, maxy = numpy.max(allpoints, axis=0)
194        minx, miny = numpy.min(allpoints, axis=0)
195        spanx = maxx - minx
196        spany = maxy - miny
197       
198        random = numpy.random.mtrand.RandomState(0)
199         
200        if self.jitter > 0:
201            rowpoints[:,0] += random.normal(0, spanx * self.jitter / 100.0, (len(rowpoints),))
202            rowpoints[:,1] += random.normal(0, spany * self.jitter / 100.0, (len(rowpoints),))
203           
204            colpoints[:,0] += random.normal(0, spanx * self.jitter / 100.0, (len(colpoints),))
205            colpoints[:,1] += random.normal(0, spany * self.jitter / 100.0, (len(colpoints),))
206           
207        # Plot the points
208        self.graph.addCurve("Row points", brushColor=colors[0],
209                            penColor=colors[0], size=self.pointSize,
210                            enableLegend=True, xData=rowpoints[:, 0], yData=rowpoints[:, 1],
211                            autoScale=True, brushAlpha=self.alpha)
212       
213        for label, point in zip(rowlabels, rowpoints):
214            self.graph.addMarker(label, point[0], point[1], alignment=Qt.AlignCenter | Qt.AlignBottom)
215           
216        self.graph.addCurve("Column points", brushColor=colors[1],
217                            penColor=colors[1], size=self.pointSize,
218                            enableLegend=True, xData=colpoints[:, 0], yData=colpoints[:, 1],
219                            autoScale=True, brushAlpha=self.alpha)
220       
221        for label, point in zip(collabels, colpoints):
222            self.graph.addMarker(label, point[0], point[1], alignment=Qt.AlignCenter | Qt.AlignBottom)
223           
224        if self.jitter > 0:
225        # Update min, max, span values again due to jittering
226            vstack = ((rowpoints,) if rowpoints.size else ()) + \
227                     ((colpoints,) if colpoints.size else ())
228            allpoints = numpy.vstack(vstack)
229            maxx, maxy = numpy.max(allpoints, axis=0)
230            minx, miny = numpy.min(allpoints, axis=0)
231            spanx = maxx - minx
232            spany = maxy - miny
233       
234        self.graph.setAxisScale(QwtPlot.xBottom, minx - spanx * 0.05, maxx + spanx * 0.05)
235        self.graph.setAxisScale(QwtPlot.yLeft, miny - spany * 0.05, maxy + spany * 0.05)
236       
237        self.graph.setAxisTitle(QwtPlot.xBottom, "Axis %i" % (self.xPrincipalAxis + 1))
238        self.graph.setAxisTitle(QwtPlot.yLeft, "Axis %i" % (self.yPrincipalAxis + 1))
239       
240        #  Store labeled points for selection
241        self.colPointsLabeled = zip(colpoints, collabels)
242        self.rowPointsLabeled = zip(rowpoints, rowlabels) 
243       
244        inertia = self.CA.InertiaOfAxis(1)
245        fmt = """<table><tr><td>Axis %i:</td><td>%.3f%%</td></tr>
246        <tr><td>Axis %i:</td><td>%.3f%%</td></tr></table>
247        """
248        self.contributionInfo.setText(fmt % (self.xPrincipalAxis + 1, inertia[self.xPrincipalAxis],
249                                             self.yPrincipalAxis + 1, inertia[self.yPrincipalAxis]))
250        self.graph.replot()
251       
252    def setPointSize(self):
253        for curve in self.graph.itemList():
254            if isinstance(curve, QwtPlotCurve):
255                symbol = curve.symbol()
256                symbol.setSize(self.pointSize)
257                if QWT_VERSION_STR >= "5.2":
258                    curve.setSymbol(symbol)
259        self.graph.replot()
260   
261    def updateAlpha(self):
262        for curve in self.graph.itemList():
263            if isinstance(curve, QwtPlotCurve):
264                brushColor = curve.symbol().brush().color()
265                penColor = curve.symbol().pen().color()
266                brushColor.setAlpha(self.alpha)
267                brush = QBrush(curve.symbol().brush())
268                brush.setColor(brushColor)
269                penColor.setAlpha(self.alpha)
270                symbol = curve.symbol()
271                symbol.setBrush(brush)
272                symbol.setPen(QPen(penColor))
273                if QWT_VERSION_STR >= "5.2":
274                    curve.setSymbol(symbol)
275        self.graph.replot()
276       
277    def updateGridlines(self):
278        self.graph.enableGridXB(self.showGridlines)
279        self.graph.enableGridYL(self.showGridlines)
280       
281    def sendData(self, *args):
282        def selectedLabels(points_labels):
283            return [label for (x, y), label in points_labels if self.graph.isPointSelected(x, y)]
284       
285        if self.contingency and self.data:
286            colLabels = set(selectedLabels(self.colPointsLabeled))
287            rowLabels = set(selectedLabels(self.rowPointsLabeled))
288            colAttr = self.allAttrs[self.colAttr]
289            rowAttr = self.allAttrs[self.rowAttr]
290            selected = []
291            remaining = []
292           
293            if colLabels and rowLabels:
294                def test(ex):
295                    return str(ex[colAttr]) in colLabels and str(ex[rowAttr]) in rowLabels
296            elif colLabels or rowLabels:
297                def test(ex):
298                    return str(ex[colAttr]) in colLabels or str(ex[rowAttr]) in rowLabels
299            else:
300                def test(ex):
301                    return False
302               
303            for ex in self.data:
304                if test(ex):
305                    selected.append(ex)
306                else:
307                    remaining.append(ex)
308                 
309            selected = orange.ExampleTable(self.data.domain, selected) if selected else \
310                            orange.ExampleTable(self.data.domain)
311           
312            remaining = orange.ExampleTable(self.data.domain, remaining) if remaining else \
313                            orange.ExampleTable(self.data.domain)
314                       
315            self.send("Selected Data", selected)
316            self.send("Remaining Data", remaining)
317        else:
318            self.send("Selected Data", None)
319            self.send("Remaining Data", None)
320
321   
322if __name__ == "__main__":
323    app = QApplication([])
324    w = OWCorrespondenceAnalysis()
325    data = orange.ExampleTable("../doc/datasets/adult-sample.tab")
326    w.setData(data)
327    w.show()
328    app.exec_()
329    w.saveSettings()
330       
331   
332       
333       
Note: See TracBrowser for help on using the repository browser.