source: orange/Orange/OrangeWidgets/Visualize/OWParallelGraph.py @ 11823:c92ee492d90d

Revision 11823:c92ee492d90d, 32.4 KB checked in by astaric <anze.staric@…>, 4 months ago (diff)

Consistent code style.

Line 
1#
2# OWParallelGraph.py
3#
4import orngEnviron
5from OWGraph import *
6#from OWDistributions import *
7from orngScaleData import *
8from statc import pearsonr
9
10NO_STATISTICS = 0
11MEANS = 1
12MEDIAN = 2
13
14
15class OWParallelGraph(OWGraph, orngScaleData):
16    def __init__(self, parallelDlg, parent=None, name=None):
17        OWGraph.__init__(self, parent, name)
18        orngScaleData.__init__(self)
19
20        self.parallelDlg = parallelDlg
21        self.showDistributions = 0
22        self.toolRects = []
23        self.useSplines = 0
24        self.showStatistics = 0
25        self.lastSelectedCurve = None
26        self.enabledLegend = 0
27        self.enableGridXB(0)
28        self.enableGridYL(0)
29        self.domainContingency = None
30        self.alphaValue2 = 150
31        self.autoUpdateAxes = 1
32        self.oldLegendKeys = []
33        self.selectionConditions = {}
34        self.visualizedAttributes = []
35        self.visualizedMidLabels = []
36        self.selectedExamples = []
37        self.unselectedExamples = []
38        self.bottomPixmap = QPixmap(os.path.join(orngEnviron.directoryNames["widgetDir"], "icons/upgreenarrow.png"))
39        self.topPixmap = QPixmap(os.path.join(orngEnviron.directoryNames["widgetDir"], "icons/downgreenarrow.png"))
40
41        self.axisScaleDraw(QwtPlot.xBottom).enableComponent(QwtScaleDraw.Backbone, 0)
42        self.axisScaleDraw(QwtPlot.xBottom).enableComponent(QwtScaleDraw.Ticks, 0)
43        self.axisScaleDraw(QwtPlot.yLeft).enableComponent(QwtScaleDraw.Backbone, 0)
44        self.axisScaleDraw(QwtPlot.yLeft).enableComponent(QwtScaleDraw.Ticks, 0)
45
46    def setData(self, data, subsetData=None, **args):
47        OWGraph.setData(self, data)
48        orngScaleData.setData(self, data, subsetData, **args)
49        self.domainContingency = None
50
51
52    # update shown data. Set attributes, coloring by className ....
53    def updateData(self, attributes, midLabels=None, updateAxisScale=1):
54        self.removeDrawingCurves(removeLegendItems=0, removeMarkers=1)  # don't delete legend items
55        if attributes != self.visualizedAttributes:
56            self.selectionConditions = {}       # reset selections
57
58        self.visualizedAttributes = []
59        self.visualizedMidLabels = []
60        self.selectedExamples = []
61        self.unselectedExamples = []
62
63        if not (self.haveData or self.haveSubsetData):  return
64        if len(attributes) < 2: return
65
66        self.visualizedAttributes = attributes
67        self.visualizedMidLabels = midLabels
68        for name in self.selectionConditions.keys():        # keep only conditions that are related to the currently visualized attributes
69            if name not in self.visualizedAttributes:
70                self.selectionConditions.pop(name)
71
72        # set the limits for panning
73        self.xPanningInfo = (1, 0, len(attributes) - 1)
74        self.yPanningInfo = (
75        0, 0, 0)   # we don't enable panning in y direction so it doesn't matter what values we put in for the limits
76
77        if updateAxisScale:
78            if self.showAttrValues:
79                self.setAxisScale(QwtPlot.yLeft, -0.04, 1.04, 1)
80            else:
81                self.setAxisScale(QwtPlot.yLeft, -0.02, 1.02, 1)
82
83            if self.autoUpdateAxes:
84                if attributes and isinstance(self.dataDomain[attributes[-1]], orange.EnumVariable):
85                    self.setAxisScale(QwtPlot.xBottom, -0.1, len(attributes) - 0.4, 1)
86                else:
87                    self.setAxisScale(QwtPlot.xBottom, -0.1, len(attributes) - 0.9, 1)
88            else:
89                m = self.axisScaleDiv(QwtPlot.xBottom).interval().minValue()
90                M = self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue()
91                if m < 0 or M > len(attributes) - 2:
92                    self.setAxisScale(QwtPlot.xBottom, 0, len(attributes) - 1, 1)
93
94        self.setAxisScaleDraw(QwtPlot.xBottom,
95                              DiscreteAxisScaleDraw([self.getAttributeLabel(attr) for attr in attributes]))
96        #self.setAxisScaleDraw(QwtPlot.yLeft, HiddenScaleDraw())
97        self.setAxisMaxMajor(QwtPlot.xBottom, len(attributes))
98        self.setAxisMaxMinor(QwtPlot.xBottom, 0)
99
100        length = len(attributes)
101        indices = [self.attributeNameIndex[label] for label in attributes]
102
103        xs = range(length)
104        dataSize = len(self.scaledData[0])
105
106        if self.dataHasDiscreteClass:
107            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
108
109
110        # ############################################
111        # draw the data
112        # ############################################
113        subsetIdsToDraw = self.haveSubsetData and dict(
114            [(self.rawSubsetData[i].id, 1) for i in self.getValidSubsetIndices(indices)]) or {}
115        validData = self.getValidList(indices)
116        mainCurves = {}
117        subCurves = {}
118        conditions = dict([(name, attributes.index(name)) for name in self.selectionConditions.keys()])
119
120        for i in range(dataSize):
121            if not validData[i]:
122                continue
123
124            if not self.dataHasClass:
125                newColor = (0, 0, 0)
126            elif self.dataHasContinuousClass:
127                newColor = self.contPalette.getRGB(self.noJitteringScaledData[self.dataClassIndex][i])
128            else:
129                newColor = self.discPalette.getRGB(self.originalData[self.dataClassIndex][i])
130
131            data = [self.scaledData[index][i] for index in indices]
132
133            # if we have selected some conditions and the example does not match it we show it as a subset data
134            if 0 in [data[index] >= self.selectionConditions[name][0] and data[index] <= self.selectionConditions[name][
135                1] for (name, index) in conditions.items()]:
136                alpha = self.alphaValue2
137                curves = subCurves
138                self.unselectedExamples.append(i)
139            # if we have subset data then use alpha2 for main data and alpha for subset data
140            elif self.haveSubsetData and not subsetIdsToDraw.has_key(self.rawData[i].id):
141                alpha = self.alphaValue2
142                curves = subCurves
143                self.unselectedExamples.append(i)
144            else:
145                alpha = self.alphaValue
146                curves = mainCurves
147                self.selectedExamples.append(i)
148                if subsetIdsToDraw.has_key(self.rawData[i].id):
149                    subsetIdsToDraw.pop(self.rawData[i].id)
150
151            newColor += (alpha,)
152
153            if not curves.has_key(newColor):
154                curves[newColor] = []
155            curves[newColor].extend(data)
156
157        # if we have a data subset that contains examples that don't exist in the original dataset we show them here
158        if subsetIdsToDraw != {}:
159            validSubsetData = self.getValidSubsetList(indices)
160
161            for i in range(len(self.rawSubsetData)):
162                if not validSubsetData[i]: continue
163                if not subsetIdsToDraw.has_key(self.rawSubsetData[i].id): continue
164
165                data = [self.scaledSubsetData[index][i] for index in indices]
166                if not self.dataDomain.classVar or self.rawSubsetData[i].getclass().isSpecial():
167                    newColor = (0, 0, 0)
168                elif self.dataHasContinuousClass:
169                    newColor = self.contPalette.getRGB(self.noJitteringScaledSubsetData[self.dataClassIndex][i])
170                else:
171                    newColor = self.discPalette.getRGB(self.originalSubsetData[self.dataClassIndex][i])
172
173                if 0 in [data[index] >= self.selectionConditions[name][0] and data[index] <=
174                        self.selectionConditions[name][1] for (name, index) in conditions.items()]:
175                    newColor += (self.alphaValue2,)
176                    curves = subCurves
177                else:
178                    newColor += (self.alphaValue,)
179                    curves = mainCurves
180
181                if not curves.has_key(newColor):
182                    curves[newColor] = []
183                curves[newColor].extend(data)
184
185        # add main curves
186        keys = mainCurves.keys()
187        keys.sort()     # otherwise the order of curves change when we slide the alpha slider
188        for key in keys:
189            curve = ParallelCoordinatesCurve(len(attributes), mainCurves[key], key)
190            if self.useAntialiasing:
191                curve.setRenderHint(QwtPlotItem.RenderAntialiased)
192            if self.useSplines:
193                curve.setCurveAttribute(QwtPlotCurve.Fitted)
194            #                curve.setCurveFitter(QwtSplineCurveFitter())
195            curve.attach(self)
196
197        # add sub curves
198        keys = subCurves.keys()
199        keys.sort()     # otherwise the order of curves change when we slide the alpha slider
200        for key in keys:
201            curve = ParallelCoordinatesCurve(len(attributes), subCurves[key], key)
202            if self.useAntialiasing:
203                curve.setRenderHint(QwtPlotItem.RenderAntialiased)
204            if self.useSplines:
205                curve.setCurveAttribute(QwtPlotCurve.Fitted)
206            curve.attach(self)
207
208
209
210        # ############################################
211        # do we want to show distributions with discrete attributes
212        if self.showDistributions and self.dataHasDiscreteClass and self.haveData:
213            self.showDistributionValues(validData, indices)
214
215        # ############################################
216        # draw vertical lines that represent attributes
217        for i in range(len(attributes)):
218            self.addCurve("", lineWidth=2, style=QwtPlotCurve.Lines, symbol=QwtSymbol.NoSymbol, xData=[i, i],
219                          yData=[0, 1])
220            if self.showAttrValues == 1:
221                attr = self.dataDomain[attributes[i]]
222                if attr.varType == orange.VarTypes.Continuous:
223                    strVal1 = "%%.%df" % (attr.numberOfDecimals) % (self.attrValues[attr.name][0])
224                    strVal2 = "%%.%df" % (attr.numberOfDecimals) % (self.attrValues[attr.name][1])
225                    align1 = i == 0 and Qt.AlignRight | Qt.AlignBottom or i == len(
226                        attributes) - 1 and Qt.AlignLeft | Qt.AlignBottom or Qt.AlignHCenter | Qt.AlignBottom
227                    align2 = i == 0 and Qt.AlignRight | Qt.AlignTop or i == len(
228                        attributes) - 1 and Qt.AlignLeft | Qt.AlignTop or Qt.AlignHCenter | Qt.AlignTop
229                    self.addMarker(strVal1, i, 0.0 - 0.01, alignment=align1)
230                    self.addMarker(strVal2, i, 1.0 + 0.01, alignment=align2)
231
232                elif attr.varType == orange.VarTypes.Discrete:
233                    attrVals = getVariableValuesSorted(self.dataDomain[attributes[i]])
234                    valsLen = len(attrVals)
235                    for pos in range(len(attrVals)):
236                        # show a rectangle behind the marker
237                        self.addMarker(attrVals[pos], i + 0.01, float(1 + 2 * pos) / float(2 * valsLen),
238                                       alignment=Qt.AlignRight | Qt.AlignVCenter, bold=1, brushColor=Qt.white)
239
240        # ##############################################
241        # show lines that represent standard deviation or quartiles
242        # ##############################################
243        if self.showStatistics and self.haveData:
244            data = []
245            for i in range(length):
246                if self.dataDomain[indices[i]].varType != orange.VarTypes.Continuous:
247                    data.append([()])
248                    continue  # only for continuous attributes
249                array = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
250                                       self.scaledData[indices[i]])  # remove missing values
251
252                if not self.dataHasClass or self.dataHasContinuousClass:    # no class
253                    if self.showStatistics == MEANS:
254                        m = array.mean()
255                        dev = array.std()
256                        data.append([(m - dev, m, m + dev)])
257                    elif self.showStatistics == MEDIAN:
258                        sorted = numpy.sort(array)
259                        if len(sorted) > 0:
260                            data.append([(sorted[int(len(sorted) / 4.0)], sorted[int(len(sorted) / 2.0)],
261                                          sorted[int(len(sorted) * 0.75)])])
262                        else:
263                            data.append([(0, 0, 0)])
264                else:
265                    curr = []
266                    classValues = getVariableValuesSorted(self.dataDomain.classVar)
267                    classValueIndices = getVariableValueIndices(self.dataDomain.classVar)
268                    for c in range(len(classValues)):
269                        scaledVal = ((classValueIndices[classValues[c]] * 2) + 1) / float(2 * len(classValueIndices))
270                        nonMissingValues = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
271                                                          self.noJitteringScaledData[
272                                                              self.dataClassIndex])  # remove missing values
273                        arr_c = numpy.compress(numpy.equal(nonMissingValues, scaledVal), array)
274                        if len(arr_c) == 0:
275                            curr.append((0, 0, 0));
276                            continue
277                        if self.showStatistics == MEANS:
278                            m = arr_c.mean()
279                            dev = arr_c.std()
280                            curr.append((m - dev, m, m + dev))
281                        elif self.showStatistics == MEDIAN:
282                            sorted = numpy.sort(arr_c)
283                            curr.append((sorted[int(len(arr_c) / 4.0)], sorted[int(len(arr_c) / 2.0)],
284                                         sorted[int(len(arr_c) * 0.75)]))
285                    data.append(curr)
286
287            # draw vertical lines
288            for i in range(len(data)):
289                for c in range(len(data[i])):
290                    if data[i][c] == (): continue
291                    x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03
292                    col = QColor(self.discPalette[c])
293                    col.setAlpha(self.alphaValue2)
294                    self.addCurve("", col, col, 3, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x, x, x],
295                                  yData=[data[i][c][0], data[i][c][1], data[i][c][2]], lineWidth=4)
296                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
297                                  yData=[data[i][c][0], data[i][c][0]], lineWidth=4)
298                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
299                                  yData=[data[i][c][1], data[i][c][1]], lineWidth=4)
300                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
301                                  yData=[data[i][c][2], data[i][c][2]], lineWidth=4)
302
303            # draw lines with mean/median values
304            classCount = 1
305            if not self.dataHasClass or self.dataHasContinuousClass:
306                classCount = 1 # no class
307            else:
308                classCount = len(self.dataDomain.classVar.values)
309            for c in range(classCount):
310                diff = - 0.03 * (classCount - 1) / 2.0 + c * 0.03
311                ys = []
312                xs = []
313                for i in range(len(data)):
314                    if data[i] != [()]:
315                        ys.append(data[i][c][1]); xs.append(i + diff)
316                    else:
317                        if len(xs) > 1:
318                            col = QColor(self.discPalette[c])
319                            col.setAlpha(self.alphaValue2)
320                            self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys,
321                                          lineWidth=4)
322                        xs = [];
323                        ys = []
324                col = QColor(self.discPalette[c])
325                col.setAlpha(self.alphaValue2)
326                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, lineWidth=4)
327
328
329        # ##################################################
330        # show labels in the middle of the axis
331        if midLabels:
332            for j in range(len(midLabels)):
333                self.addMarker(midLabels[j], j + 0.5, 1.0, alignment=Qt.AlignCenter | Qt.AlignTop)
334
335        # show the legend
336        if self.enabledLegend == 1 and self.dataHasDiscreteClass:
337            if self.dataDomain.classVar.varType == orange.VarTypes.Discrete:
338                legendKeys = []
339                varValues = getVariableValuesSorted(self.dataDomain.classVar)
340                #self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0,0,0), QColor(0,0,0), 0, symbol = QwtSymbol.NoSymbol, enableLegend = 1)
341                for ind in range(len(varValues)):
342                    #self.addCurve(varValues[ind], self.discPalette[ind], self.discPalette[ind], 15, symbol = QwtSymbol.Rect, enableLegend = 1)
343                    legendKeys.append((varValues[ind], self.discPalette[ind]))
344                if legendKeys != self.oldLegendKeys:
345                    self.oldLegendKeys = legendKeys
346                    self.legend().clear()
347                    self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0, 0, 0), QColor(0, 0, 0), 0,
348                                  symbol=QwtSymbol.NoSymbol, enableLegend=1)
349                    for (name, color) in legendKeys:
350                        self.addCurve(name, color, color, 15, symbol=QwtSymbol.Rect, enableLegend=1)
351            else:
352                l = len(attributes) - 1
353                xs = [l * 1.15, l * 1.20, l * 1.20, l * 1.15]
354                count = 200;
355                height = 1 / 200.
356                for i in range(count):
357                    y = i / float(count)
358                    col = self.contPalette[y]
359                    curve = PolygonCurve(QPen(col), QBrush(col), xData=xs, yData=[y, y, y + height, y + height])
360                    curve.attach(self)
361
362                # add markers for min and max value of color attribute
363                [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name]
364                decimals = self.dataDomain.classVar.numberOfDecimals
365                self.addMarker("%%.%df" % (decimals) % (minVal), xs[0] - l * 0.02, 0.04, Qt.AlignLeft)
366                self.addMarker("%%.%df" % (decimals) % (maxVal), xs[0] - l * 0.02, 1.0 - 0.04, Qt.AlignLeft)
367        else:
368            self.legend().clear()
369            self.oldLegendKeys = []
370
371        self.replot()
372
373
374    # ##########################################
375    # SHOW DISTRIBUTION BAR GRAPH
376    def showDistributionValues(self, validData, indices):
377        # create color table
378        clsCount = len(self.dataDomain.classVar.values)
379        #if clsCount < 1: clsCount = 1.0
380
381        # we create a hash table of possible class values (happens only if we have a discrete class)
382        classValueSorted = getVariableValuesSorted(self.dataDomain.classVar)
383        if self.domainContingency == None:
384            self.domainContingency = orange.DomainContingency(self.rawData)
385
386        maxVal = 1
387        for attr in indices:
388            if self.dataDomain[attr].varType != orange.VarTypes.Discrete:
389                continue
390            if self.dataDomain[attr] == self.dataDomain.classVar:
391                maxVal = max(maxVal, max(orange.Distribution(attr, self.rawData) or [1]))
392            else:
393                maxVal = max(maxVal, max([max(val or [1]) for val in self.domainContingency[attr].values()] or [1]))
394
395        for graphAttrIndex, index in enumerate(indices):
396            attr = self.dataDomain[index]
397            if attr.varType != orange.VarTypes.Discrete: continue
398            if self.dataDomain[index] == self.dataDomain.classVar:
399                contingency = orange.Contingency(self.dataDomain[index], self.dataDomain[index])
400                dist = orange.Distribution(self.dataDomain[index], self.rawData)
401                for val in self.dataDomain[index].values:
402                    contingency[val][val] = dist[val]
403            else:
404                contingency = self.domainContingency[index]
405
406            attrLen = len(attr.values)
407
408            # we create a hash table of variable values and their indices
409            variableValueIndices = getVariableValueIndices(self.dataDomain[index])
410            variableValueSorted = getVariableValuesSorted(self.dataDomain[index])
411
412            # create bar curve
413            for j in range(attrLen):
414                attrVal = variableValueSorted[j]
415                try:
416                    attrValCont = contingency[attrVal]
417                except IndexError, ex:
418                    print >> sys.stderr, ex, attrVal, contingency
419                    continue
420
421                for i in range(clsCount):
422                    clsVal = classValueSorted[i]
423
424                    newColor = QColor(self.discPalette[i])
425                    newColor.setAlpha(self.alphaValue)
426
427                    width = float(attrValCont[clsVal] * 0.5) / float(maxVal)
428                    interval = 1.0 / float(2 * attrLen)
429                    yOff = float(1.0 + 2.0 * j) / float(2 * attrLen)
430                    height = 0.7 / float(clsCount * attrLen)
431
432                    yLowBott = yOff + float(clsCount * height) / 2.0 - i * height
433                    curve = PolygonCurve(QPen(newColor), QBrush(newColor),
434                                         xData=[graphAttrIndex, graphAttrIndex + width, graphAttrIndex + width,
435                                                graphAttrIndex],
436                                         yData=[yLowBott, yLowBott, yLowBott - height, yLowBott - height], tooltip=(
437                        self.dataDomain[index].name, variableValueSorted[j], len(self.rawData),
438                        [(clsVal, attrValCont[clsVal]) for clsVal in classValueSorted]))
439                    curve.attach(self)
440
441
442    # handle tooltip events
443    def event(self, ev):
444        if ev.type() == QEvent.ToolTip:
445            x = self.invTransform(QwtPlot.xBottom, ev.pos().x())
446            y = self.invTransform(QwtPlot.yLeft, ev.pos().y())
447
448            canvasPos = self.canvas().mapFrom(self, ev.pos())
449            xFloat = self.invTransform(QwtPlot.xBottom, canvasPos.x())
450            contact, (index, pos) = self.testArrowContact(int(round(xFloat)), canvasPos.x(), canvasPos.y())
451            if contact:
452                attr = self.dataDomain[self.visualizedAttributes[index]]
453                if attr.varType == orange.VarTypes.Continuous:
454                    condition = self.selectionConditions.get(attr.name, [0, 1])
455                    val = self.attrValues[attr.name][0] + condition[pos] * (
456                    self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
457                    strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val)
458                    QToolTip.showText(ev.globalPos(), strVal)
459            else:
460                for curve in self.itemList():
461                    if type(curve) == PolygonCurve and curve.boundingRect().contains(x, y) and getattr(curve, "tooltip",
462                                                                                                       None):
463                        (name, value, total, dist) = curve.tooltip
464                        count = sum([v[1] for v in dist])
465                        if count == 0: continue
466                        tooltipText = "Attribute: <b>%s</b><br>Value: <b>%s</b><br>Total instances: <b>%i</b> (%.1f%%)<br>Class distribution:<br>" % (
467                        name, value, count, 100.0 * count / float(total))
468                        for (val, n) in dist:
469                            tooltipText += "&nbsp; &nbsp; <b>%s</b> : <b>%i</b> (%.1f%%)<br>" % (
470                            val, n, 100.0 * float(n) / float(count))
471                        QToolTip.showText(ev.globalPos(), tooltipText[:-4])
472
473        elif ev.type() == QEvent.MouseMove:
474            QToolTip.hideText()
475
476        return OWGraph.event(self, ev)
477
478
479    def testArrowContact(self, indices, x, y):
480        if type(indices) != list: indices = [indices]
481        for index in indices:
482            if index >= len(self.visualizedAttributes) or index < 0: continue
483            intX = self.transform(QwtPlot.xBottom, index)
484            bottom = self.transform(QwtPlot.yLeft,
485                                    self.selectionConditions.get(self.visualizedAttributes[index], [0, 1])[0])
486            bottomRect = QRect(intX - self.bottomPixmap.width() / 2, bottom, self.bottomPixmap.width(),
487                               self.bottomPixmap.height())
488            if bottomRect.contains(QPoint(x, y)): return 1, (index, 0)
489            top = self.transform(QwtPlot.yLeft,
490                                 self.selectionConditions.get(self.visualizedAttributes[index], [0, 1])[1])
491            topRect = QRect(intX - self.topPixmap.width() / 2, top - self.topPixmap.height(), self.topPixmap.width(),
492                            self.topPixmap.height())
493            if topRect.contains(QPoint(x, y)): return 1, (index, 1)
494        return 0, (0, 0)
495
496    def mousePressEvent(self, e):
497        canvasPos = self.canvas().mapFrom(self, e.pos())
498        xFloat = self.invTransform(QwtPlot.xBottom, canvasPos.x())
499        contact, info = self.testArrowContact(int(round(xFloat)), canvasPos.x(), canvasPos.y())
500
501        if contact:
502            self.pressedArrow = info
503        elif self.state in [ZOOMING, PANNING]:
504            OWGraph.mousePressEvent(self, e)
505
506
507    def mouseMoveEvent(self, e):
508        if hasattr(self, "pressedArrow"):
509            canvasPos = self.canvas().mapFrom(self, e.pos())
510            yFloat = min(1, max(0, self.invTransform(QwtPlot.yLeft, canvasPos.y())))
511            index, pos = self.pressedArrow
512            attr = self.dataDomain[self.visualizedAttributes[index]]
513            oldCondition = self.selectionConditions.get(attr.name, [0, 1])
514            oldCondition[pos] = yFloat
515            self.selectionConditions[attr.name] = oldCondition
516            self.updateData(self.visualizedAttributes, self.visualizedMidLabels, updateAxisScale=0)
517
518            if attr.varType == orange.VarTypes.Continuous:
519                val = self.attrValues[attr.name][0] + oldCondition[pos] * (
520                self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
521                strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val)
522                QToolTip.showText(e.globalPos(), strVal)
523            if self.sendSelectionOnUpdate and self.autoSendSelectionCallback:
524                self.autoSendSelectionCallback()
525
526        elif self.state in [ZOOMING, PANNING]:
527            OWGraph.mouseMoveEvent(self, e)
528
529    def mouseReleaseEvent(self, e):
530        if hasattr(self, "pressedArrow"):
531            del self.pressedArrow
532            if self.autoSendSelectionCallback and not (self.sendSelectionOnUpdate and self.autoSendSelectionCallback):
533                self.autoSendSelectionCallback() # send the new selection
534        elif self.state in [ZOOMING, PANNING]:
535            OWGraph.mouseReleaseEvent(self, e)
536
537
538    def staticMouseClick(self, e):
539        if e.button() == Qt.LeftButton and self.state == ZOOMING:
540            if self.tempSelectionCurve: self.tempSelectionCurve.detach()
541            self.tempSelectionCurve = None
542            canvasPos = self.canvas().mapFrom(self, e.pos())
543            x = self.invTransform(QwtPlot.xBottom, canvasPos.x())
544            y = self.invTransform(QwtPlot.yLeft, canvasPos.y())
545            diffX = (self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue() - self.axisScaleDiv(
546                QwtPlot.xBottom).interval().minValue()) / 2.
547
548            xmin = x - (diffX / 2.) * (x - self.axisScaleDiv(QwtPlot.xBottom).interval().minValue()) / diffX
549            xmax = x + (diffX / 2.) * (self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue() - x) / diffX
550            ymin = self.axisScaleDiv(QwtPlot.yLeft).interval().maxValue()
551            ymax = self.axisScaleDiv(QwtPlot.yLeft).interval().minValue()
552
553            self.zoomStack.append((self.axisScaleDiv(QwtPlot.xBottom).interval().minValue(),
554                                   self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue(),
555                                   self.axisScaleDiv(QwtPlot.yLeft).interval().minValue(),
556                                   self.axisScaleDiv(QwtPlot.yLeft).interval().maxValue()))
557            self.setNewZoom(xmin, xmax, ymax, ymin)
558            return 1
559
560        # if the user clicked between two lines send a list with the names of the two attributes
561        elif self.parallelDlg:
562            x1 = int(self.invTransform(QwtPlot.xBottom, e.x()))
563            axis = self.axisScaleDraw(QwtPlot.xBottom)
564            self.parallelDlg.sendShownAttributes([str(axis.label(x1)), str(axis.label(x1 + 1))])
565        return 0
566
567    def removeAllSelections(self, send=1):
568        self.selectionConditions = {}
569        self.updateData(self.visualizedAttributes, self.visualizedMidLabels, updateAxisScale=0)
570        if send and self.autoSendSelectionCallback:
571            self.autoSendSelectionCallback() # do we want to send new selection
572
573    # draw the curves and the selection conditions
574    def drawCanvas(self, painter):
575        OWGraph.drawCanvas(self, painter)
576        for i in range(int(max(0, math.floor(self.axisScaleDiv(QwtPlot.xBottom).interval().minValue()))), int(
577                min(len(self.visualizedAttributes),
578                    math.ceil(self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue()) + 1))):
579            bottom, top = self.selectionConditions.get(self.visualizedAttributes[i], (0, 1))
580            painter.drawPixmap(self.transform(QwtPlot.xBottom, i) - self.bottomPixmap.width() / 2,
581                               self.transform(QwtPlot.yLeft, bottom), self.bottomPixmap)
582            painter.drawPixmap(self.transform(QwtPlot.xBottom, i) - self.topPixmap.width() / 2,
583                               self.transform(QwtPlot.yLeft, top) - self.topPixmap.height(), self.topPixmap)
584
585    # get selected examples
586    # this function must be called after calling self.updateGraph
587    def getSelectionsAsExampleTables(self):
588        if not self.haveData:
589            return (None, None)
590
591        selected = self.rawData.getitemsref(self.selectedExamples)
592        unselected = self.rawData.getitemsref(self.unselectedExamples)
593
594        if len(selected) == 0: selected = None
595        if len(unselected) == 0: unselected = None
596        return (selected, unselected)
597
598
599# ####################################################################
600# a curve that is able to draw several series of lines
601class ParallelCoordinatesCurve(QwtPlotCurve):
602    def __init__(self, attrCount, yData, color, name=""):
603        QwtPlotCurve.__init__(self, name)
604        self.setStyle(QwtPlotCurve.Lines)
605        self.setItemAttribute(QwtPlotItem.Legend, 0)
606
607        lineCount = len(yData) / attrCount
608        self.attrCount = attrCount
609        self.xData = range(attrCount) * lineCount
610        self.yData = yData
611
612        #        self._cubic = self.cubicPath(None, None)
613
614        self.setData(QPolygonF(map(lambda t: QPointF(*t), zip(self.xData, self.yData))))
615        if type(color) == tuple:
616            self.setPen(QPen(QColor(*color)))
617        else:
618            self.setPen(QPen(QColor(color)))
619
620
621    def drawCurve(self, painter, style, xMap, yMap, iFrom, iTo):
622        low = max(0, int(math.floor(xMap.s1())))
623        high = min(self.attrCount - 1, int(math.ceil(xMap.s2())))
624        painter.setPen(self.pen())
625        if not self.testCurveAttribute(QwtPlotCurve.Fitted):
626            for i in range(self.dataSize() / self.attrCount):
627                start = self.attrCount * i + low
628                end = self.attrCount * i + high
629                self.drawLines(painter, xMap, yMap, start, end)
630        else:
631            painter.save()
632            #            painter.scale(xMap.transform(1.0), yMap.transform(1.0))
633            painter.strokePath(self.cubicPath(xMap, yMap), self.pen())
634            #            painter.strokePath(self._cubic, self.pen())
635            painter.restore()
636
637    def cubicPath(self, xMap, yMap):
638        path = QPainterPath()
639        transform = lambda x, y: QPointF(xMap.transform(x), yMap.transform(y))
640        #        transform = lambda x, y: QPointF(x, y)
641        #        data = [QPointF(transform(x, y)) for x, y in zip(self.xData, self.yData)]
642        data = [(x, y) for x, y in zip(self.xData, self.yData)]
643        for i in range(self.dataSize() / self.attrCount):
644            segment = data[i * self.attrCount: (i + 1) * self.attrCount]
645            for i, p in enumerate(segment[:-1]):
646                x1, y1 = p
647                x2, y2 = segment[i + 1]
648                path.moveTo(transform(x1, y1))
649                path.cubicTo(transform(x1 + 0.5, y1), transform(x2 - 0.5, y2), transform(x2, y2))
650        return path       
651               
652               
653           
654
Note: See TracBrowser for help on using the repository browser.