source: orange/Orange/OrangeWidgets/Visualize Qt/OWScatterPlotGraphQt.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 16.3 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

RevLine 
[8682]1#
2# OWScatterPlotGraph.py
3#
4from plot.owplot import *
5import time
6from orngCI import FeatureByCartesianProduct
7##import OWClusterOptimization
8import orngVisFuncts
9from orngScaleScatterPlotData import *
10import ColorPalette
11import numpy
12
13DONT_SHOW_TOOLTIPS = 0
14VISIBLE_ATTRIBUTES = 1
15ALL_ATTRIBUTES = 2
16
17MIN_SHAPE_SIZE = 6
18
19
20###########################################################################################
21##### CLASS : OWSCATTERPLOTGRAPH
22###########################################################################################
23class OWScatterPlotGraphQt(OWPlot, orngScaleScatterPlotData):
24    def __init__(self, scatterWidget, parent = None, name = "None"):
[8691]25        OWPlot.__init__(self, parent, name, widget = scatterWidget)
[8682]26        orngScaleScatterPlotData.__init__(self)
27
28        self.pointWidth = 8
29        self.jitterContinuous = 0
30        self.jitterSize = 5
31        self.showXaxisTitle = 1
32        self.showYLaxisTitle = 1
33        self.showLegend = 1
34        self.tooltipKind = 1
35        self.showFilledSymbols = 1
36        self.showProbabilities = 0
37
38        self.tooltipData = []
39        self.scatterWidget = scatterWidget
40        self.insideColors = None
41        self.shownAttributeIndices = []
42        self.shownXAttribute = ""
43        self.shownYAttribute = ""
44        self.squareGranularity = 3
45        self.spaceBetweenCells = 1
46        self.oldLegendKeys = {}
47
48        self.enableWheelZoom = 1
[8717]49        self.potentialsCurve = None
[8682]50
51    def setData(self, data, subsetData = None, **args):
52        OWPlot.setData(self, data)
53        self.oldLegendKeys = {}
54        orngScaleScatterPlotData.setData(self, data, subsetData, **args)
55
56    #########################################################
57    # update shown data. Set labels, coloring by className ....
58    def updateData(self, xAttr, yAttr, colorAttr, shapeAttr = "", sizeShapeAttr = "", labelAttr = None, **args):
59        self.legend().clear()
60        self.tooltipData = []
61        self.potentialsClassifier = None
62        self.potentialsImage = None
63        # self.canvas().invalidatePaintCache()
64        self.shownXAttribute = xAttr
65        self.shownYAttribute = yAttr
66
67        if self.scaledData == None or len(self.scaledData) == 0:
[8228]68            self.setAxisScale(xBottom, 0, 1, 1); 
69            self.setAxisScale(yLeft, 0, 1, 1)
[8682]70            self.setXaxisTitle(""); self.setYLaxisTitle("")
71            self.oldLegendKeys = {}
72            return
73
74        self.__dict__.update(args)      # set value from args dictionary
75
76        colorIndex = -1
77        if colorAttr != "" and colorAttr != "(Same color)":
78            colorIndex = self.attributeNameIndex[colorAttr]
79            if self.dataDomain[colorAttr].varType == orange.VarTypes.Discrete:
80                self.discPalette.setNumberOfColors(len(self.dataDomain[colorAttr].values))
81
82        shapeIndex = -1
83        if shapeAttr != "" and shapeAttr != "(Same shape)" and len(self.dataDomain[shapeAttr].values) < 11:
84            shapeIndex = self.attributeNameIndex[shapeAttr]
85
86        sizeIndex = -1
87        if sizeShapeAttr != "" and sizeShapeAttr != "(Same size)":
88            sizeIndex = self.attributeNameIndex[sizeShapeAttr]
89           
90        showContinuousColorLegend = colorIndex != -1 and self.dataDomain[colorIndex].varType == orange.VarTypes.Continuous
91
92        (xVarMin, xVarMax) = self.attrValues[xAttr]
93        (yVarMin, yVarMax) = self.attrValues[yAttr]
94        xVar = max(xVarMax - xVarMin, 1e-10)
95        yVar = max(yVarMax - yVarMin, 1e-10)
96        xAttrIndex = self.attributeNameIndex[xAttr]
97        yAttrIndex = self.attributeNameIndex[yAttr]
98
99        attrIndices = [xAttrIndex, yAttrIndex, colorIndex, shapeIndex, sizeIndex]
100        while -1 in attrIndices: attrIndices.remove(-1)
101        self.shownAttributeIndices = attrIndices
102
103        # set axis for x attribute
104        discreteX = self.dataDomain[xAttrIndex].varType == orange.VarTypes.Discrete
105        if discreteX:
106            xVarMax -= 1; xVar -= 1
107            xmin = xVarMin - (self.jitterSize + 10.)/100.
108            xmax = xVarMax + (self.jitterSize + 10.)/100.
109            labels = getVariableValuesSorted(self.dataDomain[xAttrIndex])
110        else:
111            off  = (xVarMax - xVarMin) * (self.jitterSize * self.jitterContinuous + 2) / 100.0
112            xmin = xVarMin - off
113            xmax = xVarMax + off
114            labels = None
115        self.setXlabels(labels)
116        self.setAxisScale(xBottom, xmin, xmax,  discreteX)
117
118        # set axis for y attribute
119        discreteY = self.dataDomain[yAttrIndex].varType == orange.VarTypes.Discrete
120        if discreteY:
121            yVarMax -= 1; yVar -= 1
122            ymin = yVarMin - (self.jitterSize + 10.)/100.
123            ymax = yVarMax + (self.jitterSize + 10.)/100.
124            labels = getVariableValuesSorted(self.dataDomain[yAttrIndex])
125        else:
126            off  = (yVarMax - yVarMin) * (self.jitterSize * self.jitterContinuous + 2) / 100.0
127            ymin = yVarMin - off
128            ymax = yVarMax + off
129            labels = None
130        self.setYLlabels(labels)
131        self.setAxisScale(yLeft, ymin, ymax, discreteY)
132
133        self.setXaxisTitle(xAttr)
134        self.setYLaxisTitle(yAttr)
135
136        # compute x and y positions of the points in the scatterplot
137        xData, yData = self.getXYDataPositions(xAttr, yAttr)
138        validData = self.getValidList(attrIndices)      # get examples that have valid data for each used attribute
139
140        # #######################################################
141        # show probabilities
[8717]142        if self.potentialsCurve:
143            self.potentialsCurve.detach()
144            self.potentialsCurve = None
[8682]145        if self.showProbabilities and colorIndex >= 0 and self.dataDomain[colorIndex].varType in [orange.VarTypes.Discrete, orange.VarTypes.Continuous]:
146            if self.dataDomain[colorIndex].varType == orange.VarTypes.Discrete: domain = orange.Domain([self.dataDomain[xAttrIndex], self.dataDomain[yAttrIndex], orange.EnumVariable(self.attributeNames[colorIndex], values = getVariableValuesSorted(self.dataDomain[colorIndex]))])
147            else:                                                               domain = orange.Domain([self.dataDomain[xAttrIndex], self.dataDomain[yAttrIndex], orange.FloatVariable(self.attributeNames[colorIndex])])
148            xdiff = xmax-xmin; ydiff = ymax-ymin
149            scX = xData/xdiff
150            scY = yData/ydiff
151            classData = self.originalData[colorIndex]
152
153            probData = numpy.transpose(numpy.array([scX, scY, classData]))
154            probData= numpy.compress(validData, probData, axis = 0)
155           
156            sys.stderr.flush()
157            self.xmin = xmin; self.xmax = xmax
158            self.ymin = ymin; self.ymax = ymax
159           
160            if probData.any():
161                self.potentialsClassifier = orange.P2NN(domain, probData, None, None, None, None)
[8717]162                self.potentialsCurve = ProbabilitiesItem(self.potentialsClassifier, self.squareGranularity, 1., self.spaceBetweenCells)
163                self.potentialsCurve.attach(self)
[8682]164            else:
165                self.potentialsClassifier = None
166       
167        """
168            Create a single curve with different points
169        """
[8712]170       
[8716]171        def_color = self.color(OWPalette.Data)
[8712]172        def_size = self.point_width
173        def_shape = self.curveSymbols[0]
[8682]174
175        if colorIndex != -1:
176            if self.dataDomain[colorIndex].varType == orange.VarTypes.Continuous:
[8257]177                c_data = self.noJitteringScaledData[colorIndex]
178                palette = self.continuous_palette
[8682]179            else:
[8257]180                c_data = self.originalData[colorIndex]
181                palette = self.discrete_palette
182            checked_color_data = [(c_data[i] if validData[i] else 0) for i in range(len(c_data))]
183            colorData = [QColor(*palette.getRGB(i)) for i in checked_color_data]
184        else:
185            colorData = [def_color]
[8682]186
187        if sizeIndex != -1:
188            sizeData = [MIN_SHAPE_SIZE + round(i * self.pointWidth) for i in self.noJitteringScaledData[sizeIndex]]
189        else:
[8712]190            sizeData = [def_size]
[8682]191           
192        if shapeIndex != -1 and self.dataDomain[shapeIndex].varType == orange.VarTypes.Discrete:
193            shapeData = [self.curveSymbols[int(i)] for i in self.originalData[shapeIndex]]
194        else:
[8712]195            shapeData = [def_shape]
[8682]196           
197        if labelAttr and labelAttr in [self.rawData.domain.getmeta(mykey).name for mykey in self.rawData.domain.getmetas().keys()] + [var.name for var in self.rawData.domain]:
198            if self.rawData[0][labelAttr].varType == orange.VarTypes.Continuous:
199                labelData = ["%4.1f" % orange.Value(i[labelAttr]) if not i[labelAttr].isSpecial() else "" for i in self.rawData]
200            else:
201                labelData = [str(i[labelAttr].value) if not i[labelAttr].isSpecial() else "" for i in self.rawData]
202        else:
203            labelData = [""]
204
205        if self.haveSubsetData:
206            subset_ids = [example.id for example in self.rawSubsetData]
207            marked_data = [example.id in subset_ids for example in self.rawData]
208            showFilled = 0
209        else:
210            marked_data = []
[8696]211        self.set_main_curve_data(xData, yData, colorData, labelData, sizeData, shapeData, marked_data, validData)
[8682]212       
213        '''
214            Create legend items in any case
215            so that show/hide legend only
216        '''
217        discColorIndex = colorIndex if colorIndex != -1 and self.dataDomain[colorIndex].varType == orange.VarTypes.Discrete else -1
218        discShapeIndex = shapeIndex if shapeIndex != -1 and self.dataDomain[shapeIndex].varType == orange.VarTypes.Discrete else -1
219        discSizeIndex = sizeIndex if sizeIndex != -1 and self.dataDomain[sizeIndex].varType == orange.VarTypes.Discrete else -1
220                   
221        if discColorIndex != -1:
222            num = len(self.dataDomain[discColorIndex].values)
223            varValues = getVariableValuesSorted(self.dataDomain[discColorIndex])
224            for ind in range(num):
[8712]225                self.legend().add_item(self.dataDomain[discColorIndex].name, varValues[ind], OWPoint(def_shape, self.discPalette[ind], def_size))
[8682]226
227        if discShapeIndex != -1:
228            num = len(self.dataDomain[discShapeIndex].values)
229            varValues = getVariableValuesSorted(self.dataDomain[discShapeIndex])
230            for ind in range(num):
[8712]231                self.legend().add_item(self.dataDomain[discShapeIndex].name, varValues[ind], OWPoint(self.curveSymbols[ind], def_color, def_size))
[8682]232
233        if discSizeIndex != -1:
234            num = len(self.dataDomain[discSizeIndex].values)
235            varValues = getVariableValuesSorted(self.dataDomain[discSizeIndex])
236            for ind in range(num):
[8712]237                self.legend().add_item(self.dataDomain[discSizeIndex].name, varValues[ind], OWPoint(def_shape, def_color, MIN_SHAPE_SIZE + round(ind*self.pointWidth/len(varValues))))
[8682]238
239        # ##############################################################
240        # draw color scale for continuous coloring attribute
241        if colorIndex != -1 and showContinuousColorLegend:
242            self.legend().add_color_gradient(colorAttr, [("%%.%df" % self.dataDomain[colorAttr].numberOfDecimals % v) for v in self.attrValues[colorAttr]])
243           
244        self.replot()
245
246##    # ##############################################################
247##    # ######  SHOW CLUSTER LINES  ##################################
248##    # ##############################################################
249##    def showClusterLines(self, xAttr, yAttr, width = 1):
250##        classIndices = getVariableValueIndices(self.rawData, self.attributeNameIndex[self.rawData.domain.classVar.name])
251##
252##        shortData = self.rawData.select([self.rawData.domain[xAttr], self.rawData.domain[yAttr], self.rawData.domain.classVar])
253##        shortData = orange.Preprocessor_dropMissing(shortData)
254##
255##        (closure, enlargedClosure, classValue) = self.clusterClosure
256##
257##        (xVarMin, xVarMax) = self.attrValues[xAttr]
258##        (yVarMin, yVarMax) = self.attrValues[yAttr]
259##        xVar = xVarMax - xVarMin
260##        yVar = yVarMax - yVarMin
261##
262##        if type(closure) == dict:
263##            for key in closure.keys():
264##                clusterLines = closure[key]
265##                color = self.discPalette[classIndices[self.rawData.domain.classVar[classValue[key]].value]]
266##                for (p1, p2) in clusterLines:
267##                    self.addCurve("", color, color, 1, QwtPlotCurve.Lines, OWPoint.NoSymbol, xData = [float(shortData[p1][0]), float(shortData[p2][0])], yData = [float(shortData[p1][1]), float(shortData[p2][1])], lineWidth = width)
268##        else:
269##            colorIndex = self.discPalette[classIndices[self.rawData.domain.classVar[classValue].value]]
270##            for (p1, p2) in closure:
271##                self.addCurve("", color, color, 1, QwtPlotCurve.Lines, OWPoint.NoSymbol, xData = [float(shortData[p1][0]), float(shortData[p2][0])], yData = [float(shortData[p1][1]), float(shortData[p2][1])], lineWidth = width)
272   
273    def update_point_size(self):
274        if self.scatterWidget.attrSize:
275            self.scatterWidget.updateGraph()
276        else:
277            self.main_curve.set_point_sizes([self.point_width])
278            self.update_curves()
279   
280
281    def addTip(self, x, y, attrIndices = None, dataindex = None, text = None):
282        if self.tooltipKind == DONT_SHOW_TOOLTIPS: return
283        if text == None:
284            if self.tooltipKind == VISIBLE_ATTRIBUTES:  text = self.getExampleTooltipText(self.rawData[dataindex], attrIndices)
285            elif self.tooltipKind == ALL_ATTRIBUTES:    text = self.getExampleTooltipText(self.rawData[dataindex], range(len(self.attributeNames)))
286        self.tips.addToolTip(x, y, text)
287
288
289    # override the default buildTooltip function defined in OWPlot
290    def buildTooltip(self, exampleIndex):
291        if exampleIndex < 0:
292            example = self.rawSubsetData[-exampleIndex - 1]
293        else:
294            example = self.rawData[exampleIndex]
295
296        if self.tooltipKind == VISIBLE_ATTRIBUTES:
297            text = self.getExampleTooltipText(example, self.shownAttributeIndices)
298        elif self.tooltipKind == ALL_ATTRIBUTES:
299            text = self.getExampleTooltipText(example)
300        return text
301
302
303    # ##############################################################
304    # send 2 example tables. in first is the data that is inside selected rects (polygons), in the second is unselected data
305    def getSelectionsAsExampleTables(self, attrList):
306        [xAttr, yAttr] = attrList
307        #if not self.rawData: return (None, None, None)
308        if not self.haveData: return (None, None)
309
310        selIndices, unselIndices = self.getSelectionsAsIndices(attrList)
311
312        selected = self.rawData.selectref(selIndices)
313        unselected = self.rawData.selectref(unselIndices)
314
315        if len(selected) == 0: selected = None
316        if len(unselected) == 0: unselected = None
317
318        return (selected, unselected)
319
320
321    def getSelectionsAsIndices(self, attrList, validData = None):
322        [xAttr, yAttr] = attrList
323        if not self.haveData: return [], []
324
325        attrIndices = [self.attributeNameIndex[attr] for attr in attrList]
326        if validData == None:
327            validData = self.getValidList(attrIndices)
328
329        (xArray, yArray) = self.getXYDataPositions(xAttr, yAttr)
330
331        return self.getSelectedPoints(xArray, yArray, validData)
332
333
334    def onMouseReleased(self, e):
335        OWPlot.onMouseReleased(self, e)
336        self.updateLayout()
337
338    def computePotentials(self):
339        import orangeom
340        s = self.graph_area.toRect().size()
341        if not s.isValid():
342            self.potentialsImage = QImage()
343            return
344        rx = s.width()
345        ry = s.height()
346        rx -= rx % self.squareGranularity
347        ry -= ry % self.squareGranularity
348
349        ox = int(self.transform(xBottom, 0) - self.transform(xBottom, self.xmin))
350        oy = int(self.transform(yLeft, self.ymin) - self.transform(yLeft, 0))
351
352        if not getattr(self, "potentialsImage", None) or getattr(self, "potentialContext", None) != (rx, ry, self.shownXAttribute, self.shownYAttribute, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells):
353            self.potentialContext = (rx, ry, self.shownXAttribute, self.shownYAttribute, self.squareGranularity, self.jitterSize, self.jitterContinuous, self.spaceBetweenCells)
354            self.potentialsImageFromClassifier = self.potentialsClassifier
355
356if __name__== "__main__":
357    #Draw a simple graph
358    a = QApplication(sys.argv)
359    c = OWScatterPlotGraph(None)
360    c.show()
361    a.exec_()
Note: See TracBrowser for help on using the repository browser.