source: orange/Orange/OrangeWidgets/Visualize/OWLinProjGraph.py @ 10475:61c2249d671f

Revision 10475:61c2249d671f, 39.5 KB checked in by Matija Polajnar <matija.polajnar@…>, 2 years ago (diff)

Major refactorization of linear projections, fixing some bugs in the process.

Line 
1from OWGraph import *
2from copy import copy
3import time
4from operator import add
5from math import *
6from orngScaleLinProjData import *
7import orngVisFuncts
8import OWColorPalette
9from OWGraphTools import UnconnectedLinesCurve
10
11# indices in curveData
12SYMBOL = 0
13PENCOLOR = 1
14BRUSHCOLOR = 2
15
16LINE_TOOLTIPS = 0
17VISIBLE_ATTRIBUTES = 1
18ALL_ATTRIBUTES = 2
19
20TOOLTIPS_SHOW_DATA = 0
21TOOLTIPS_SHOW_SPRINGS = 1
22
23###########################################################################################
24##### CLASS : OWLINPROJGRAPH
25###########################################################################################
26class OWLinProjGraph(OWGraph, orngScaleLinProjData):
27    def __init__(self, widget, parent = None, name = "None"):
28        OWGraph.__init__(self, parent, name)
29        orngScaleLinProjData.__init__(self)
30
31        self.totalPossibilities = 0 # a variable used in optimization - tells us the total number of different attribute positions
32        self.triedPossibilities = 0 # how many possibilities did we already try
33        self.p = None
34
35        self.dataMap = {}        # each key is of form: "xVal-yVal", where xVal and yVal are discretized continuous values. Value of each key has form: (x,y, HSVValue, [data vals])
36        self.tooltipCurves = []
37        self.tooltipMarkers   = []
38        self.widget = widget
39
40        # moving anchors manually
41        self.shownAttributes = []
42        self.selectedAnchorIndex = None
43
44        self.hideRadius = 0
45        self.showAnchors = 1
46        self.showValueLines = 0
47        self.valueLineLength = 5
48
49        self.onlyOnePerSubset = 1
50        self.showLegend = 1
51        self.useDifferentSymbols = 0
52        self.useDifferentColors = 1
53        self.tooltipKind = 0        # index in ["Show line tooltips", "Show visible attributes", "Show all attributes"]
54        self.tooltipValue = 0       # index in ["Tooltips show data values", "Tooltips show spring values"]
55        self.scaleFactor = 1.0
56        self.showAttributeNames = 1
57
58        self.showProbabilities = 0
59        self.squareGranularity = 3
60        self.spaceBetweenCells = 1
61
62        self.showKNN = 0   # widget sets this to 1 or 2 if you want to see correct or wrong classifications
63        self.insideColors = None
64        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
65
66        self.enableXaxis(0)
67        self.enableYLaxis(0)
68        self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13, 1)
69        self.setAxisScale(QwtPlot.yLeft, -1.13, 1.13, 1)
70
71    def setData(self, data, subsetData = None, **args):
72        OWGraph.setData(self, data)
73        orngScaleLinProjData.setData(self, data, subsetData, **args)
74        #self.anchorData = []
75
76        self.setAxisScale(QwtPlot.yLeft, -1.13, 1.13, 1)
77        self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13, 1)
78
79        if data and data.domain.classVar and data.domain.classVar.varType == orange.VarTypes.Continuous:
80            self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13 + 0.1, 1)   # if we have a continuous class we need a bit more space on the right to show a color legend
81
82    # ####################################################################
83    # update shown data. Set labels, coloring by className ....
84    def updateData(self, labels = None, setAnchors = 0, **args):
85        self.removeDrawingCurves()  # my function, that doesn't delete selection curves
86        #self.removeCurves()
87        self.removeMarkers()
88        self.tooltipMarkers = []
89
90        self.__dict__.update(args)
91        if labels == None: labels = [anchor[2] for anchor in self.anchorData]
92        self.shownAttributes = labels
93        self.dataMap = {}   # dictionary with keys of form "x_i-y_i" with values (x_i, y_i, color, data)
94        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
95
96        if not self.haveData or len(labels) < 3:
97            self.anchorData = []
98            self.updateLayout()
99            return
100
101        if setAnchors or (args.has_key("XAnchors") and args.has_key("YAnchors")):
102            self.potentialsBmp = None
103            self.setAnchors(args.get("XAnchors"), args.get("YAnchors"), labels)
104            #self.anchorData = self.createAnchors(len(labels), labels)    # used for showing tooltips
105
106        indices = [self.attributeNameIndex[anchor[2]] for anchor in self.anchorData]  # store indices to shown attributes
107
108        # do we want to show anchors and their labels
109        if self.showAnchors:
110            if self.hideRadius > 0:
111                xdata = self.createXAnchors(100)*(self.hideRadius / 10)
112                ydata = self.createYAnchors(100)*(self.hideRadius / 10)
113                self.addCurve("hidecircle", QColor(200,200,200), QColor(200,200,200), 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = xdata.tolist() + [xdata[0]], yData = ydata.tolist() + [ydata[0]])
114
115            # draw dots at anchors
116            shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
117
118            if not self.normalizeExamples:
119                r=self.hideRadius**2/100
120                for i,(x,y,a) in enumerate(shownAnchorData):
121                    self.addCurve("l%i" % i, QColor(160, 160, 160), QColor(160, 160, 160), 10, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0, x], yData = [0, y], showFilledSymbols = 1, lineWidth=2)
122                    if self.showAttributeNames:
123                        self.addMarker(a, x*1.07, y*1.04, Qt.AlignCenter, bold=1)
124            else:
125                XAnchors = [a[0] for a in shownAnchorData]
126                YAnchors = [a[1] for a in shownAnchorData]
127                self.addCurve("dots", QColor(160,160,160), QColor(160,160,160), 10, style = QwtPlotCurve.NoCurve, symbol = QwtSymbol.Ellipse, xData = XAnchors, yData = YAnchors, showFilledSymbols = 1)
128
129                # draw text at anchors
130                if self.showAttributeNames:
131                    for x, y, a in shownAnchorData:
132                        self.addMarker(a, x*1.07, y*1.04, Qt.AlignCenter, bold = 1)
133
134        if self.showAnchors and self.normalizeExamples:
135            # draw "circle"
136            xdata = self.createXAnchors(100)
137            ydata = self.createYAnchors(100)
138            self.addCurve("circle", QColor(Qt.black), QColor(Qt.black), 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = xdata.tolist() + [xdata[0]], yData = ydata.tolist() + [ydata[0]])
139
140        self.potentialsClassifier = None # remove the classifier so that repaint won't recompute it
141        self.updateLayout()
142
143        if self.dataHasDiscreteClass:
144            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
145
146        useDifferentSymbols = self.useDifferentSymbols and self.dataHasDiscreteClass and len(self.dataDomain.classVar.values) < len(self.curveSymbols)
147        dataSize = len(self.rawData)
148        validData = self.getValidList(indices)
149        transProjData = self.createProjectionAsNumericArray(indices, validData = validData, scaleFactor = self.scaleFactor, normalize = self.normalizeExamples, jitterSize = -1, useAnchorData = 1, removeMissingData = 0)
150        if transProjData == None:
151            return
152        projData = transProjData.T
153        x_positions = projData[0]
154        y_positions = projData[1]
155        xPointsToAdd = {}
156        yPointsToAdd = {}
157
158
159        if self.showProbabilities and self.haveData and self.dataHasClass:
160            # construct potentialsClassifier from unscaled positions
161            domain = orange.Domain([self.dataDomain[i].name for i in indices]+[self.dataDomain.classVar.name], self.dataDomain)
162            offsets = [self.attrValues[self.attributeNames[i]][0] for i in indices]
163            normalizers = [self.getMinMaxVal(i) for i in indices]
164            selectedData = numpy.take(self.originalData, indices, axis = 0)
165            averages = numpy.average(numpy.compress(validData, selectedData, axis=1), 1)
166            classData = numpy.compress(validData, self.originalData[self.dataClassIndex])
167            if classData.any():
168                self.potentialsClassifier = orange.P2NN(domain, numpy.transpose(numpy.array([numpy.compress(validData, self.unscaled_x_positions), numpy.compress(validData, self.unscaled_y_positions), classData])), self.anchorData, offsets, normalizers, averages, self.normalizeExamples, law=1)
169            else:
170                self.potentialsClassifier = None
171            self.potentialsImage = None
172
173
174        # ##############################################################
175        # show model quality
176        # ##############################################################
177        if self.insideColors != None or self.showKNN and self.haveData:
178            # if we want to show knn classifications of the examples then turn the projection into example table and run knn
179            if self.insideColors:
180                insideData, stringData = self.insideColors
181            else:
182                shortData = self.createProjectionAsExampleTable([self.attributeNameIndex[attr] for attr in labels], useAnchorData = 1)
183                predictions, probabilities = self.widget.vizrank.kNNClassifyData(shortData)
184                if self.showKNN == 2: insideData, stringData = [1.0 - val for val in predictions], "Probability of wrong classification = %.2f%%"
185                else:                 insideData, stringData = predictions, "Probability of correct classification = %.2f%%"
186
187            if self.dataHasDiscreteClass:        classColors = self.discPalette
188            elif self.dataHasContinuousClass:    classColors = self.contPalette
189
190            if len(insideData) != len(self.rawData):
191                j = 0
192                for i in range(len(self.rawData)):
193                    if not validData[i]: continue
194                    if self.dataHasClass:
195                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[j])
196                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
197                    else:
198                        fillColor = edgeColor = (0,0,0)
199                    self.addCurve(str(i), QColor(*fillColor+ (self.alphaValue,)), QColor(*edgeColor+ (self.alphaValue,)), self.pointWidth, xData = [x_positions[i]], yData = [y_positions[i]])
200                    if self.showValueLines:
201                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
202                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[j]))
203                    j+= 1
204            else:
205                for i in range(len(self.rawData)):
206                    if not validData[i]: continue
207                    if self.dataHasClass:
208                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[i])
209                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
210                    else:
211                        fillColor = edgeColor = (0,0,0)
212                    self.addCurve(str(i), QColor(*fillColor+ (self.alphaValue,)), QColor(*edgeColor+ (self.alphaValue,)), self.pointWidth, xData = [x_positions[i]], yData = [y_positions[i]])
213                    if self.showValueLines:
214                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
215                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[i]))
216
217        # ##############################################################
218        # do we have a subset data to show?
219        # ##############################################################
220        elif self.haveSubsetData:
221            shownSubsetCount = 0
222            subsetIdsToDraw = dict([(example.id,1) for example in self.rawSubsetData])
223
224            # draw the rawData data set. examples that exist also in the subset data draw full, other empty
225            for i in range(dataSize):
226                if not validData[i]: continue
227                if subsetIdsToDraw.has_key(self.rawData[i].id):
228                    continue
229
230                if self.dataHasDiscreteClass and self.useDifferentColors:
231                    newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
232                elif self.dataHasContinuousClass and self.useDifferentColors:
233                    newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
234                else:
235                    newColor = (0,0,0)
236
237                if self.useDifferentSymbols and self.dataHasDiscreteClass:
238                    curveSymbol = self.curveSymbols[int(self.originalData[self.dataClassIndex][i])]
239                else:
240                    curveSymbol = self.curveSymbols[0]
241
242                if not xPointsToAdd.has_key((newColor, curveSymbol,0)):
243                    xPointsToAdd[(newColor, curveSymbol,0)] = []
244                    yPointsToAdd[(newColor, curveSymbol,0)] = []
245                xPointsToAdd[(newColor, curveSymbol,0)].append(x_positions[i])
246                yPointsToAdd[(newColor, curveSymbol,0)].append(y_positions[i])
247                if self.showValueLines:
248                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
249
250                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
251
252            # if we have a data subset that contains examples that don't exist in the original dataset we show them here
253            XAnchors = numpy.array([val[0] for val in self.anchorData])
254            YAnchors = numpy.array([val[1] for val in self.anchorData])
255            anchorRadius = numpy.sqrt(XAnchors*XAnchors + YAnchors*YAnchors)
256            validSubData = self.getValidSubsetList(indices)
257            projSubData = self.createProjectionAsNumericArray(indices, validData = validSubData, scaleFactor = self.scaleFactor, normalize = self.normalizeExamples, jitterSize = -1, useAnchorData = 1, removeMissingData = 0, useSubsetData = 1).T
258            sub_x_positions = projSubData[0]
259            sub_y_positions = projSubData[1]
260
261            for i in range(len(self.rawSubsetData)):
262                if not validSubData[i]: continue    # check if has missing values
263
264                if not self.dataHasClass or self.rawSubsetData[i].getclass().isSpecial():
265                    newColor = (0,0,0)
266                else:
267                    if self.dataHasDiscreteClass:
268                        newColor = self.discPalette.getRGB(self.originalSubsetData[self.dataClassIndex][i])
269                    else:
270                        newColor = self.contPalette.getRGB(self.noJitteringScaledSubsetData[self.dataClassIndex][i])
271
272                if self.useDifferentSymbols and self.dataHasDiscreteClass and self.validSubsetDataArray[self.dataClassIndex][i]:
273                    curveSymbol = self.curveSymbols[int(self.originalSubsetData[self.dataClassIndex][i])]
274                else:
275                    curveSymbol = self.curveSymbols[0]
276
277                if not xPointsToAdd.has_key((newColor, curveSymbol, 1)):
278                    xPointsToAdd[(newColor, curveSymbol, 1)] = []
279                    yPointsToAdd[(newColor, curveSymbol, 1)] = []
280                xPointsToAdd[(newColor, curveSymbol, 1)].append(sub_x_positions[i])
281                yPointsToAdd[(newColor, curveSymbol, 1)].append(sub_y_positions[i])
282
283        elif not self.dataHasClass:
284            xs = []; ys = []
285            for i in range(dataSize):
286                if not validData[i]: continue
287                xs.append(x_positions[i])
288                ys.append(y_positions[i])
289                self.addTooltipKey(x_positions[i], y_positions[i], QColor(Qt.black), i)
290                if self.showValueLines:
291                    self.addValueLineCurve(x_positions[i], y_positions[i], (0,0,0), i, indices)
292            self.addCurve(str(1), QColor(0,0,0,self.alphaValue), QColor(0,0,0,self.alphaValue), self.pointWidth, symbol = self.curveSymbols[0], xData = xs, yData = ys, penAlpha = self.alphaValue, brushAlpha = self.alphaValue)
293
294        # ##############################################################
295        # CONTINUOUS class
296        # ##############################################################
297        elif self.dataHasContinuousClass:
298            for i in range(dataSize):
299                if not validData[i]: continue
300                newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
301                self.addCurve(str(i), QColor(*newColor+ (self.alphaValue,)), QColor(*newColor+ (self.alphaValue,)), self.pointWidth, symbol = QwtSymbol.Ellipse, xData = [x_positions[i]], yData = [y_positions[i]])
302                if self.showValueLines:
303                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
304                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
305
306        # ##############################################################
307        # DISCRETE class
308        # ##############################################################
309        elif self.dataHasDiscreteClass:
310            for i in range(dataSize):
311                if not validData[i]: continue
312                if self.useDifferentColors: newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
313                else:                       newColor = (0,0,0)
314                if self.useDifferentSymbols: curveSymbol = self.curveSymbols[int(self.originalData[self.dataClassIndex][i])]
315                else:                        curveSymbol = self.curveSymbols[0]
316                if not xPointsToAdd.has_key((newColor, curveSymbol, self.showFilledSymbols)):
317                    xPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)] = []
318                    yPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)] = []
319                xPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)].append(x_positions[i])
320                yPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)].append(y_positions[i])
321                if self.showValueLines:
322                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
323                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
324
325        # first draw value lines
326        if self.showValueLines:
327            for i, color in enumerate(self.valueLineCurves[0].keys()):
328                curve = UnconnectedLinesCurve("", QPen(QColor(*color + (self.alphaValue,))), self.valueLineCurves[0][color], self.valueLineCurves[1][color])
329                curve.attach(self)
330
331        # draw all the points with a small number of curves
332        for i, (color, symbol, showFilled) in enumerate(xPointsToAdd.keys()):
333            xData = xPointsToAdd[(color, symbol, showFilled)]
334            yData = yPointsToAdd[(color, symbol, showFilled)]
335            self.addCurve(str(i), QColor(*color + (self.alphaValue,)), QColor(*color + (self.alphaValue,)), self.pointWidth, symbol = symbol, xData = xData, yData = yData, showFilledSymbols = showFilled)
336
337        # ##############################################################
338        # draw the legend
339        # ##############################################################
340        if self.showLegend:
341            # show legend for discrete class
342            if self.dataHasDiscreteClass:
343                self.addMarker(self.dataDomain.classVar.name, 0.87, 1.05, Qt.AlignLeft | Qt.AlignVCenter)
344
345                classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
346                for index in range(len(classVariableValues)):
347                    if self.useDifferentColors: color = QColor(self.discPalette[index])
348                    else:                       color = QColor(Qt.black)
349                    y = 1.0 - index * 0.05
350
351                    if not self.useDifferentSymbols:  curveSymbol = self.curveSymbols[0]
352                    else:                             curveSymbol = self.curveSymbols[index]
353
354                    self.addCurve(str(index), color, color, self.pointWidth, symbol = curveSymbol, xData = [0.95], yData = [y], penAlpha = self.alphaValue, brushAlpha = self.alphaValue)
355                    self.addMarker(classVariableValues[index], 0.90, y, Qt.AlignLeft | Qt.AlignVCenter)
356            # show legend for continuous class
357            elif self.dataHasContinuousClass:
358                xs = [1.15, 1.20, 1.20, 1.15]
359                count = 200
360                height = 2 / float(count)
361                for i in range(count):
362                    y = -1.0 + i*2.0/float(count)
363                    col = self.contPalette[i/float(count)]
364                    col.setAlpha(self.alphaValue)
365                    PolygonCurve(QPen(col), QBrush(col), xData = xs, yData = [y,y, y+height, y+height]).attach(self)
366
367                # add markers for min and max value of color attribute
368                [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name]
369                self.addMarker("%s = %%.%df" % (self.dataDomain.classVar.name, self.dataDomain.classVar.numberOfDecimals) % (minVal), xs[0] - 0.02, -1.0 + 0.04, Qt.AlignLeft)
370                self.addMarker("%s = %%.%df" % (self.dataDomain.classVar.name, self.dataDomain.classVar.numberOfDecimals) % (maxVal), xs[0] - 0.02, +1.0 - 0.04, Qt.AlignLeft)
371
372        self.replot()
373
374
375    # ##############################################################
376    # create a dictionary value for the data point
377    # this will enable to show tooltips faster and to make selection of examples available
378    def addTooltipKey(self, x, y, color, index, extraString = None):
379        dictValue = "%.1f-%.1f"%(x, y)
380        if not self.dataMap.has_key(dictValue): self.dataMap[dictValue] = []
381        self.dataMap[dictValue].append((x, y, color, index, extraString))
382
383
384    def addValueLineCurve(self, x, y, color, exampleIndex, attrIndices):
385        XAnchors = numpy.array([val[0] for val in self.anchorData])
386        YAnchors = numpy.array([val[1] for val in self.anchorData])
387        xs = numpy.array([x] * len(self.anchorData))
388        ys = numpy.array([y] * len(self.anchorData))
389        dists = numpy.sqrt((XAnchors-xs)**2 + (YAnchors-ys)**2)
390        xVect = 0.01 * self.valueLineLength * (XAnchors - xs) / dists
391        yVect = 0.01 * self.valueLineLength * (YAnchors - ys) / dists
392        exVals = [self.noJitteringScaledData[attrInd, exampleIndex] for attrInd in attrIndices]
393
394        xs = []; ys = []
395        for i in range(len(exVals)):
396            xs += [x, x + xVect[i]*exVals[i]]
397            ys += [y, y + yVect[i]*exVals[i]]
398        self.valueLineCurves[0][color] = self.valueLineCurves[0].get(color, []) + xs
399        self.valueLineCurves[1][color] = self.valueLineCurves[1].get(color, []) + ys
400
401
402    def mousePressEvent(self, e):
403        if self.manualPositioning:
404            self.mouseCurrentlyPressed = 1
405            self.selectedAnchorIndex = None
406            if not self.normalizeExamples:
407                marker, dist = self.closestMarker(e.x(), e.y())
408                if dist < 15:
409                    self.selectedAnchorIndex = self.shownAttributes.index(str(marker.label().text()))
410            else:
411                (curve, dist, x, y, index) = self.closestCurve(e.x(), e.y())
412                if dist < 5 and str(curve.title().text()) == "dots":
413                    self.selectedAnchorIndex = index
414        else:
415            OWGraph.mousePressEvent(self, e)
416
417
418    def mouseReleaseEvent(self, e):
419        if self.manualPositioning:
420            self.mouseCurrentlyPressed = 0
421            self.selectedAnchorIndex = None
422        else:
423            OWGraph.mouseReleaseEvent(self, e)
424
425    # ##############################################################
426    # draw tooltips
427    def mouseMoveEvent(self, e):
428        redraw = (self.tooltipCurves != [] or self.tooltipMarkers != [])
429
430        for curve in self.tooltipCurves:  curve.detach()
431        for marker in self.tooltipMarkers: marker.detach()
432        self.tooltipCurves = []
433        self.tooltipMarkers = []
434
435        canvasPos = self.canvas().mapFrom(self, e.pos())
436        xFloat = self.invTransform(QwtPlot.xBottom, canvasPos.x())
437        yFloat = self.invTransform(QwtPlot.yLeft, canvasPos.y())
438
439        # in case we are drawing a rectangle, we don't draw enhanced tooltips
440        # because it would then fail to draw the rectangle
441        if self.mouseCurrentlyPressed:
442            if not self.manualPositioning:
443                OWGraph.mouseMoveEvent(self, e)
444                if redraw: self.replot()
445            else:
446                if self.selectedAnchorIndex != None:
447                    if self.widget.freeVizDlg.restrain == 1:
448                        rad = sqrt(xFloat**2 + yFloat**2)
449                        xFloat /= rad
450                        yFloat /= rad
451                    elif self.widget.freeVizDlg.restrain == 2:
452                        rad = sqrt(xFloat**2 + yFloat**2)
453                        phi = 2 * self.selectedAnchorIndex * math.pi / len(self.anchorData)
454                        xFloat = rad * cos(phi)
455                        yFloat = rad * sin(phi)
456                    self.anchorData[self.selectedAnchorIndex] = (xFloat, yFloat, self.anchorData[self.selectedAnchorIndex][2])
457                    self.updateData(self.shownAttributes)
458                    self.replot()
459                    #self.widget.recomputeEnergy()
460            return
461
462        dictValue = "%.1f-%.1f"%(xFloat, yFloat)
463        if self.dataMap.has_key(dictValue):
464            points = self.dataMap[dictValue]
465            bestDist = 100.0
466            for (x_i, y_i, color, index, extraString) in points:
467                currDist = sqrt((xFloat-x_i)*(xFloat-x_i) + (yFloat-y_i)*(yFloat-y_i))
468                if currDist < bestDist:
469                    bestDist = currDist
470                    nearestPoint = (x_i, y_i, color, index, extraString)
471
472            (x_i, y_i, color, index, extraString) = nearestPoint
473            intX = self.transform(QwtPlot.xBottom, x_i)
474            intY = self.transform(QwtPlot.yLeft, y_i)
475
476            if self.tooltipKind == LINE_TOOLTIPS and bestDist < 0.05:
477                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
478                if not self.normalizeExamples:
479                    for (xAnchor,yAnchor,label) in shownAnchorData:
480                        attrVal = self.scaledData[self.attributeNameIndex[label]][index]
481                        markerX, markerY = xAnchor*(attrVal+0.03), yAnchor*(attrVal+0.03)
482                        curve = self.addCurve("", color, color, 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0, xAnchor*attrVal], yData = [0, yAnchor*attrVal], lineWidth=3)
483
484                        marker = None
485                        fontsize = 9
486                        markerAlign = Qt.AlignCenter
487                        self.tooltipCurves.append(curve)
488                        labelIndex = self.attributeNameIndex[label]
489                        if self.tooltipValue == TOOLTIPS_SHOW_DATA:
490                            if self.dataDomain[labelIndex].varType == orange.VarTypes.Continuous:
491                                text = "%%.%df" % (self.dataDomain[labelIndex].numberOfDecimals) % (self.rawData[index][labelIndex])
492                            else:
493                                text = str(self.rawData[index][labelIndex].value)
494                            marker = self.addMarker(text, markerX, markerY, markerAlign, size = fontsize)
495                        elif self.tooltipValue == TOOLTIPS_SHOW_SPRINGS:
496                            marker = self.addMarker("%.3f" % (self.scaledData[labelIndex][index]), markerX, markerY, markerAlign, size = fontsize)
497                        self.tooltipMarkers.append(marker)
498
499            elif self.tooltipKind == VISIBLE_ATTRIBUTES or self.tooltipKind == ALL_ATTRIBUTES:
500                if self.tooltipKind == VISIBLE_ATTRIBUTES:
501                    shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
502                    labels = [s for (xA, yA, s) in shownAnchorData]
503                else:
504                    labels = []
505
506                text = self.getExampleTooltipText(self.rawData[index], labels)
507                text += "<hr>Example index = %d" % (index)
508                if extraString:
509                    text += "<hr>" + extraString
510                self.showTip(intX, intY, text)
511
512        OWGraph.mouseMoveEvent(self, e)
513        self.replot()
514
515
516    # send 2 example tables. in first is the data that is inside selected rects (polygons), in the second is unselected data
517    def getSelectionsAsExampleTables(self, attrList, useAnchorData = 1, addProjectedPositions = 0):
518        if not self.haveData: return (None, None)
519        if addProjectedPositions == 0 and not self.selectionCurveList: return (None, self.rawData)       # if no selections exist
520        if (useAnchorData and len(self.anchorData) < 3) or len(attrList) < 3: return (None, None)
521
522        xAttr=orange.FloatVariable("X Positions")
523        yAttr=orange.FloatVariable("Y Positions")
524        if addProjectedPositions == 1:
525            domain=orange.Domain([xAttr,yAttr] + [v for v in self.dataDomain.variables])
526        elif addProjectedPositions == 2:
527            domain=orange.Domain(self.dataDomain)
528            domain.addmeta(orange.newmetaid(), xAttr)
529            domain.addmeta(orange.newmetaid(), yAttr)
530        else:
531            domain = orange.Domain(self.dataDomain)
532
533        domain.addmetas(self.dataDomain.getmetas())
534
535        if useAnchorData: indices = [self.attributeNameIndex[val[2]] for val in self.anchorData]
536        else:             indices = [self.attributeNameIndex[label] for label in attrList]
537        validData = self.getValidList(indices)
538        if len(validData) == 0: return (None, None)
539
540        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
541        if array == None:       # if all examples have missing values
542            return (None, None)
543
544        #selIndices, unselIndices = self.getSelectionsAsIndices(attrList, useAnchorData, validData)
545        selIndices, unselIndices = self.getSelectedPoints(array.T[0], array.T[1], validData)
546
547        if addProjectedPositions:
548            selected = orange.ExampleTable(domain, self.rawData.selectref(selIndices))
549            unselected = orange.ExampleTable(domain, self.rawData.selectref(unselIndices))
550            selIndex = 0; unselIndex = 0
551            for i in range(len(selIndices)):
552                if selIndices[i]:
553                    selected[selIndex][xAttr] = array[i][0]
554                    selected[selIndex][yAttr] = array[i][1]
555                    selIndex += 1
556                else:
557                    unselected[unselIndex][xAttr] = array[i][0]
558                    unselected[unselIndex][yAttr] = array[i][1]
559                    unselIndex += 1
560        else:
561            selected = self.rawData.selectref(selIndices)
562            unselected = self.rawData.selectref(unselIndices)
563
564        if len(selected) == 0: selected = None
565        if len(unselected) == 0: unselected = None
566        return (selected, unselected)
567
568
569    def getSelectionsAsIndices(self, attrList, useAnchorData = 1, validData = None):
570        if not self.haveData: return [], []
571
572        attrIndices = [self.attributeNameIndex[attr] for attr in attrList]
573        if validData == None:
574            validData = self.getValidList(attrIndices)
575
576        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
577        if array == None:
578            return [], []
579        array = numpy.transpose(array)
580        return self.getSelectedPoints(array[0], array[1], validData)
581
582
583    # update shown data. Set labels, coloring by className ....
584    def savePicTeX(self):
585        lastSave = getattr(self, "lastPicTeXSave", "C:\\")
586        qfileName = QFileDialog.getSaveFileName(None, "Save to..", lastSave + "graph.pictex","PicTeX (*.pictex);;All files (*.*)")
587        fileName = str(qfileName)
588        if fileName == "":
589            return
590
591        if not os.path.splitext(fileName)[1][1:]:
592            fileName = fileName + ".pictex"
593
594        self.lastSave = os.path.split(fileName)[0]+"/"
595        file = open(fileName, "wt")
596
597        file.write("\\mbox{\n")
598        file.write(\\beginpicture\n")
599        file.write(\\setcoordinatesystem units <0.4\columnwidth, 0.4\columnwidth>\n")
600        file.write(\\setplotarea x from -1.1 to 1.1, y from -1 to 1.1\n")
601
602        if not self.normalizeExamples:
603            file.write("\\circulararc 360 degrees from 1 0 center at 0 0\n")
604
605        if self.showAnchors:
606            if self.hideRadius > 0:
607                file.write("\\setdashes\n")
608                file.write("\\circulararc 360 degrees from %5.3f 0 center at 0 0\n" % (self.hideRadius/10.))
609                file.write("\\setsolid\n")
610
611            if self.showAttributeNames:
612                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
613                if not self.normalizeExamples:
614                    for x,y,l in shownAnchorData:
615                        file.write("\\plot 0 0 %5.3f %5.3f /\n" % (x, y))
616                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
617                else:
618                    file.write("\\multiput {\\small $\\odot$} at %s /\n" % (" ".join(["%5.3f %5.3f" % tuple(i[:2]) for i in shownAnchorData])))
619                    for x,y,l in shownAnchorData:
620                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
621
622        symbols = ("{\\small $\\circ$}", "{\\tiny $\\times$}", "{\\tiny $+$}", "{\\small $\\star$}",
623                   "{\\small $\\ast$}", "{\\tiny $\\div$}", "{\\small $\\bullet$}", ) + tuple([chr(x) for x in range(97, 123)])
624        dataSize = len(self.rawData)
625        labels = self.widget.getShownAttributeList()
626        indices = [self.attributeNameIndex[label] for label in labels]
627        selectedData = numpy.take(self.scaledData, indices, axis = 0)
628        XAnchors = numpy.array([a[0] for a in self.anchorData])
629        YAnchors = numpy.array([a[1] for a in self.anchorData])
630
631        r = numpy.sqrt(XAnchors*XAnchors + YAnchors*YAnchors)     # compute the distance of each anchor from the center of the circle
632        XAnchors *= r                                               # we need to normalize the anchors by r, otherwise the anchors won't attract points less if they are placed at the center of the circle
633        YAnchors *= r
634
635        x_positions = numpy.dot(XAnchors, selectedData)
636        y_positions = numpy.dot(YAnchors, selectedData)
637
638        if self.normalizeExamples:
639            sum_i = self._getSum_i(selectedData, useAnchorData = 1, anchorRadius = r)
640            x_positions /= sum_i
641            y_positions /= sum_i
642
643        if self.scaleFactor:
644            self.trueScaleFactor = self.scaleFactor
645        else:
646            abss = x_positions*x_positions + y_positions*y_positions
647            self.trueScaleFactor =  1 / sqrt(abss[numpy.argmax(abss)])
648
649        x_positions *= self.trueScaleFactor
650        y_positions *= self.trueScaleFactor
651
652        validData = self.getValidList(indices)
653
654        pos = [[] for i in range(len(self.dataDomain.classVar.values))]
655        for i in range(dataSize):
656            if validData[i]:
657                pos[int(self.originalData[self.dataClassIndex][i])].append((x_positions[i], y_positions[i]))
658
659        for i in range(len(self.dataDomain.classVar.values)):
660            file.write("\\multiput {%s} at %s /\n" % (symbols[i], " ".join(["%5.3f %5.3f" % p for p in pos[i]])))
661
662        if self.showLegend:
663            classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
664            file.write("\\put {%s} [lB] at 0.87 1.06\n" % self.dataDomain.classVar.name)
665            for index in range(len(classVariableValues)):
666                file.write("\\put {%s} at 1.0 %5.3f\n" % (symbols[index], 0.93 - 0.115*index))
667                file.write("\\put {%s} [lB] at 1.05 %5.3f\n" % (classVariableValues[index], 0.9 - 0.115*index))
668
669        file.write("\\endpicture\n}\n")
670        file.close()
671
672    def computePotentials(self):
673        import orangeom
674        #rx = self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, 0)
675        #ry = self.transform(QwtPlot.yLeft, 0) - self.transform(QwtPlot.yLeft, 1)
676
677        rx = self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, -1)
678        ry = self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 1)
679        ox = self.transform(QwtPlot.xBottom, 0) - self.transform(QwtPlot.xBottom, -1)
680        oy = self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 0)
681
682        rx -= rx % self.squareGranularity
683        ry -= ry % self.squareGranularity
684
685        if not getattr(self, "potentialsImage", None) \
686           or getattr(self, "potentialContext", None) != (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells):
687            if self.potentialsClassifier.classVar.varType == orange.VarTypes.Continuous:
688                imagebmp = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells)
689                palette = [qRgb(255.*i/255., 255.*i/255., 255-(255.*i/255.)) for i in range(255)] + [qRgb(255, 255, 255)]
690            else:
691                imagebmp, nShades = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells) # the last argument is self.trueScaleFactor (in LinProjGraph...)
692                palette = []
693                sortedClasses = getVariableValuesSorted(self.potentialsClassifier.domain.classVar)
694                for cls in self.potentialsClassifier.classVar.values:
695                    color = self.discPalette.getRGB(sortedClasses.index(cls))
696                    towhite = [255-c for c in color]
697                    for s in range(nShades):
698                        si = 1-float(s)/nShades
699                        palette.append(qRgb(*tuple([color[i]+towhite[i]*si for i in (0, 1, 2)])))
700                palette.extend([qRgb(255, 255, 255) for i in range(256-len(palette))])
701
702            self.potentialsImage = QImage(imagebmp, rx, ry, QImage.Format_Indexed8)
703            self.potentialsImage.setColorTable(OWColorPalette.signedPalette(palette) if qVersion() < "4.5" else palette)
704            self.potentialsImage.setNumColors(256)
705            self.potentialContext = (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells)
706            self.potentialsImageFromClassifier = self.potentialsClassifier
707
708
709
710    def drawCanvas(self, painter):
711        if self.showProbabilities and getattr(self, "potentialsClassifier", None):
712            if not (self.potentialsClassifier is getattr(self, "potentialsImageFromClassifier", None)):
713                self.computePotentials()
714            target = QRectF(self.transform(QwtPlot.xBottom, -1), self.transform(QwtPlot.yLeft, 1),
715                            self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, -1),
716                            self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 1))
717            source = QRectF(0, 0, self.potentialsImage.size().width(), self.potentialsImage.size().height())
718            painter.drawImage(target, self.potentialsImage, source)
719#            painter.drawImage(self.transform(QwtPlot.xBottom, -1), self.transform(QwtPlot.yLeft, 1), self.potentialsImage)
720        OWGraph.drawCanvas(self, painter)
721
722OWLinProjGraph = graph_deprecator(OWLinProjGraph)
723
724if __name__== "__main__":
725    #Draw a simple graph
726    import os
727    a = QApplication(sys.argv)
728    graph = OWLinProjGraph(None)
729    data = orange.ExampleTable(r"E:\Development\Orange Datasets\UCI\wine.tab")
730    graph.setData(data)
731    graph.updateData([attr.name for attr in data.domain.attributes])
732    graph.show()
733    a.exec_()
Note: See TracBrowser for help on using the repository browser.