source: orange/Orange/OrangeWidgets/Visualize/OWLinProjGraph.py @ 10705:cb0cb39cd12c

Revision 10705:cb0cb39cd12c, 39.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed scatterplot (and linear projection) point jittering (now remains the same through subset data changes), fixes #1170.

Line 
1from OWGraph import *
2from copy import copy
3import time
4from operator import add
5from math import *
6from orngScaleLinProjData import *
7import orngVisFuncts
8import OWColorPalette
9from OWGraphTools import UnconnectedLinesCurve
10
11# indices in curveData
12SYMBOL = 0
13PENCOLOR = 1
14BRUSHCOLOR = 2
15
16LINE_TOOLTIPS = 0
17VISIBLE_ATTRIBUTES = 1
18ALL_ATTRIBUTES = 2
19
20TOOLTIPS_SHOW_DATA = 0
21TOOLTIPS_SHOW_SPRINGS = 1
22
23###########################################################################################
24##### CLASS : OWLINPROJGRAPH
25###########################################################################################
26class OWLinProjGraph(OWGraph, orngScaleLinProjData):
27    def __init__(self, widget, parent = None, name = "None"):
28        OWGraph.__init__(self, parent, name)
29        orngScaleLinProjData.__init__(self)
30
31        self.totalPossibilities = 0 # a variable used in optimization - tells us the total number of different attribute positions
32        self.triedPossibilities = 0 # how many possibilities did we already try
33        self.p = None
34
35        self.dataMap = {}        # each key is of form: "xVal-yVal", where xVal and yVal are discretized continuous values. Value of each key has form: (x,y, HSVValue, [data vals])
36        self.tooltipCurves = []
37        self.tooltipMarkers   = []
38        self.widget = widget
39
40        # moving anchors manually
41        self.shownAttributes = []
42        self.selectedAnchorIndex = None
43
44        self.hideRadius = 0
45        self.showAnchors = 1
46        self.showValueLines = 0
47        self.valueLineLength = 5
48
49        self.onlyOnePerSubset = 1
50        self.showLegend = 1
51        self.useDifferentSymbols = 0
52        self.useDifferentColors = 1
53        self.tooltipKind = 0        # index in ["Show line tooltips", "Show visible attributes", "Show all attributes"]
54        self.tooltipValue = 0       # index in ["Tooltips show data values", "Tooltips show spring values"]
55        self.scaleFactor = 1.0
56        self.showAttributeNames = 1
57
58        self.showProbabilities = 0
59        self.squareGranularity = 3
60        self.spaceBetweenCells = 1
61
62        self.showKNN = 0   # widget sets this to 1 or 2 if you want to see correct or wrong classifications
63        self.insideColors = None
64        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
65
66        self.enableXaxis(0)
67        self.enableYLaxis(0)
68        self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13, 1)
69        self.setAxisScale(QwtPlot.yLeft, -1.13, 1.13, 1)
70
71    def setData(self, data, subsetData = None, **args):
72        OWGraph.setData(self, data)
73        orngScaleLinProjData.setData(self, data, subsetData, **args)
74        #self.anchorData = []
75
76        self.setAxisScale(QwtPlot.yLeft, -1.13, 1.13, 1)
77        self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13, 1)
78
79        if data and data.domain.classVar and data.domain.classVar.varType == orange.VarTypes.Continuous:
80            self.setAxisScale(QwtPlot.xBottom, -1.13, 1.13 + 0.1, 1)   # if we have a continuous class we need a bit more space on the right to show a color legend
81
82    # ####################################################################
83    # update shown data. Set labels, coloring by className ....
84    def updateData(self, labels = None, setAnchors = 0, **args):
85        self.removeDrawingCurves()  # my function, that doesn't delete selection curves
86        #self.removeCurves()
87        self.removeMarkers()
88        self.tooltipMarkers = []
89
90        self.__dict__.update(args)
91        if labels == None: labels = [anchor[2] for anchor in self.anchorData]
92        self.shownAttributes = labels
93        self.dataMap = {}   # dictionary with keys of form "x_i-y_i" with values (x_i, y_i, color, data)
94        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
95
96        if not self.haveData or len(labels) < 3:
97            self.anchorData = []
98            self.updateLayout()
99            return
100
101        if setAnchors or (args.has_key("XAnchors") and args.has_key("YAnchors")):
102            self.potentialsBmp = None
103            self.setAnchors(args.get("XAnchors"), args.get("YAnchors"), labels)
104            #self.anchorData = self.createAnchors(len(labels), labels)    # used for showing tooltips
105
106        indices = [self.attributeNameIndex[anchor[2]] for anchor in self.anchorData]  # store indices to shown attributes
107
108        # do we want to show anchors and their labels
109        if self.showAnchors:
110            if self.hideRadius > 0:
111                xdata = self.createXAnchors(100)*(self.hideRadius / 10)
112                ydata = self.createYAnchors(100)*(self.hideRadius / 10)
113                self.addCurve("hidecircle", QColor(200,200,200), QColor(200,200,200), 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = xdata.tolist() + [xdata[0]], yData = ydata.tolist() + [ydata[0]])
114
115            # draw dots at anchors
116            shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
117
118            if not self.normalizeExamples:
119                r=self.hideRadius**2/100
120                for i,(x,y,a) in enumerate(shownAnchorData):
121                    self.addCurve("l%i" % i, QColor(160, 160, 160), QColor(160, 160, 160), 10, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0, x], yData = [0, y], showFilledSymbols = 1, lineWidth=2)
122                    if self.showAttributeNames:
123                        self.addMarker(a, x*1.07, y*1.04, Qt.AlignCenter, bold=1)
124            else:
125                XAnchors = [a[0] for a in shownAnchorData]
126                YAnchors = [a[1] for a in shownAnchorData]
127                self.addCurve("dots", QColor(160,160,160), QColor(160,160,160), 10, style = QwtPlotCurve.NoCurve, symbol = QwtSymbol.Ellipse, xData = XAnchors, yData = YAnchors, showFilledSymbols = 1)
128
129                # draw text at anchors
130                if self.showAttributeNames:
131                    for x, y, a in shownAnchorData:
132                        self.addMarker(a, x*1.07, y*1.04, Qt.AlignCenter, bold = 1)
133
134        if self.showAnchors and self.normalizeExamples:
135            # draw "circle"
136            xdata = self.createXAnchors(100)
137            ydata = self.createYAnchors(100)
138            self.addCurve("circle", QColor(Qt.black), QColor(Qt.black), 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = xdata.tolist() + [xdata[0]], yData = ydata.tolist() + [ydata[0]])
139
140        self.potentialsClassifier = None # remove the classifier so that repaint won't recompute it
141        self.updateLayout()
142
143        if self.dataHasDiscreteClass:
144            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
145
146        useDifferentSymbols = self.useDifferentSymbols and self.dataHasDiscreteClass and len(self.dataDomain.classVar.values) < len(self.curveSymbols)
147        dataSize = len(self.rawData)
148        validData = self.getValidList(indices)
149        transProjData = self.createProjectionAsNumericArray(indices, validData = validData, scaleFactor = self.scaleFactor, normalize = self.normalizeExamples, jitterSize = -1, useAnchorData = 1, removeMissingData = 0)
150        if transProjData == None:
151            return
152        projData = transProjData.T
153        x_positions = projData[0]
154        y_positions = projData[1]
155        xPointsToAdd = {}
156        yPointsToAdd = {}
157
158
159        if self.showProbabilities and self.haveData and self.dataHasClass:
160            # construct potentialsClassifier from unscaled positions
161            domain = orange.Domain([self.dataDomain[i].name for i in indices]+[self.dataDomain.classVar.name], self.dataDomain)
162            offsets = [self.attrValues[self.attributeNames[i]][0] for i in indices]
163            normalizers = [self.getMinMaxVal(i) for i in indices]
164            selectedData = numpy.take(self.originalData, indices, axis = 0)
165            averages = numpy.average(numpy.compress(validData, selectedData, axis=1), 1)
166            classData = numpy.compress(validData, self.originalData[self.dataClassIndex])
167            if classData.any():
168                self.potentialsClassifier = orange.P2NN(domain, numpy.transpose(numpy.array([numpy.compress(validData, self.unscaled_x_positions), numpy.compress(validData, self.unscaled_y_positions), classData])), self.anchorData, offsets, normalizers, averages, self.normalizeExamples, law=1)
169            else:
170                self.potentialsClassifier = None
171            self.potentialsImage = None
172
173
174        # ##############################################################
175        # show model quality
176        # ##############################################################
177        if self.insideColors != None or self.showKNN and self.haveData:
178            # if we want to show knn classifications of the examples then turn the projection into example table and run knn
179            if self.insideColors:
180                insideData, stringData = self.insideColors
181            else:
182                shortData = self.createProjectionAsExampleTable([self.attributeNameIndex[attr] for attr in labels], useAnchorData = 1)
183                predictions, probabilities = self.widget.vizrank.kNNClassifyData(shortData)
184                if self.showKNN == 2: insideData, stringData = [1.0 - val for val in predictions], "Probability of wrong classification = %.2f%%"
185                else:                 insideData, stringData = predictions, "Probability of correct classification = %.2f%%"
186
187            if self.dataHasDiscreteClass:        classColors = self.discPalette
188            elif self.dataHasContinuousClass:    classColors = self.contPalette
189
190            if len(insideData) != len(self.rawData):
191                j = 0
192                for i in range(len(self.rawData)):
193                    if not validData[i]: continue
194                    if self.dataHasClass:
195                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[j])
196                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
197                    else:
198                        fillColor = edgeColor = (0,0,0)
199                    self.addCurve(str(i), QColor(*fillColor+ (self.alphaValue,)), QColor(*edgeColor+ (self.alphaValue,)), self.pointWidth, xData = [x_positions[i]], yData = [y_positions[i]])
200                    if self.showValueLines:
201                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
202                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[j]))
203                    j+= 1
204            else:
205                for i in range(len(self.rawData)):
206                    if not validData[i]: continue
207                    if self.dataHasClass:
208                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[i])
209                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
210                    else:
211                        fillColor = edgeColor = (0,0,0)
212                    self.addCurve(str(i), QColor(*fillColor+ (self.alphaValue,)), QColor(*edgeColor+ (self.alphaValue,)), self.pointWidth, xData = [x_positions[i]], yData = [y_positions[i]])
213                    if self.showValueLines:
214                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
215                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[i]))
216
217        # ##############################################################
218        # do we have a subset data to show?
219        # ##############################################################
220        elif self.haveSubsetData:
221            shownSubsetCount = 0
222            subsetIdsToDraw = dict([(example.id,1) for example in self.rawSubsetData])
223            subsetIdsAlreadyDrawn = set()
224            # draw the rawData data set. examples that exist also in the subset data draw full, other empty
225            for i in range(dataSize):
226                if not validData[i]:
227                    continue
228                if subsetIdsToDraw.has_key(self.rawData[i].id):
229                    instance_filled = 1
230                    subsetIdsAlreadyDrawn.add(self.rawData[i].id)
231                else:
232                    instance_filled = 0
233
234                if self.dataHasDiscreteClass and self.useDifferentColors:
235                    newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
236                elif self.dataHasContinuousClass and self.useDifferentColors:
237                    newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
238                else:
239                    newColor = (0,0,0)
240
241                if self.useDifferentSymbols and self.dataHasDiscreteClass:
242                    curveSymbol = self.curveSymbols[int(self.originalData[self.dataClassIndex][i])]
243                else:
244                    curveSymbol = self.curveSymbols[0]
245
246                if not xPointsToAdd.has_key((newColor, curveSymbol, instance_filled)):
247                    xPointsToAdd[(newColor, curveSymbol, instance_filled)] = []
248                    yPointsToAdd[(newColor, curveSymbol, instance_filled)] = []
249                xPointsToAdd[(newColor, curveSymbol, instance_filled)].append(x_positions[i])
250                yPointsToAdd[(newColor, curveSymbol, instance_filled)].append(y_positions[i])
251                if self.showValueLines:
252                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
253
254                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
255
256            # if we have a data subset that contains examples that don't exist in the original dataset we show them here
257            XAnchors = numpy.array([val[0] for val in self.anchorData])
258            YAnchors = numpy.array([val[1] for val in self.anchorData])
259            anchorRadius = numpy.sqrt(XAnchors*XAnchors + YAnchors*YAnchors)
260            validSubData = self.getValidSubsetList(indices)
261            projSubData = self.createProjectionAsNumericArray(indices, validData = validSubData, scaleFactor = self.scaleFactor, normalize = self.normalizeExamples, jitterSize = -1, useAnchorData = 1, removeMissingData = 0, useSubsetData = 1).T
262            sub_x_positions = projSubData[0]
263            sub_y_positions = projSubData[1]
264
265            for i in range(len(self.rawSubsetData)):
266                if not validSubData[i]: # check if has missing values
267                    continue
268                if self.rawSubsetData[i].id in subsetIdsAlreadyDrawn:
269                    continue
270
271                if not self.dataHasClass or self.rawSubsetData[i].getclass().isSpecial():
272                    newColor = (0,0,0)
273                else:
274                    if self.dataHasDiscreteClass:
275                        newColor = self.discPalette.getRGB(self.originalSubsetData[self.dataClassIndex][i])
276                    else:
277                        newColor = self.contPalette.getRGB(self.noJitteringScaledSubsetData[self.dataClassIndex][i])
278
279                if self.useDifferentSymbols and self.dataHasDiscreteClass and self.validSubsetDataArray[self.dataClassIndex][i]:
280                    curveSymbol = self.curveSymbols[int(self.originalSubsetData[self.dataClassIndex][i])]
281                else:
282                    curveSymbol = self.curveSymbols[0]
283
284                if not xPointsToAdd.has_key((newColor, curveSymbol, 1)):
285                    xPointsToAdd[(newColor, curveSymbol, 1)] = []
286                    yPointsToAdd[(newColor, curveSymbol, 1)] = []
287                xPointsToAdd[(newColor, curveSymbol, 1)].append(sub_x_positions[i])
288                yPointsToAdd[(newColor, curveSymbol, 1)].append(sub_y_positions[i])
289
290        elif not self.dataHasClass:
291            xs = []; ys = []
292            for i in range(dataSize):
293                if not validData[i]: continue
294                xs.append(x_positions[i])
295                ys.append(y_positions[i])
296                self.addTooltipKey(x_positions[i], y_positions[i], QColor(Qt.black), i)
297                if self.showValueLines:
298                    self.addValueLineCurve(x_positions[i], y_positions[i], (0,0,0), i, indices)
299            self.addCurve(str(1), QColor(0,0,0,self.alphaValue), QColor(0,0,0,self.alphaValue), self.pointWidth, symbol = self.curveSymbols[0], xData = xs, yData = ys, penAlpha = self.alphaValue, brushAlpha = self.alphaValue)
300
301        # ##############################################################
302        # CONTINUOUS class
303        # ##############################################################
304        elif self.dataHasContinuousClass:
305            for i in range(dataSize):
306                if not validData[i]: continue
307                newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
308                self.addCurve(str(i), QColor(*newColor+ (self.alphaValue,)), QColor(*newColor+ (self.alphaValue,)), self.pointWidth, symbol = QwtSymbol.Ellipse, xData = [x_positions[i]], yData = [y_positions[i]])
309                if self.showValueLines:
310                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
311                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
312
313        # ##############################################################
314        # DISCRETE class
315        # ##############################################################
316        elif self.dataHasDiscreteClass:
317            for i in range(dataSize):
318                if not validData[i]: continue
319                if self.useDifferentColors: newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
320                else:                       newColor = (0,0,0)
321                if self.useDifferentSymbols: curveSymbol = self.curveSymbols[int(self.originalData[self.dataClassIndex][i])]
322                else:                        curveSymbol = self.curveSymbols[0]
323                if not xPointsToAdd.has_key((newColor, curveSymbol, self.showFilledSymbols)):
324                    xPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)] = []
325                    yPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)] = []
326                xPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)].append(x_positions[i])
327                yPointsToAdd[(newColor, curveSymbol, self.showFilledSymbols)].append(y_positions[i])
328                if self.showValueLines:
329                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
330                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
331
332        # first draw value lines
333        if self.showValueLines:
334            for i, color in enumerate(self.valueLineCurves[0].keys()):
335                curve = UnconnectedLinesCurve("", QPen(QColor(*color + (self.alphaValue,))), self.valueLineCurves[0][color], self.valueLineCurves[1][color])
336                curve.attach(self)
337
338        # draw all the points with a small number of curves
339        for i, (color, symbol, showFilled) in enumerate(xPointsToAdd.keys()):
340            xData = xPointsToAdd[(color, symbol, showFilled)]
341            yData = yPointsToAdd[(color, symbol, showFilled)]
342            self.addCurve(str(i), QColor(*color + (self.alphaValue,)), QColor(*color + (self.alphaValue,)), self.pointWidth, symbol = symbol, xData = xData, yData = yData, showFilledSymbols = showFilled)
343
344        # ##############################################################
345        # draw the legend
346        # ##############################################################
347        if self.showLegend:
348            # show legend for discrete class
349            if self.dataHasDiscreteClass:
350                self.addMarker(self.dataDomain.classVar.name, 0.87, 1.05, Qt.AlignLeft | Qt.AlignVCenter)
351
352                classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
353                for index in range(len(classVariableValues)):
354                    if self.useDifferentColors: color = QColor(self.discPalette[index])
355                    else:                       color = QColor(Qt.black)
356                    y = 1.0 - index * 0.05
357
358                    if not self.useDifferentSymbols:  curveSymbol = self.curveSymbols[0]
359                    else:                             curveSymbol = self.curveSymbols[index]
360
361                    self.addCurve(str(index), color, color, self.pointWidth, symbol = curveSymbol, xData = [0.95], yData = [y], penAlpha = self.alphaValue, brushAlpha = self.alphaValue)
362                    self.addMarker(classVariableValues[index], 0.90, y, Qt.AlignLeft | Qt.AlignVCenter)
363            # show legend for continuous class
364            elif self.dataHasContinuousClass:
365                xs = [1.15, 1.20, 1.20, 1.15]
366                count = 200
367                height = 2 / float(count)
368                for i in range(count):
369                    y = -1.0 + i*2.0/float(count)
370                    col = self.contPalette[i/float(count)]
371                    col.setAlpha(self.alphaValue)
372                    PolygonCurve(QPen(col), QBrush(col), xData = xs, yData = [y,y, y+height, y+height]).attach(self)
373
374                # add markers for min and max value of color attribute
375                [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name]
376                self.addMarker("%s = %%.%df" % (self.dataDomain.classVar.name, self.dataDomain.classVar.numberOfDecimals) % (minVal), xs[0] - 0.02, -1.0 + 0.04, Qt.AlignLeft)
377                self.addMarker("%s = %%.%df" % (self.dataDomain.classVar.name, self.dataDomain.classVar.numberOfDecimals) % (maxVal), xs[0] - 0.02, +1.0 - 0.04, Qt.AlignLeft)
378
379        self.replot()
380
381
382    # ##############################################################
383    # create a dictionary value for the data point
384    # this will enable to show tooltips faster and to make selection of examples available
385    def addTooltipKey(self, x, y, color, index, extraString = None):
386        dictValue = "%.1f-%.1f"%(x, y)
387        if not self.dataMap.has_key(dictValue): self.dataMap[dictValue] = []
388        self.dataMap[dictValue].append((x, y, color, index, extraString))
389
390
391    def addValueLineCurve(self, x, y, color, exampleIndex, attrIndices):
392        XAnchors = numpy.array([val[0] for val in self.anchorData])
393        YAnchors = numpy.array([val[1] for val in self.anchorData])
394        xs = numpy.array([x] * len(self.anchorData))
395        ys = numpy.array([y] * len(self.anchorData))
396        dists = numpy.sqrt((XAnchors-xs)**2 + (YAnchors-ys)**2)
397        xVect = 0.01 * self.valueLineLength * (XAnchors - xs) / dists
398        yVect = 0.01 * self.valueLineLength * (YAnchors - ys) / dists
399        exVals = [self.noJitteringScaledData[attrInd, exampleIndex] for attrInd in attrIndices]
400
401        xs = []; ys = []
402        for i in range(len(exVals)):
403            xs += [x, x + xVect[i]*exVals[i]]
404            ys += [y, y + yVect[i]*exVals[i]]
405        self.valueLineCurves[0][color] = self.valueLineCurves[0].get(color, []) + xs
406        self.valueLineCurves[1][color] = self.valueLineCurves[1].get(color, []) + ys
407
408
409    def mousePressEvent(self, e):
410        if self.manualPositioning:
411            self.mouseCurrentlyPressed = 1
412            self.selectedAnchorIndex = None
413            if not self.normalizeExamples:
414                marker, dist = self.closestMarker(e.x(), e.y())
415                if dist < 15:
416                    self.selectedAnchorIndex = self.shownAttributes.index(str(marker.label().text()))
417            else:
418                (curve, dist, x, y, index) = self.closestCurve(e.x(), e.y())
419                if dist < 5 and str(curve.title().text()) == "dots":
420                    self.selectedAnchorIndex = index
421        else:
422            OWGraph.mousePressEvent(self, e)
423
424
425    def mouseReleaseEvent(self, e):
426        if self.manualPositioning:
427            self.mouseCurrentlyPressed = 0
428            self.selectedAnchorIndex = None
429        else:
430            OWGraph.mouseReleaseEvent(self, e)
431
432    # ##############################################################
433    # draw tooltips
434    def mouseMoveEvent(self, e):
435        redraw = (self.tooltipCurves != [] or self.tooltipMarkers != [])
436
437        for curve in self.tooltipCurves:  curve.detach()
438        for marker in self.tooltipMarkers: marker.detach()
439        self.tooltipCurves = []
440        self.tooltipMarkers = []
441
442        canvasPos = self.canvas().mapFrom(self, e.pos())
443        xFloat = self.invTransform(QwtPlot.xBottom, canvasPos.x())
444        yFloat = self.invTransform(QwtPlot.yLeft, canvasPos.y())
445
446        # in case we are drawing a rectangle, we don't draw enhanced tooltips
447        # because it would then fail to draw the rectangle
448        if self.mouseCurrentlyPressed:
449            if not self.manualPositioning:
450                OWGraph.mouseMoveEvent(self, e)
451                if redraw: self.replot()
452            else:
453                if self.selectedAnchorIndex != None:
454                    if self.widget.freeVizDlg.restrain == 1:
455                        rad = sqrt(xFloat**2 + yFloat**2)
456                        xFloat /= rad
457                        yFloat /= rad
458                    elif self.widget.freeVizDlg.restrain == 2:
459                        rad = sqrt(xFloat**2 + yFloat**2)
460                        phi = 2 * self.selectedAnchorIndex * math.pi / len(self.anchorData)
461                        xFloat = rad * cos(phi)
462                        yFloat = rad * sin(phi)
463                    self.anchorData[self.selectedAnchorIndex] = (xFloat, yFloat, self.anchorData[self.selectedAnchorIndex][2])
464                    self.updateData(self.shownAttributes)
465                    self.replot()
466                    #self.widget.recomputeEnergy()
467            return
468
469        dictValue = "%.1f-%.1f"%(xFloat, yFloat)
470        if self.dataMap.has_key(dictValue):
471            points = self.dataMap[dictValue]
472            bestDist = 100.0
473            for (x_i, y_i, color, index, extraString) in points:
474                currDist = sqrt((xFloat-x_i)*(xFloat-x_i) + (yFloat-y_i)*(yFloat-y_i))
475                if currDist < bestDist:
476                    bestDist = currDist
477                    nearestPoint = (x_i, y_i, color, index, extraString)
478
479            (x_i, y_i, color, index, extraString) = nearestPoint
480            intX = self.transform(QwtPlot.xBottom, x_i)
481            intY = self.transform(QwtPlot.yLeft, y_i)
482
483            if self.tooltipKind == LINE_TOOLTIPS and bestDist < 0.05:
484                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
485                if not self.normalizeExamples:
486                    for (xAnchor,yAnchor,label) in shownAnchorData:
487                        attrVal = self.scaledData[self.attributeNameIndex[label]][index]
488                        markerX, markerY = xAnchor*(attrVal+0.03), yAnchor*(attrVal+0.03)
489                        curve = self.addCurve("", color, color, 1, style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0, xAnchor*attrVal], yData = [0, yAnchor*attrVal], lineWidth=3)
490
491                        marker = None
492                        fontsize = 9
493                        markerAlign = Qt.AlignCenter
494                        self.tooltipCurves.append(curve)
495                        labelIndex = self.attributeNameIndex[label]
496                        if self.tooltipValue == TOOLTIPS_SHOW_DATA:
497                            if self.dataDomain[labelIndex].varType == orange.VarTypes.Continuous:
498                                text = "%%.%df" % (self.dataDomain[labelIndex].numberOfDecimals) % (self.rawData[index][labelIndex])
499                            else:
500                                text = str(self.rawData[index][labelIndex].value)
501                            marker = self.addMarker(text, markerX, markerY, markerAlign, size = fontsize)
502                        elif self.tooltipValue == TOOLTIPS_SHOW_SPRINGS:
503                            marker = self.addMarker("%.3f" % (self.scaledData[labelIndex][index]), markerX, markerY, markerAlign, size = fontsize)
504                        self.tooltipMarkers.append(marker)
505
506            elif self.tooltipKind == VISIBLE_ATTRIBUTES or self.tooltipKind == ALL_ATTRIBUTES:
507                if self.tooltipKind == VISIBLE_ATTRIBUTES:
508                    shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
509                    labels = [s for (xA, yA, s) in shownAnchorData]
510                else:
511                    labels = []
512
513                text = self.getExampleTooltipText(self.rawData[index], labels)
514                text += "<hr>Example index = %d" % (index)
515                if extraString:
516                    text += "<hr>" + extraString
517                self.showTip(intX, intY, text)
518
519        OWGraph.mouseMoveEvent(self, e)
520        self.replot()
521
522
523    # send 2 example tables. in first is the data that is inside selected rects (polygons), in the second is unselected data
524    def getSelectionsAsExampleTables(self, attrList, useAnchorData = 1, addProjectedPositions = 0):
525        if not self.haveData: return (None, None)
526        if addProjectedPositions == 0 and not self.selectionCurveList: return (None, self.rawData)       # if no selections exist
527        if (useAnchorData and len(self.anchorData) < 3) or len(attrList) < 3: return (None, None)
528
529        xAttr=orange.FloatVariable("X Positions")
530        yAttr=orange.FloatVariable("Y Positions")
531        if addProjectedPositions == 1:
532            domain=orange.Domain([xAttr,yAttr] + [v for v in self.dataDomain.variables])
533        elif addProjectedPositions == 2:
534            domain=orange.Domain(self.dataDomain)
535            domain.addmeta(orange.newmetaid(), xAttr)
536            domain.addmeta(orange.newmetaid(), yAttr)
537        else:
538            domain = orange.Domain(self.dataDomain)
539
540        domain.addmetas(self.dataDomain.getmetas())
541
542        if useAnchorData: indices = [self.attributeNameIndex[val[2]] for val in self.anchorData]
543        else:             indices = [self.attributeNameIndex[label] for label in attrList]
544        validData = self.getValidList(indices)
545        if len(validData) == 0: return (None, None)
546
547        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
548        if array == None:       # if all examples have missing values
549            return (None, None)
550
551        #selIndices, unselIndices = self.getSelectionsAsIndices(attrList, useAnchorData, validData)
552        selIndices, unselIndices = self.getSelectedPoints(array.T[0], array.T[1], validData)
553
554        if addProjectedPositions:
555            selected = orange.ExampleTable(domain, self.rawData.selectref(selIndices))
556            unselected = orange.ExampleTable(domain, self.rawData.selectref(unselIndices))
557            selIndex = 0; unselIndex = 0
558            for i in range(len(selIndices)):
559                if selIndices[i]:
560                    selected[selIndex][xAttr] = array[i][0]
561                    selected[selIndex][yAttr] = array[i][1]
562                    selIndex += 1
563                else:
564                    unselected[unselIndex][xAttr] = array[i][0]
565                    unselected[unselIndex][yAttr] = array[i][1]
566                    unselIndex += 1
567        else:
568            selected = self.rawData.selectref(selIndices)
569            unselected = self.rawData.selectref(unselIndices)
570
571        if len(selected) == 0: selected = None
572        if len(unselected) == 0: unselected = None
573        return (selected, unselected)
574
575
576    def getSelectionsAsIndices(self, attrList, useAnchorData = 1, validData = None):
577        if not self.haveData: return [], []
578
579        attrIndices = [self.attributeNameIndex[attr] for attr in attrList]
580        if validData == None:
581            validData = self.getValidList(attrIndices)
582
583        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
584        if array == None:
585            return [], []
586        array = numpy.transpose(array)
587        return self.getSelectedPoints(array[0], array[1], validData)
588
589
590    # update shown data. Set labels, coloring by className ....
591    def savePicTeX(self):
592        lastSave = getattr(self, "lastPicTeXSave", "C:\\")
593        qfileName = QFileDialog.getSaveFileName(None, "Save to..", lastSave + "graph.pictex","PicTeX (*.pictex);;All files (*.*)")
594        fileName = unicode(qfileName)
595        if fileName == "":
596            return
597
598        if not os.path.splitext(fileName)[1][1:]:
599            fileName = fileName + ".pictex"
600
601        self.lastSave = os.path.split(fileName)[0]+"/"
602        file = open(fileName, "wt")
603
604        file.write("\\mbox{\n")
605        file.write(\\beginpicture\n")
606        file.write(\\setcoordinatesystem units <0.4\columnwidth, 0.4\columnwidth>\n")
607        file.write(\\setplotarea x from -1.1 to 1.1, y from -1 to 1.1\n")
608
609        if not self.normalizeExamples:
610            file.write("\\circulararc 360 degrees from 1 0 center at 0 0\n")
611
612        if self.showAnchors:
613            if self.hideRadius > 0:
614                file.write("\\setdashes\n")
615                file.write("\\circulararc 360 degrees from %5.3f 0 center at 0 0\n" % (self.hideRadius/10.))
616                file.write("\\setsolid\n")
617
618            if self.showAttributeNames:
619                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
620                if not self.normalizeExamples:
621                    for x,y,l in shownAnchorData:
622                        file.write("\\plot 0 0 %5.3f %5.3f /\n" % (x, y))
623                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
624                else:
625                    file.write("\\multiput {\\small $\\odot$} at %s /\n" % (" ".join(["%5.3f %5.3f" % tuple(i[:2]) for i in shownAnchorData])))
626                    for x,y,l in shownAnchorData:
627                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
628
629        symbols = ("{\\small $\\circ$}", "{\\tiny $\\times$}", "{\\tiny $+$}", "{\\small $\\star$}",
630                   "{\\small $\\ast$}", "{\\tiny $\\div$}", "{\\small $\\bullet$}", ) + tuple([chr(x) for x in range(97, 123)])
631        dataSize = len(self.rawData)
632        labels = self.widget.getShownAttributeList()
633        indices = [self.attributeNameIndex[label] for label in labels]
634        selectedData = numpy.take(self.scaledData, indices, axis = 0)
635        XAnchors = numpy.array([a[0] for a in self.anchorData])
636        YAnchors = numpy.array([a[1] for a in self.anchorData])
637
638        r = numpy.sqrt(XAnchors*XAnchors + YAnchors*YAnchors)     # compute the distance of each anchor from the center of the circle
639        XAnchors *= r                                               # we need to normalize the anchors by r, otherwise the anchors won't attract points less if they are placed at the center of the circle
640        YAnchors *= r
641
642        x_positions = numpy.dot(XAnchors, selectedData)
643        y_positions = numpy.dot(YAnchors, selectedData)
644
645        if self.normalizeExamples:
646            sum_i = self._getSum_i(selectedData, useAnchorData = 1, anchorRadius = r)
647            x_positions /= sum_i
648            y_positions /= sum_i
649
650        if self.scaleFactor:
651            self.trueScaleFactor = self.scaleFactor
652        else:
653            abss = x_positions*x_positions + y_positions*y_positions
654            self.trueScaleFactor =  1 / sqrt(abss[numpy.argmax(abss)])
655
656        x_positions *= self.trueScaleFactor
657        y_positions *= self.trueScaleFactor
658
659        validData = self.getValidList(indices)
660
661        pos = [[] for i in range(len(self.dataDomain.classVar.values))]
662        for i in range(dataSize):
663            if validData[i]:
664                pos[int(self.originalData[self.dataClassIndex][i])].append((x_positions[i], y_positions[i]))
665
666        for i in range(len(self.dataDomain.classVar.values)):
667            file.write("\\multiput {%s} at %s /\n" % (symbols[i], " ".join(["%5.3f %5.3f" % p for p in pos[i]])))
668
669        if self.showLegend:
670            classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
671            file.write("\\put {%s} [lB] at 0.87 1.06\n" % self.dataDomain.classVar.name)
672            for index in range(len(classVariableValues)):
673                file.write("\\put {%s} at 1.0 %5.3f\n" % (symbols[index], 0.93 - 0.115*index))
674                file.write("\\put {%s} [lB] at 1.05 %5.3f\n" % (classVariableValues[index], 0.9 - 0.115*index))
675
676        file.write("\\endpicture\n}\n")
677        file.close()
678
679    def computePotentials(self):
680        import orangeom
681        #rx = self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, 0)
682        #ry = self.transform(QwtPlot.yLeft, 0) - self.transform(QwtPlot.yLeft, 1)
683
684        rx = self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, -1)
685        ry = self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 1)
686        ox = self.transform(QwtPlot.xBottom, 0) - self.transform(QwtPlot.xBottom, -1)
687        oy = self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 0)
688
689        rx -= rx % self.squareGranularity
690        ry -= ry % self.squareGranularity
691
692        if not getattr(self, "potentialsImage", None) \
693           or getattr(self, "potentialContext", None) != (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells):
694            if self.potentialsClassifier.classVar.varType == orange.VarTypes.Continuous:
695                imagebmp = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells)
696                palette = [qRgb(255.*i/255., 255.*i/255., 255-(255.*i/255.)) for i in range(255)] + [qRgb(255, 255, 255)]
697            else:
698                imagebmp, nShades = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells) # the last argument is self.trueScaleFactor (in LinProjGraph...)
699                palette = []
700                sortedClasses = getVariableValuesSorted(self.potentialsClassifier.domain.classVar)
701                for cls in self.potentialsClassifier.classVar.values:
702                    color = self.discPalette.getRGB(sortedClasses.index(cls))
703                    towhite = [255-c for c in color]
704                    for s in range(nShades):
705                        si = 1-float(s)/nShades
706                        palette.append(qRgb(*tuple([color[i]+towhite[i]*si for i in (0, 1, 2)])))
707                palette.extend([qRgb(255, 255, 255) for i in range(256-len(palette))])
708
709            self.potentialsImage = QImage(imagebmp, rx, ry, QImage.Format_Indexed8)
710            self.potentialsImage.setColorTable(OWColorPalette.signedPalette(palette) if qVersion() < "4.5" else palette)
711            self.potentialsImage.setNumColors(256)
712            self.potentialContext = (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells)
713            self.potentialsImageFromClassifier = self.potentialsClassifier
714
715
716
717    def drawCanvas(self, painter):
718        if self.showProbabilities and getattr(self, "potentialsClassifier", None):
719            if not (self.potentialsClassifier is getattr(self, "potentialsImageFromClassifier", None)):
720                self.computePotentials()
721            target = QRectF(self.transform(QwtPlot.xBottom, -1), self.transform(QwtPlot.yLeft, 1),
722                            self.transform(QwtPlot.xBottom, 1) - self.transform(QwtPlot.xBottom, -1),
723                            self.transform(QwtPlot.yLeft, -1) - self.transform(QwtPlot.yLeft, 1))
724            source = QRectF(0, 0, self.potentialsImage.size().width(), self.potentialsImage.size().height())
725            painter.drawImage(target, self.potentialsImage, source)
726#            painter.drawImage(self.transform(QwtPlot.xBottom, -1), self.transform(QwtPlot.yLeft, 1), self.potentialsImage)
727        OWGraph.drawCanvas(self, painter)
728
729OWLinProjGraph = graph_deprecator(OWLinProjGraph)
730
731if __name__== "__main__":
732    #Draw a simple graph
733    import os
734    a = QApplication(sys.argv)
735    graph = OWLinProjGraph(None)
736    data = orange.ExampleTable(r"E:\Development\Orange Datasets\UCI\wine.tab")
737    graph.setData(data)
738    graph.updateData([attr.name for attr in data.domain.attributes])
739    graph.show()
740    a.exec_()
Note: See TracBrowser for help on using the repository browser.