source: orange/Orange/OrangeWidgets/VisualizeQt/OWLinProjGraphQt.py @ 11474:df0622184ee6

Revision 11474:df0622184ee6, 35.4 KB checked in by markotoplak, 12 months ago (diff)

Renamed Visualize Qt to VisualizeQt (so it can be loaded in new canvas).

Line 
1from plot.owplot import *
2from copy import copy
3import time
4from operator import add
5from math import *
6from orngScaleLinProjData import *
7import orngVisFuncts
8import OWColorPalette
9from plot.owtools import UnconnectedLinesCurve, ProbabilitiesItem
10import numpy
11
12# indices in curveData
13SYMBOL = 0
14PENCOLOR = 1
15BRUSHCOLOR = 2
16
17LINE_TOOLTIPS = 0
18VISIBLE_ATTRIBUTES = 1
19ALL_ATTRIBUTES = 2
20
21TOOLTIPS_SHOW_DATA = 0
22TOOLTIPS_SHOW_SPRINGS = 1
23
24###########################################################################################
25##### CLASS : OWLINPROJGRAPH
26###########################################################################################
27class OWLinProjGraph(OWPlot, orngScaleLinProjData):
28    def __init__(self, widget, parent = None, name = "None"):
29        OWPlot.__init__(self, parent, name, axes=[], widget=widget)
30        orngScaleLinProjData.__init__(self)
31
32        self.totalPossibilities = 0 # a variable used in optimization - tells us the total number of different attribute positions
33        self.triedPossibilities = 0 # how many possibilities did we already try
34        self.p = None
35
36        self.dataMap = {}        # each key is of form: "xVal-yVal", where xVal and yVal are discretized continuous values. Value of each key has form: (x,y, HSVValue, [data vals])
37        self.tooltipCurves = []
38        self.tooltipMarkers   = []
39        self.widget = widget
40
41        # moving anchors manually
42        self.shownAttributes = []
43        self.selectedAnchorIndex = None
44
45        self.hideRadius = 0
46        self.showAnchors = 1
47        self.showValueLines = 0
48        self.valueLineLength = 5
49
50        self.onlyOnePerSubset = 1
51        self.showLegend = 1
52        self.useDifferentSymbols = 0
53        self.useDifferentColors = 1
54        self.tooltipKind = 0        # index in ["Show line tooltips", "Show visible attributes", "Show all attributes"]
55        self.tooltipValue = 0       # index in ["Tooltips show data values", "Tooltips show spring values"]
56        self.scaleFactor = 1.0
57        self.showAttributeNames = 1
58
59        self.showProbabilities = 0
60        self.squareGranularity = 3
61        self.spaceBetweenCells = 1
62
63        self.showKNN = 0   # widget sets this to 1 or 2 if you want to see correct or wrong classifications
64        self.insideColors = None
65        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
66       
67        range = (-1.13, 1.13)
68        self.data_range[xBottom] = range
69        self.data_range[yLeft] = range
70       
71        self._extra_curves = []
72        self.current_tooltip_point = None
73        self.connect(self, SIGNAL("point_hovered(Point*)"), self.draw_tooltips)
74       
75        self.value_line_curves = []
76        self.potentialsCurve = None
77
78        self.warn_unused_attributes = True
79
80    def setData(self, data, subsetData = None, **args):
81        OWPlot.setData(self, data)
82        orngScaleLinProjData.setData(self, data, subsetData, **args)
83        #self.anchorData = []
84
85        if data and data.domain.classVar and data.domain.classVar.varType == orange.VarTypes.Continuous:
86            self.data_range[xBottom] = (-1.13, 1.13 + 0.1) # if we have a continuous class we need a bit more space on the right to show a color legend
87        else:
88            self.data_range[xBottom] = (-1.13, 1.13)
89    # ####################################################################
90    # update shown data. Set labels, coloring by className ....
91    def updateData(self, labels = None, setAnchors = 0, **args):
92        for c in self._extra_curves:
93            c.detach()
94        self._extra_curves = []
95        self.tooltipMarkers = []
96        for c in self.value_line_curves:
97            c.detach()
98        self.value_line_curves = []
99        self.legend().clear()
100        self.clear_markers()
101
102        self.__dict__.update(args)
103        if labels == None: labels = [anchor[2] for anchor in self.anchorData]
104        self.shownAttributes = labels
105        self.dataMap = {}   # dictionary with keys of form "x_i-y_i" with values (x_i, y_i, color, data)
106        self.valueLineCurves = [{}, {}]    # dicts for x and y set of coordinates for unconnected lines
107
108        if not self.haveData or len(labels) < 3:
109            self.anchorData = []
110            self.updateLayout()
111            return
112
113        if setAnchors or (args.has_key("XAnchors") and args.has_key("YAnchors")):
114            self.potentialsBmp = None
115            self.setAnchors(args.get("XAnchors"), args.get("YAnchors"), labels)
116            #self.anchorData = self.createAnchors(len(labels), labels)    # used for showing tooltips
117
118        indices = [self.attributeNameIndex[anchor[2]] for anchor in self.anchorData]  # store indices to shown attributes
119
120        # do we want to show anchors and their labels
121        if self.showAnchors:
122            if self.hideRadius > 0:
123                circle = CircleCurve(QColor(200,200,200), QColor(200,200,200), radius = self.hideRadius)
124                circle.ignore_alpha = True
125                self.add_custom_curve(circle)
126                self._extra_curves.append(circle)
127
128            # draw dots at anchors
129            shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
130            self.remove_all_axes(user_only = False)
131            if not self.normalizeExamples:
132                r=self.hideRadius**2/100
133                for i,(x,y,a) in enumerate(shownAnchorData):
134                    if x > 0:
135                        line = QLineF(0, 0, x, y)
136                        arrows = AxisEnd
137                        label_pos = AxisEnd
138                    else:
139                        line = QLineF(x, y, 0, 0)
140                        arrows = AxisStart
141                        label_pos = AxisStart
142                    self.add_axis(UserAxis + i, title=a, title_location=label_pos, line=line, arrows=arrows, zoomable=True)
143                    self.setAxisLabels(UserAxis + i, [])
144            else:
145                XAnchors = [a[0] for a in shownAnchorData]
146                YAnchors = [a[1] for a in shownAnchorData]
147                c = self.addCurve("dots", QColor(160,160,160), QColor(160,160,160), 10, style = Qt.NoPen, symbol = OWPoint.Ellipse, xData = XAnchors, yData = YAnchors, showFilledSymbols = 1)
148                c.ignore_alpha = True
149                self._extra_curves.append(c)
150
151                # draw text at anchors
152                if self.showAttributeNames:
153                    for x, y, a in shownAnchorData:
154                        self.addMarker(a, x*1.07, y*1.04, Qt.AlignCenter, bold = 1)
155
156        if self.showAnchors and self.normalizeExamples:
157            # draw "circle"
158            circle = CircleCurve()
159            circle.ignore_alpha = True
160            self.add_custom_curve(circle)
161            self._extra_curves.append(circle)
162
163        self.potentialsClassifier = None # remove the classifier so that repaint won't recompute it
164        self.updateLayout()
165
166        if self.dataHasDiscreteClass:
167            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
168
169        useDifferentSymbols = self.useDifferentSymbols and self.dataHasDiscreteClass and len(self.dataDomain.classVar.values) < len(self.curveSymbols)
170        dataSize = len(self.rawData)
171        validData = self.getValidList(indices)
172        transProjData = self.createProjectionAsNumericArray(indices, validData = validData, scaleFactor = self.scaleFactor, normalize = self.normalizeExamples, jitterSize = -1, useAnchorData = 1, removeMissingData = 0)
173        if transProjData == None:
174            return
175        projData = transProjData.T
176        x_positions = projData[0]
177        y_positions = projData[1]
178       
179        xPointsToAdd = {}
180        yPointsToAdd = {}
181
182        if self.potentialsCurve:
183            self.potentialsCurve.detach()
184            self.potentialsCurve = None
185        if self.showProbabilities and self.haveData and self.dataHasClass:
186            # construct potentialsClassifier from unscaled positions
187            domain = orange.Domain([self.dataDomain[i].name for i in indices]+[self.dataDomain.classVar.name], self.dataDomain)
188            offsets = [self.attrValues[self.attributeNames[i]][0] for i in indices]
189            normalizers = [self.getMinMaxVal(i) for i in indices]
190            selectedData = numpy.take(self.originalData, indices, axis = 0)
191            averages = numpy.average(numpy.compress(validData, selectedData, axis=1), 1)
192            classData = numpy.compress(validData, self.originalData[self.dataClassIndex])
193            if classData.any():
194                self.potentialsClassifier = orange.P2NN(domain, numpy.transpose(numpy.array([numpy.compress(validData, self.unscaled_x_positions), numpy.compress(validData, self.unscaled_y_positions), classData])), self.anchorData, offsets, normalizers, averages, self.normalizeExamples, law=1)
195                self.potentialsCurve = ProbabilitiesItem(self.potentialsClassifier, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells, QRectF(-1, -1, 2, 2))
196                self.potentialsCurve.attach(self)
197            else:
198                self.potentialsClassifier = None
199            self.potentialsImage = None
200
201
202        if self.useDifferentColors:
203            if self.dataHasDiscreteClass:
204                color_data = [QColor(*self.discPalette.getRGB(i)) for i in self.originalData[self.dataClassIndex]]
205            elif self.dataHasContinuousClass:   
206                color_data = [QColor(*self.contPalette.getRGB(i)) for i in self.originalData[self.dataClassIndex]]
207            else:
208                color_data = [Qt.black]
209        else:
210            color_data = [Qt.black]
211                       
212        if self.useDifferentSymbols and self.dataHasDiscreteClass:
213            symbol_data = [self.curveSymbols[int(i)] for i in self.originalData[self.dataClassIndex]]
214        else:
215            symbol_data = [OWPoint.Ellipse]
216           
217        size_data = [self.point_width]
218        label_data = []
219       
220        if self.haveSubsetData:
221            subset_ids = [example.id for example in self.rawSubsetData]
222            marked_data = [example.id in subset_ids for example in self.rawData]
223            showFilled = 0
224        else:
225            marked_data = []
226
227        self.set_main_curve_data(x_positions, y_positions, color_data, label_data, size_data, symbol_data, marked_data, validData)
228       
229        # ##############################################################
230        # show model quality
231        # ##############################################################
232        if self.insideColors != None or self.showKNN and self.haveData:
233            # if we want to show knn classifications of the examples then turn the projection into example table and run knn
234            if self.insideColors:
235                insideData, stringData = self.insideColors
236            else:
237                shortData = self.createProjectionAsExampleTable([self.attributeNameIndex[attr] for attr in labels], useAnchorData = 1)
238                predictions, probabilities = self.widget.vizrank.kNNClassifyData(shortData)
239                if self.showKNN == 2: insideData, stringData = [1.0 - val for val in predictions], "Probability of wrong classification = %.2f%%"
240                else:                 insideData, stringData = predictions, "Probability of correct classification = %.2f%%"
241
242            if self.dataHasDiscreteClass:        classColors = self.discPalette
243            elif self.dataHasContinuousClass:    classColors = self.contPalette
244
245            if len(insideData) != len(self.rawData):
246                j = 0
247                for i in range(len(self.rawData)):
248                    if not validData[i]: continue
249                    if self.dataHasClass:
250                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[j])
251                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
252                    else:
253                        fillColor = edgeColor = (0,0,0)
254                    if self.showValueLines:
255                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
256                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[j]))
257                    j+= 1
258            else:
259                for i in range(len(self.rawData)):
260                    if not validData[i]: continue
261                    if self.dataHasClass:
262                        fillColor = classColors.getRGB(self.originalData[self.dataClassIndex][i], 255*insideData[i])
263                        edgeColor = classColors.getRGB(self.originalData[self.dataClassIndex][i])
264                    else:
265                        fillColor = edgeColor = (0,0,0)
266                    if self.showValueLines:
267                        self.addValueLineCurve(x_positions[i], y_positions[i], edgeColor, i, indices)
268                    self.addTooltipKey(x_positions[i], y_positions[i], QColor(*edgeColor), i, stringData % (100*insideData[i]))
269
270        # ##############################################################
271        # do we have a subset data to show?
272        # ##############################################################
273        elif self.haveSubsetData:
274            shownSubsetCount = 0
275            subsetIdsToDraw = dict([(example.id,1) for example in self.rawSubsetData])
276
277            # draw the rawData data set. examples that exist also in the subset data draw full, other empty
278            for i in range(dataSize):
279                if not validData[i]: continue
280                if subsetIdsToDraw.has_key(self.rawData[i].id):
281                    continue
282
283                if self.dataHasDiscreteClass and self.useDifferentColors:
284                    newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
285                elif self.dataHasContinuousClass and self.useDifferentColors:
286                    newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
287                else:
288                    newColor = (0,0,0)
289
290                if self.useDifferentSymbols and self.dataHasDiscreteClass:
291                    curveSymbol = self.curveSymbols[int(self.originalData[self.dataClassIndex][i])]
292                else:
293                    curveSymbol = self.curveSymbols[0]
294
295                if self.showValueLines:
296                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
297
298                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
299
300        elif not self.dataHasClass:
301            xs = []; ys = []
302            for i in range(dataSize):
303                if not validData[i]: continue
304                xs.append(x_positions[i])
305                ys.append(y_positions[i])
306                self.addTooltipKey(x_positions[i], y_positions[i], QColor(Qt.black), i)
307                if self.showValueLines:
308                    self.addValueLineCurve(x_positions[i], y_positions[i], (0,0,0), i, indices)
309
310        # ##############################################################
311        # CONTINUOUS class
312        # ##############################################################
313        elif self.dataHasContinuousClass:
314            for i in range(dataSize):
315                if not validData[i]: continue
316                newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
317                if self.showValueLines:
318                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
319                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
320
321        # ##############################################################
322        # DISCRETE class
323        # ##############################################################
324        elif self.dataHasDiscreteClass:
325            for i in range(dataSize):
326                if not validData[i]: continue
327                if self.useDifferentColors: newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
328                else:                       newColor = (0,0,0)
329                if self.showValueLines:
330                    self.addValueLineCurve(x_positions[i], y_positions[i], newColor, i, indices)
331                self.addTooltipKey(x_positions[i], y_positions[i], QColor(*newColor), i)
332
333        # first draw value lines
334        if self.showValueLines:
335            for i, color in enumerate(self.valueLineCurves[0].keys()):
336                curve = UnconnectedLinesCurve("", QPen(QColor(*color + (self.alphaValue,))), self.valueLineCurves[0][color], self.valueLineCurves[1][color])
337                curve.attach(self)
338                self.value_line_curves.append(curve)
339
340
341        # ##############################################################
342        # draw the legend
343        # ##############################################################
344        # show legend for discrete class
345        if self.dataHasDiscreteClass:
346            classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
347            for index in range(len(classVariableValues)):
348                if self.useDifferentColors: color = QColor(self.discPalette[index])
349                else:                       color = QColor(Qt.black)
350
351                if not self.useDifferentSymbols:  curveSymbol = self.curveSymbols[0]
352                else:                             curveSymbol = self.curveSymbols[index]
353
354                self.legend().add_item(self.dataDomain.classVar.name, classVariableValues[index], OWPoint(curveSymbol, color, self.pointWidth))
355        # show legend for continuous class
356        elif self.dataHasContinuousClass:
357            self.legend().add_color_gradient(self.dataDomain.classVar.name, [("%%.%df" % self.dataDomain.classVar.numberOfDecimals % v) for v in self.attrValues[self.dataDomain.classVar.name]])
358        self.replot()
359
360
361    # ##############################################################
362    # create a dictionary value for the data point
363    # this will enable to show tooltips faster and to make selection of examples available
364    def addTooltipKey(self, x, y, color, index, extraString = None):
365        dictValue = (x, y)
366        self.dataMap[dictValue] = (x, y, color, index, extraString)
367
368    def addValueLineCurve(self, x, y, color, exampleIndex, attrIndices):
369        XAnchors = numpy.array([val[0] for val in self.anchorData])
370        YAnchors = numpy.array([val[1] for val in self.anchorData])
371        xs = numpy.array([x] * len(self.anchorData))
372        ys = numpy.array([y] * len(self.anchorData))
373        dists = numpy.sqrt((XAnchors-xs)**2 + (YAnchors-ys)**2)
374        xVect = 0.01 * self.valueLineLength * (XAnchors - xs) / dists
375        yVect = 0.01 * self.valueLineLength * (YAnchors - ys) / dists
376        exVals = [self.noJitteringScaledData[attrInd, exampleIndex] for attrInd in attrIndices]
377
378        xs = []; ys = []
379        for i in range(len(exVals)):
380            xs += [x, x + xVect[i]*exVals[i]]
381            ys += [y, y + yVect[i]*exVals[i]]
382        self.valueLineCurves[0][color] = self.valueLineCurves[0].get(color, []) + xs
383        self.valueLineCurves[1][color] = self.valueLineCurves[1].get(color, []) + ys
384
385    def mousePressEvent(self, e):
386        if self.manualPositioning:
387            self.mouseCurrentlyPressed = 1
388            self.selectedAnchorIndex = None
389            if not self.normalizeExamples:
390                marker, dist = self.closestMarker(e.x(), e.y())
391                if dist < 15:
392                    self.selectedAnchorIndex = self.shownAttributes.index(str(marker.label().text()))
393            else:
394                (curve, dist, x, y, index) = self.closestCurve(e.x(), e.y())
395                if dist < 5 and str(curve.title().text()) == "dots":
396                    self.selectedAnchorIndex = index
397        else:
398            OWPlot.mousePressEvent(self, e)
399
400
401    def mouseReleaseEvent(self, e):
402        if self.manualPositioning:
403            self.mouseCurrentlyPressed = 0
404            self.selectedAnchorIndex = None
405        else:
406            OWPlot.mouseReleaseEvent(self, e)
407
408    def mouseMoveEvent(self, e):
409        if self._pressed_mouse_button and self.manualPositioning and self.selectedAnchorIndex != None:
410            if self.selectedAnchorIndex != None:
411                if self.widget.freeVizDlg.restrain == 1:
412                    rad = sqrt(xFloat**2 + yFloat**2)
413                    xFloat /= rad
414                    yFloat /= rad
415                elif self.widget.freeVizDlg.restrain == 2:
416                    rad = sqrt(xFloat**2 + yFloat**2)
417                    phi = 2 * self.selectedAnchorIndex * math.pi / len(self.anchorData)
418                    xFloat = rad * cos(phi)
419                    yFloat = rad * sin(phi)
420                self.anchorData[self.selectedAnchorIndex] = (xFloat, yFloat, self.anchorData[self.selectedAnchorIndex][2])
421                self.updateData(self.shownAttributes)
422                self.replot()
423                #self.widget.recomputeEnergy()
424        else:
425            OWPlot.mouseMoveEvent(self, e)
426   
427
428    # ##############################################################
429    # draw tooltips       
430    def draw_tooltips(self, point):
431        if point is self.current_tooltip_point:
432            return
433           
434    self.current_tooltip_point = point
435           
436        for curve in self.tooltipCurves:  curve.detach()
437        for marker in self.tooltipMarkers: marker.detach()
438        self.tooltipCurves = []
439        self.tooltipMarkers = []
440       
441        if not point:
442            return
443
444        xFloat, yFloat = point.coordinates()
445
446        dictValue = (xFloat, yFloat)
447        if self.dataMap.has_key(dictValue):
448            (x_i, y_i, color, index, extraString) = self.dataMap[dictValue]
449            intX = self.transform(xBottom, x_i)
450            intY = self.transform(yLeft, y_i)
451
452            if self.tooltipKind == LINE_TOOLTIPS:
453                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
454                if not self.normalizeExamples:
455                    for (xAnchor,yAnchor,label) in shownAnchorData:
456                        attrVal = self.scaledData[self.attributeNameIndex[label]][index]
457                        markerX, markerY = xAnchor*(attrVal+0.03), yAnchor*(attrVal+0.03)
458                        curve = self.addCurve("", color, color, 1, style = Qt.SolidLine, symbol = OWPoint.NoSymbol, xData = [0, xAnchor*attrVal], yData = [0, yAnchor*attrVal], lineWidth=3)
459                        curve.setZValue(HighlightZValue)
460                        self.tooltipCurves.append(curve)
461
462                        marker = None
463                        fontsize = 9
464                        markerAlign = Qt.AlignCenter
465                        labelIndex = self.attributeNameIndex[label]
466                        if self.tooltipValue == TOOLTIPS_SHOW_DATA:
467                            if self.dataDomain[labelIndex].varType == orange.VarTypes.Continuous:
468                                text = "%%.%df" % (self.dataDomain[labelIndex].numberOfDecimals) % (self.rawData[index][labelIndex])
469                            else:
470                                text = str(self.rawData[index][labelIndex].value)
471                            marker = self.addMarker(text, markerX, markerY, markerAlign, size = fontsize)
472                        elif self.tooltipValue == TOOLTIPS_SHOW_SPRINGS:
473                            marker = self.addMarker("%.3f" % (self.scaledData[labelIndex][index]), markerX, markerY, markerAlign, size = fontsize)
474                        self.tooltipMarkers.append(marker)
475
476            elif self.tooltipKind == VISIBLE_ATTRIBUTES or self.tooltipKind == ALL_ATTRIBUTES:
477                if self.tooltipKind == VISIBLE_ATTRIBUTES:
478                    shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
479                    labels = [s for (xA, yA, s) in shownAnchorData]
480                else:
481                    labels = []
482
483                text = self.getExampleTooltipText(self.rawData[index], labels)
484                text += "<hr>Example index = %d" % (index)
485                if extraString:
486                    text += "<hr>" + extraString
487                self.showTip(intX, intY, text)
488
489    # send 2 example tables. in first is the data that is inside selected rects (polygons), in the second is unselected data
490    def getSelectionsAsExampleTables(self, attrList, useAnchorData = 1, addProjectedPositions = 0):
491        if not self.haveData: return (None, None)
492        if addProjectedPositions == 0 and not self.selectionCurveList: return (None, self.rawData)       # if no selections exist
493        if (useAnchorData and len(self.anchorData) < 3) or len(attrList) < 3: return (None, None)
494
495        xAttr=orange.FloatVariable("X Positions")
496        yAttr=orange.FloatVariable("Y Positions")
497        if addProjectedPositions == 1:
498            domain=orange.Domain([xAttr,yAttr] + [v for v in self.dataDomain.variables])
499        elif addProjectedPositions == 2:
500            domain=orange.Domain(self.dataDomain)
501            domain.addmeta(orange.newmetaid(), xAttr)
502            domain.addmeta(orange.newmetaid(), yAttr)
503        else:
504            domain = orange.Domain(self.dataDomain)
505
506        domain.addmetas(self.dataDomain.getmetas())
507
508        if useAnchorData: indices = [self.attributeNameIndex[val[2]] for val in self.anchorData]
509        else:             indices = [self.attributeNameIndex[label] for label in attrList]
510        validData = self.getValidList(indices)
511        if len(validData) == 0: return (None, None)
512
513        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
514        if array == None:       # if all examples have missing values
515            return (None, None)
516
517        #selIndices, unselIndices = self.getSelectionsAsIndices(attrList, useAnchorData, validData)
518        selIndices, unselIndices = self.getSelectedPoints(array.T[0], array.T[1], validData)
519
520        if addProjectedPositions:
521            selected = orange.ExampleTable(domain, self.rawData.selectref(selIndices))
522            unselected = orange.ExampleTable(domain, self.rawData.selectref(unselIndices))
523            selIndex = 0; unselIndex = 0
524            for i in range(len(selIndices)):
525                if selIndices[i]:
526                    selected[selIndex][xAttr] = array[i][0]
527                    selected[selIndex][yAttr] = array[i][1]
528                    selIndex += 1
529                else:
530                    unselected[unselIndex][xAttr] = array[i][0]
531                    unselected[unselIndex][yAttr] = array[i][1]
532                    unselIndex += 1
533        else:
534            selected = self.rawData.selectref(selIndices)
535            unselected = self.rawData.selectref(unselIndices)
536
537        if len(selected) == 0: selected = None
538        if len(unselected) == 0: unselected = None
539        return (selected, unselected)
540
541
542    def getSelectionsAsIndices(self, attrList, useAnchorData = 1, validData = None):
543        if not self.haveData: return [], []
544
545        attrIndices = [self.attributeNameIndex[attr] for attr in attrList]
546        if validData == None:
547            validData = self.getValidList(attrIndices)
548
549        array = self.createProjectionAsNumericArray(attrList, scaleFactor = self.scaleFactor, useAnchorData = useAnchorData, removeMissingData = 0)
550        if array == None:
551            return [], []
552        array = numpy.transpose(array)
553        return self.getSelectedPoints(array[0], array[1], validData)
554
555
556    # update shown data. Set labels, coloring by className ....
557    def savePicTeX(self):
558        lastSave = getattr(self, "lastPicTeXSave", "C:\\")
559        qfileName = QFileDialog.getSaveFileName(None, "Save to..", lastSave + "graph.pictex","PicTeX (*.pictex);;All files (*.*)")
560        fileName = unicode(qfileName)
561        if fileName == "":
562            return
563
564        if not os.path.splitext(fileName)[1][1:]:
565            fileName = fileName + ".pictex"
566
567        self.lastSave = os.path.split(fileName)[0]+"/"
568        file = open(fileName, "wt")
569
570        file.write("\\mbox{\n")
571        file.write(\\beginpicture\n")
572        file.write(\\setcoordinatesystem units <0.4\columnwidth, 0.4\columnwidth>\n")
573        file.write(\\setplotarea x from -1.1 to 1.1, y from -1 to 1.1\n")
574
575        if not self.normalizeExamples:
576            file.write("\\circulararc 360 degrees from 1 0 center at 0 0\n")
577
578        if self.showAnchors:
579            if self.hideRadius > 0:
580                file.write("\\setdashes\n")
581                file.write("\\circulararc 360 degrees from %5.3f 0 center at 0 0\n" % (self.hideRadius/10.))
582                file.write("\\setsolid\n")
583
584            if self.showAttributeNames:
585                shownAnchorData = filter(lambda p, r=self.hideRadius**2/100: p[0]**2+p[1]**2>r, self.anchorData)
586                if not self.normalizeExamples:
587                    for x,y,l in shownAnchorData:
588                        file.write("\\plot 0 0 %5.3f %5.3f /\n" % (x, y))
589                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
590                else:
591                    file.write("\\multiput {\\small $\\odot$} at %s /\n" % (" ".join(["%5.3f %5.3f" % tuple(i[:2]) for i in shownAnchorData])))
592                    for x,y,l in shownAnchorData:
593                        file.write("\\put {{\\footnotesize %s}} [b] at %5.3f %5.3f\n" % (l.replace("_", "-"), x*1.07, y*1.04))
594
595        symbols = ("{\\small $\\circ$}", "{\\tiny $\\times$}", "{\\tiny $+$}", "{\\small $\\star$}",
596                   "{\\small $\\ast$}", "{\\tiny $\\div$}", "{\\small $\\bullet$}", ) + tuple([chr(x) for x in range(97, 123)])
597        dataSize = len(self.rawData)
598        labels = self.widget.getShownAttributeList()
599        indices = [self.attributeNameIndex[label] for label in labels]
600        selectedData = numpy.take(self.scaledData, indices, axis = 0)
601        XAnchors = numpy.array([a[0] for a in self.anchorData])
602        YAnchors = numpy.array([a[1] for a in self.anchorData])
603
604        r = numpy.sqrt(XAnchors*XAnchors + YAnchors*YAnchors)     # compute the distance of each anchor from the center of the circle
605        XAnchors *= r                                               # we need to normalize the anchors by r, otherwise the anchors won't attract points less if they are placed at the center of the circle
606        YAnchors *= r
607
608        x_positions = numpy.dot(XAnchors, selectedData)
609        y_positions = numpy.dot(YAnchors, selectedData)
610
611        if self.normalizeExamples:
612            sum_i = self._getSum_i(selectedData, useAnchorData = 1, anchorRadius = r)
613            x_positions /= sum_i
614            y_positions /= sum_i
615
616        if self.scaleFactor:
617            self.trueScaleFactor = self.scaleFactor
618        else:
619            abss = x_positions*x_positions + y_positions*y_positions
620            self.trueScaleFactor =  1 / sqrt(abss[numpy.argmax(abss)])
621
622        x_positions *= self.trueScaleFactor
623        y_positions *= self.trueScaleFactor
624
625        validData = self.getValidList(indices)
626
627        pos = [[] for i in range(len(self.dataDomain.classVar.values))]
628        for i in range(dataSize):
629            if validData[i]:
630                pos[int(self.originalData[self.dataClassIndex][i])].append((x_positions[i], y_positions[i]))
631
632        for i in range(len(self.dataDomain.classVar.values)):
633            file.write("\\multiput {%s} at %s /\n" % (symbols[i], " ".join(["%5.3f %5.3f" % p for p in pos[i]])))
634
635        if self.showLegend:
636            classVariableValues = getVariableValuesSorted(self.dataDomain.classVar)
637            file.write("\\put {%s} [lB] at 0.87 1.06\n" % self.dataDomain.classVar.name)
638            for index in range(len(classVariableValues)):
639                file.write("\\put {%s} at 1.0 %5.3f\n" % (symbols[index], 0.93 - 0.115*index))
640                file.write("\\put {%s} [lB] at 1.05 %5.3f\n" % (classVariableValues[index], 0.9 - 0.115*index))
641
642        file.write("\\endpicture\n}\n")
643        file.close()
644
645    def computePotentials(self):
646        import orangeom
647        #rx = self.transform(xBottom, 1) - self.transform(xBottom, 0)
648        #ry = self.transform(yLeft, 0) - self.transform(yLeft, 1)
649
650        rx = self.transform(xBottom, 1) - self.transform(xBottom, -1)
651        ry = self.transform(yLeft, -1) - self.transform(yLeft, 1)
652        ox = self.transform(xBottom, 0) - self.transform(xBottom, -1)
653        oy = self.transform(yLeft, -1) - self.transform(yLeft, 0)
654
655        rx -= rx % self.squareGranularity
656        ry -= ry % self.squareGranularity
657
658        if not getattr(self, "potentialsImage", None) \
659           or getattr(self, "potentialContext", None) != (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells):
660            if self.potentialsClassifier.classVar.varType == orange.VarTypes.Continuous:
661                imagebmp = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells)
662                palette = [qRgb(255.*i/255., 255.*i/255., 255-(255.*i/255.)) for i in range(255)] + [qRgb(255, 255, 255)]
663            else:
664                imagebmp, nShades = orangeom.potentialsBitmap(self.potentialsClassifier, rx, ry, ox, oy, self.squareGranularity, self.trueScaleFactor/2, self.spaceBetweenCells) # the last argument is self.trueScaleFactor (in LinProjGraph...)
665                palette = []
666                sortedClasses = getVariableValuesSorted(self.potentialsClassifier.domain.classVar)
667                for cls in self.potentialsClassifier.classVar.values:
668                    color = self.discPalette.getRGB(sortedClasses.index(cls))
669                    towhite = [255-c for c in color]
670                    for s in range(nShades):
671                        si = 1-float(s)/nShades
672                        palette.append(qRgb(*tuple([color[i]+towhite[i]*si for i in (0, 1, 2)])))
673                palette.extend([qRgb(255, 255, 255) for i in range(256-len(palette))])
674
675            self.potentialsImage = QImage(imagebmp, rx, ry, QImage.Format_Indexed8)
676            self.potentialsImage.setColorTable(OWColorPalette.signedPalette(palette) if qVersion() < "4.5" else palette)
677            self.potentialsImage.setNumColors(256)
678            self.potentialContext = (rx, ry, self.shownAttributes, self.trueScaleFactor, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells)
679            self.potentialsImageFromClassifier = self.potentialsClassifier
680
681
682
683    def drawCanvas(self, painter):
684        if self.showProbabilities and getattr(self, "potentialsClassifier", None):
685            if not (self.potentialsClassifier is getattr(self, "potentialsImageFromClassifier", None)):
686                self.computePotentials()
687            target = QRectF(self.transform(xBottom, -1), self.transform(yLeft, 1),
688                            self.transform(xBottom, 1) - self.transform(xBottom, -1),
689                            self.transform(yLeft, -1) - self.transform(yLeft, 1))
690            source = QRectF(0, 0, self.potentialsImage.size().width(), self.potentialsImage.size().height())
691            painter.drawImage(target, self.potentialsImage, source)
692#            painter.drawImage(self.transform(xBottom, -1), self.transform(yLeft, 1), self.potentialsImage)
693        OWPlot.drawCanvas(self, painter)
694       
695    def update_point_size(self):
696        if self.main_curve:
697            # We never have different sizes in LinProj
698            self.main_curve.set_point_sizes([self.pointWidth])
699
700OWLinProjGraph = graph_deprecator(OWLinProjGraph)
701
702
703if __name__== "__main__":
704    #Draw a simple graph
705    import os
706    a = QApplication(sys.argv)
707    graph = OWLinProjGraph(None)
708    data = orange.ExampleTable(r"E:\Development\Orange Datasets\UCI\wine.tab")
709    graph.setData(data)
710    graph.updateData([attr.name for attr in data.domain.attributes])
711    graph.show()
712    a.exec_()
Note: See TracBrowser for help on using the repository browser.