Changeset 11824:dcbaa22c3634 in orange


Ignore:
Timestamp:
01/02/14 13:08:35 (4 months ago)
Author:
astaric <anze.staric@…>
Branch:
default
committer:
astaric <anze.staric@gmail.com> 1388665309 -3600
Message:

Refactor OWParallelGraph.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/OrangeWidgets/Visualize/OWParallelGraph.py

    r11823 r11824  
    4949        self.domainContingency = None 
    5050 
    51  
    5251    # update shown data. Set attributes, coloring by className .... 
    5352    def updateData(self, attributes, midLabels=None, updateAxisScale=1): 
     
    6665        self.visualizedAttributes = attributes 
    6766        self.visualizedMidLabels = midLabels 
    68         for name in self.selectionConditions.keys():        # keep only conditions that are related to the currently visualized attributes 
     67        self.filter_stale_conditions() 
     68 
     69        # set the limits for panning 
     70        self.xPanningInfo = (1, 0, len(attributes) - 1) 
     71        self.yPanningInfo = (0, 0, 0) 
     72 
     73        self.update_scale(attributes, updateAxisScale) 
     74 
     75        length = len(attributes) 
     76        indices = [self.attributeNameIndex[label] for label in attributes] 
     77 
     78        xs = range(length) 
     79        dataSize = len(self.scaledData[0]) 
     80 
     81        if self.dataHasDiscreteClass: 
     82            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values)) 
     83 
     84        validData = self.getValidList(indices) 
     85 
     86        self.draw_lines(attributes, dataSize, indices, validData) 
     87 
     88        if self.showDistributions and self.dataHasDiscreteClass and self.haveData: 
     89            self.draw_distributions(validData, indices) 
     90 
     91        self.draw_axes(attributes) 
     92 
     93        if self.showStatistics and self.haveData: 
     94            self.draw_statistics(indices, length) 
     95 
     96        if midLabels: 
     97            self.draw_midlabels(midLabels) 
     98 
     99        if self.enabledLegend == 1 and self.dataHasDiscreteClass: 
     100            self.draw_legend(attributes) 
     101        else: 
     102            self.legend().clear() 
     103            self.oldLegendKeys = [] 
     104 
     105        self.replot() 
     106 
     107    def filter_stale_conditions(self): 
     108        for name in self.selectionConditions.keys(): 
    69109            if name not in self.visualizedAttributes: 
    70110                self.selectionConditions.pop(name) 
    71111 
    72         # set the limits for panning 
    73         self.xPanningInfo = (1, 0, len(attributes) - 1) 
    74         self.yPanningInfo = ( 
    75         0, 0, 0)   # we don't enable panning in y direction so it doesn't matter what values we put in for the limits 
    76  
     112    def update_scale(self, attributes, updateAxisScale): 
    77113        if updateAxisScale: 
    78114            if self.showAttrValues: 
     
    91127                if m < 0 or M > len(attributes) - 2: 
    92128                    self.setAxisScale(QwtPlot.xBottom, 0, len(attributes) - 1, 1) 
    93  
    94129        self.setAxisScaleDraw(QwtPlot.xBottom, 
    95130                              DiscreteAxisScaleDraw([self.getAttributeLabel(attr) for attr in attributes])) 
     
    98133        self.setAxisMaxMinor(QwtPlot.xBottom, 0) 
    99134 
    100         length = len(attributes) 
    101         indices = [self.attributeNameIndex[label] for label in attributes] 
    102  
    103         xs = range(length) 
    104         dataSize = len(self.scaledData[0]) 
    105  
    106         if self.dataHasDiscreteClass: 
    107             self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values)) 
    108  
    109  
    110         # ############################################ 
    111         # draw the data 
    112         # ############################################ 
     135    def draw_lines(self, attributes, dataSize, indices, validData): 
    113136        subsetIdsToDraw = self.haveSubsetData and dict( 
    114137            [(self.rawSubsetData[i].id, 1) for i in self.getValidSubsetIndices(indices)]) or {} 
    115         validData = self.getValidList(indices) 
    116138        mainCurves = {} 
    117139        subCurves = {} 
    118140        conditions = dict([(name, attributes.index(name)) for name in self.selectionConditions.keys()]) 
    119  
    120141        for i in range(dataSize): 
    121142            if not validData[i]: 
     
    206227            curve.attach(self) 
    207228 
    208  
    209  
    210         # ############################################ 
    211         # do we want to show distributions with discrete attributes 
    212         if self.showDistributions and self.dataHasDiscreteClass and self.haveData: 
    213             self.showDistributionValues(validData, indices) 
    214  
    215         # ############################################ 
    216         # draw vertical lines that represent attributes 
     229    def draw_axes(self, attributes): 
    217230        for i in range(len(attributes)): 
    218231            self.addCurve("", lineWidth=2, style=QwtPlotCurve.Lines, symbol=QwtSymbol.NoSymbol, xData=[i, i], 
     
    238251                                       alignment=Qt.AlignRight | Qt.AlignVCenter, bold=1, brushColor=Qt.white) 
    239252 
    240         # ############################################## 
    241         # show lines that represent standard deviation or quartiles 
    242         # ############################################## 
    243         if self.showStatistics and self.haveData: 
    244             data = [] 
    245             for i in range(length): 
    246                 if self.dataDomain[indices[i]].varType != orange.VarTypes.Continuous: 
    247                     data.append([()]) 
    248                     continue  # only for continuous attributes 
    249                 array = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1), 
    250                                        self.scaledData[indices[i]])  # remove missing values 
    251  
    252                 if not self.dataHasClass or self.dataHasContinuousClass:    # no class 
     253    def draw_midlabels(self, midLabels): 
     254        for j in range(len(midLabels)): 
     255            self.addMarker(midLabels[j], j + 0.5, 1.0, alignment=Qt.AlignCenter | Qt.AlignTop) 
     256 
     257    def draw_statistics(self, indices, length): 
     258        data = [] 
     259        for i in range(length): 
     260            if self.dataDomain[indices[i]].varType != orange.VarTypes.Continuous: 
     261                data.append([()]) 
     262                continue  # only for continuous attributes 
     263            array = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1), 
     264                                   self.scaledData[indices[i]])  # remove missing values 
     265 
     266            if not self.dataHasClass or self.dataHasContinuousClass:    # no class 
     267                if self.showStatistics == MEANS: 
     268                    m = array.mean() 
     269                    dev = array.std() 
     270                    data.append([(m - dev, m, m + dev)]) 
     271                elif self.showStatistics == MEDIAN: 
     272                    sorted = numpy.sort(array) 
     273                    if len(sorted) > 0: 
     274                        data.append([(sorted[int(len(sorted) / 4.0)], sorted[int(len(sorted) / 2.0)], 
     275                                      sorted[int(len(sorted) * 0.75)])]) 
     276                    else: 
     277                        data.append([(0, 0, 0)]) 
     278            else: 
     279                curr = [] 
     280                classValues = getVariableValuesSorted(self.dataDomain.classVar) 
     281                classValueIndices = getVariableValueIndices(self.dataDomain.classVar) 
     282                for c in range(len(classValues)): 
     283                    scaledVal = ((classValueIndices[classValues[c]] * 2) + 1) / float(2 * len(classValueIndices)) 
     284                    nonMissingValues = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1), 
     285                                                      self.noJitteringScaledData[ 
     286                                                          self.dataClassIndex])  # remove missing values 
     287                    arr_c = numpy.compress(numpy.equal(nonMissingValues, scaledVal), array) 
     288                    if len(arr_c) == 0: 
     289                        curr.append((0, 0, 0)); 
     290                        continue 
    253291                    if self.showStatistics == MEANS: 
    254                         m = array.mean() 
    255                         dev = array.std() 
    256                         data.append([(m - dev, m, m + dev)]) 
     292                        m = arr_c.mean() 
     293                        dev = arr_c.std() 
     294                        curr.append((m - dev, m, m + dev)) 
    257295                    elif self.showStatistics == MEDIAN: 
    258                         sorted = numpy.sort(array) 
    259                         if len(sorted) > 0: 
    260                             data.append([(sorted[int(len(sorted) / 4.0)], sorted[int(len(sorted) / 2.0)], 
    261                                           sorted[int(len(sorted) * 0.75)])]) 
    262                         else: 
    263                             data.append([(0, 0, 0)]) 
    264                 else: 
    265                     curr = [] 
    266                     classValues = getVariableValuesSorted(self.dataDomain.classVar) 
    267                     classValueIndices = getVariableValueIndices(self.dataDomain.classVar) 
    268                     for c in range(len(classValues)): 
    269                         scaledVal = ((classValueIndices[classValues[c]] * 2) + 1) / float(2 * len(classValueIndices)) 
    270                         nonMissingValues = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1), 
    271                                                           self.noJitteringScaledData[ 
    272                                                               self.dataClassIndex])  # remove missing values 
    273                         arr_c = numpy.compress(numpy.equal(nonMissingValues, scaledVal), array) 
    274                         if len(arr_c) == 0: 
    275                             curr.append((0, 0, 0)); 
    276                             continue 
    277                         if self.showStatistics == MEANS: 
    278                             m = arr_c.mean() 
    279                             dev = arr_c.std() 
    280                             curr.append((m - dev, m, m + dev)) 
    281                         elif self.showStatistics == MEDIAN: 
    282                             sorted = numpy.sort(arr_c) 
    283                             curr.append((sorted[int(len(arr_c) / 4.0)], sorted[int(len(arr_c) / 2.0)], 
    284                                          sorted[int(len(arr_c) * 0.75)])) 
    285                     data.append(curr) 
    286  
    287             # draw vertical lines 
    288             for i in range(len(data)): 
    289                 for c in range(len(data[i])): 
    290                     if data[i][c] == (): continue 
    291                     x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03 
    292                     col = QColor(self.discPalette[c]) 
    293                     col.setAlpha(self.alphaValue2) 
    294                     self.addCurve("", col, col, 3, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x, x, x], 
    295                                   yData=[data[i][c][0], data[i][c][1], data[i][c][2]], lineWidth=4) 
    296                     self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
    297                                   yData=[data[i][c][0], data[i][c][0]], lineWidth=4) 
    298                     self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
    299                                   yData=[data[i][c][1], data[i][c][1]], lineWidth=4) 
    300                     self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
    301                                   yData=[data[i][c][2], data[i][c][2]], lineWidth=4) 
    302  
    303             # draw lines with mean/median values 
    304             classCount = 1 
    305             if not self.dataHasClass or self.dataHasContinuousClass: 
    306                 classCount = 1 # no class 
    307             else: 
    308                 classCount = len(self.dataDomain.classVar.values) 
    309             for c in range(classCount): 
    310                 diff = - 0.03 * (classCount - 1) / 2.0 + c * 0.03 
    311                 ys = [] 
    312                 xs = [] 
    313                 for i in range(len(data)): 
    314                     if data[i] != [()]: 
    315                         ys.append(data[i][c][1]); xs.append(i + diff) 
    316                     else: 
    317                         if len(xs) > 1: 
    318                             col = QColor(self.discPalette[c]) 
    319                             col.setAlpha(self.alphaValue2) 
    320                             self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, 
    321                                           lineWidth=4) 
    322                         xs = []; 
    323                         ys = [] 
     296                        sorted = numpy.sort(arr_c) 
     297                        curr.append(( 
     298                            sorted[int(len(arr_c) / 4.0)], sorted[int(len(arr_c) / 2.0)], 
     299                            sorted[int(len(arr_c) * 0.75)])) 
     300                data.append(curr) 
     301 
     302        # draw vertical lines 
     303        for i in range(len(data)): 
     304            for c in range(len(data[i])): 
     305                if data[i][c] == (): continue 
     306                x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03 
    324307                col = QColor(self.discPalette[c]) 
    325308                col.setAlpha(self.alphaValue2) 
    326                 self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, lineWidth=4) 
    327  
    328  
    329         # ################################################## 
    330         # show labels in the middle of the axis 
    331         if midLabels: 
    332             for j in range(len(midLabels)): 
    333                 self.addMarker(midLabels[j], j + 0.5, 1.0, alignment=Qt.AlignCenter | Qt.AlignTop) 
    334  
    335         # show the legend 
    336         if self.enabledLegend == 1 and self.dataHasDiscreteClass: 
    337             if self.dataDomain.classVar.varType == orange.VarTypes.Discrete: 
    338                 legendKeys = [] 
    339                 varValues = getVariableValuesSorted(self.dataDomain.classVar) 
    340                 #self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0,0,0), QColor(0,0,0), 0, symbol = QwtSymbol.NoSymbol, enableLegend = 1) 
    341                 for ind in range(len(varValues)): 
    342                     #self.addCurve(varValues[ind], self.discPalette[ind], self.discPalette[ind], 15, symbol = QwtSymbol.Rect, enableLegend = 1) 
    343                     legendKeys.append((varValues[ind], self.discPalette[ind])) 
    344                 if legendKeys != self.oldLegendKeys: 
    345                     self.oldLegendKeys = legendKeys 
    346                     self.legend().clear() 
    347                     self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0, 0, 0), QColor(0, 0, 0), 0, 
    348                                   symbol=QwtSymbol.NoSymbol, enableLegend=1) 
    349                     for (name, color) in legendKeys: 
    350                         self.addCurve(name, color, color, 15, symbol=QwtSymbol.Rect, enableLegend=1) 
    351             else: 
    352                 l = len(attributes) - 1 
    353                 xs = [l * 1.15, l * 1.20, l * 1.20, l * 1.15] 
    354                 count = 200; 
    355                 height = 1 / 200. 
    356                 for i in range(count): 
    357                     y = i / float(count) 
    358                     col = self.contPalette[y] 
    359                     curve = PolygonCurve(QPen(col), QBrush(col), xData=xs, yData=[y, y, y + height, y + height]) 
    360                     curve.attach(self) 
    361  
    362                 # add markers for min and max value of color attribute 
    363                 [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name] 
    364                 decimals = self.dataDomain.classVar.numberOfDecimals 
    365                 self.addMarker("%%.%df" % (decimals) % (minVal), xs[0] - l * 0.02, 0.04, Qt.AlignLeft) 
    366                 self.addMarker("%%.%df" % (decimals) % (maxVal), xs[0] - l * 0.02, 1.0 - 0.04, Qt.AlignLeft) 
     309                self.addCurve("", col, col, 3, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x, x, x], 
     310                              yData=[data[i][c][0], data[i][c][1], data[i][c][2]], lineWidth=4) 
     311                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
     312                              yData=[data[i][c][0], data[i][c][0]], lineWidth=4) 
     313                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
     314                              yData=[data[i][c][1], data[i][c][1]], lineWidth=4) 
     315                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03], 
     316                              yData=[data[i][c][2], data[i][c][2]], lineWidth=4) 
     317 
     318        # draw lines with mean/median values 
     319        classCount = 1 
     320        if not self.dataHasClass or self.dataHasContinuousClass: 
     321            classCount = 1 # no class 
    367322        else: 
    368             self.legend().clear() 
    369             self.oldLegendKeys = [] 
    370  
    371         self.replot() 
    372  
     323            classCount = len(self.dataDomain.classVar.values) 
     324        for c in range(classCount): 
     325            diff = - 0.03 * (classCount - 1) / 2.0 + c * 0.03 
     326            ys = [] 
     327            xs = [] 
     328            for i in range(len(data)): 
     329                if data[i] != [()]: 
     330                    ys.append(data[i][c][1]); 
     331                    xs.append(i + diff) 
     332                else: 
     333                    if len(xs) > 1: 
     334                        col = QColor(self.discPalette[c]) 
     335                        col.setAlpha(self.alphaValue2) 
     336                        self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, 
     337                                      lineWidth=4) 
     338                    xs = []; 
     339                    ys = [] 
     340            col = QColor(self.discPalette[c]) 
     341            col.setAlpha(self.alphaValue2) 
     342            self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, lineWidth=4) 
     343 
     344    def draw_legend(self, attributes): 
     345        if self.dataDomain.classVar.varType == orange.VarTypes.Discrete: 
     346            legendKeys = [] 
     347            varValues = getVariableValuesSorted(self.dataDomain.classVar) 
     348            #self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0,0,0), QColor(0,0,0), 0, symbol = QwtSymbol.NoSymbol, enableLegend = 1) 
     349            for ind in range(len(varValues)): 
     350                #self.addCurve(varValues[ind], self.discPalette[ind], self.discPalette[ind], 15, symbol = QwtSymbol.Rect, enableLegend = 1) 
     351                legendKeys.append((varValues[ind], self.discPalette[ind])) 
     352            if legendKeys != self.oldLegendKeys: 
     353                self.oldLegendKeys = legendKeys 
     354                self.legend().clear() 
     355                self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0, 0, 0), QColor(0, 0, 0), 0, 
     356                              symbol=QwtSymbol.NoSymbol, enableLegend=1) 
     357                for (name, color) in legendKeys: 
     358                    self.addCurve(name, color, color, 15, symbol=QwtSymbol.Rect, enableLegend=1) 
     359        else: 
     360            l = len(attributes) - 1 
     361            xs = [l * 1.15, l * 1.20, l * 1.20, l * 1.15] 
     362            count = 200; 
     363            height = 1 / 200. 
     364            for i in range(count): 
     365                y = i / float(count) 
     366                col = self.contPalette[y] 
     367                curve = PolygonCurve(QPen(col), QBrush(col), xData=xs, yData=[y, y, y + height, y + height]) 
     368                curve.attach(self) 
     369 
     370            # add markers for min and max value of color attribute 
     371            [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name] 
     372            decimals = self.dataDomain.classVar.numberOfDecimals 
     373            self.addMarker("%%.%df" % (decimals) % (minVal), xs[0] - l * 0.02, 0.04, Qt.AlignLeft) 
     374            self.addMarker("%%.%df" % (decimals) % (maxVal), xs[0] - l * 0.02, 1.0 - 0.04, Qt.AlignLeft) 
    373375 
    374376    # ########################################## 
    375377    # SHOW DISTRIBUTION BAR GRAPH 
    376     def showDistributionValues(self, validData, indices): 
     378    def draw_distributions(self, validData, indices): 
    377379        # create color table 
    378380        clsCount = len(self.dataDomain.classVar.values) 
     
    431433 
    432434                    yLowBott = yOff + float(clsCount * height) / 2.0 - i * height 
    433                     curve = PolygonCurve(QPen(newColor), QBrush(newColor), 
    434                                          xData=[graphAttrIndex, graphAttrIndex + width, graphAttrIndex + width, 
    435                                                 graphAttrIndex], 
    436                                          yData=[yLowBott, yLowBott, yLowBott - height, yLowBott - height], tooltip=( 
    437                         self.dataDomain[index].name, variableValueSorted[j], len(self.rawData), 
    438                         [(clsVal, attrValCont[clsVal]) for clsVal in classValueSorted])) 
     435                    xData = [graphAttrIndex, graphAttrIndex + width, graphAttrIndex + width, graphAttrIndex] 
     436                    yData = [yLowBott, yLowBott, yLowBott - height, yLowBott - height] 
     437                    tooltip = (self.dataDomain[index].name, variableValueSorted[j], len(self.rawData), 
     438                               [(clsVal, attrValCont[clsVal]) for clsVal in classValueSorted]) 
     439                    curve = PolygonCurve(QPen(newColor), QBrush(newColor), xData, yData, tooltip) 
    439440                    curve.attach(self) 
    440441 
     
    454455                    condition = self.selectionConditions.get(attr.name, [0, 1]) 
    455456                    val = self.attrValues[attr.name][0] + condition[pos] * ( 
    456                     self.attrValues[attr.name][1] - self.attrValues[attr.name][0]) 
     457                        self.attrValues[attr.name][1] - self.attrValues[attr.name][0]) 
    457458                    strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val) 
    458459                    QToolTip.showText(ev.globalPos(), strVal) 
     
    465466                        if count == 0: continue 
    466467                        tooltipText = "Attribute: <b>%s</b><br>Value: <b>%s</b><br>Total instances: <b>%i</b> (%.1f%%)<br>Class distribution:<br>" % ( 
    467                         name, value, count, 100.0 * count / float(total)) 
     468                            name, value, count, 100.0 * count / float(total)) 
    468469                        for (val, n) in dist: 
    469470                            tooltipText += "&nbsp; &nbsp; <b>%s</b> : <b>%i</b> (%.1f%%)<br>" % ( 
    470                             val, n, 100.0 * float(n) / float(count)) 
     471                                val, n, 100.0 * float(n) / float(count)) 
    471472                        QToolTip.showText(ev.globalPos(), tooltipText[:-4]) 
    472473 
     
    518519            if attr.varType == orange.VarTypes.Continuous: 
    519520                val = self.attrValues[attr.name][0] + oldCondition[pos] * ( 
    520                 self.attrValues[attr.name][1] - self.attrValues[attr.name][0]) 
     521                    self.attrValues[attr.name][1] - self.attrValues[attr.name][0]) 
    521522                strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val) 
    522523                QToolTip.showText(e.globalPos(), strVal) 
Note: See TracChangeset for help on using the changeset viewer.