source: orange/orange/OrangeWidgets/Prototypes/OWPivot.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 12.2 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""<name>Pivot</name>
2<description>Pivot</description>
3<icon>icons/Pivot.png</icon>
4<priority>30</priority>
5<contact>Janez Demsar (janez.demsar@fri.uni-lj.si)</contact>"""
6
7from OWWidget import *
8from OWGUI import *
9
10disctype, conttype = orange.Variable.Discrete, orange.Variable.Continuous
11 
12class OWPivot(OWWidget):
13    contextHandlers = {"": DomainContextHandler("", ["rowAttribute", "columnAttribute", "attribute"])}
14    # settingsList is computed from aggregates in __init__
15    aggregates = [("Count", "count", [disctype, conttype]),
16                  ("Sum", "sum" ,[conttype]), 
17                  ("Average", "average", [conttype]),
18                  ("Deviation", "deviation", [conttype]),
19                  ("Minimum", "minimum", [conttype]),
20                  ("Maximum", "maximum", [conttype]),
21                  ("Most common", "mostCommon", [disctype]),
22                  ("Distribution", "distribution", [disctype]),
23                  ("Relative frequencies", "frequencies", [disctype])]
24   
25    def __init__(self,parent=None, signalManager = None):
26        OWWidget.__init__(self, parent, signalManager, "Pivot")
27        self.inputs = [("Data", ExampleTable, self.setData, Default)]
28        self.outputs = [("Selected Group", ExampleTable, Default)]
29        self.settingsList = [a[1] for a in self.aggregates]
30        self.icons = self.createAttributeIconDict() 
31
32        self.rowAttribute = self.columnAttribute = self.attribute = ""
33        for agg, fld, flag in self.aggregates:
34            setattr(self, fld, fld=="count")
35        self.resize(640, 480)
36        self.loadSettings()
37
38        self.rowCombo = OWGUI.comboBox(self.controlArea, self, "rowAttribute", box="Row attribute", callback=self.rowColAttributeChanged,
39                                       sendSelectedValue=1, valueType=str, addSpace=True)
40        self.colCombo = OWGUI.comboBox(self.controlArea, self, "columnAttribute", box="Column attribute", callback=self.rowColAttributeChanged,
41                                       sendSelectedValue=1, valueType=str, addSpace=True)
42        b = OWGUI.widgetBox(self.controlArea, box="Content")
43        self.attrCombo = OWGUI.comboBox(b, self, "attribute", callback=self.attributeChanged, sendSelectedValue=1, valueType=str, addSpace=True)
44        self.checkBoxes = []
45        for agg, fld, flag in self.aggregates:
46            cb = OWGUI.checkBox(b, self, fld, agg, callback=self.updateMatrix)
47            self.checkBoxes.append((cb, flag))
48        OWGUI.rubber(self.controlArea)
49           
50        self.table = OWGUI.table(self.mainArea, rows=0, columns=0, selectionMode=QTableWidget.SingleSelection)
51        self.table.horizontalHeader().hide()
52        self.table.verticalHeader().hide()
53        self.table.setGridStyle(Qt.DotLine)
54        self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.selectionChanged)
55
56    def setData(self, data):
57        self.closeContext()
58        self.rowCombo.clear()
59        self.colCombo.clear()
60        self.attrCombo.clear()
61        self.data = data
62        if data:
63            self.discAttrs = [attr for attr in data.domain if attr.varType == orange.Variable.Discrete]
64            for attr in self.discAttrs:
65                self.rowCombo.addItem(self.icons[attr.varType], attr.name)
66                self.colCombo.addItem(self.icons[attr.varType], attr.name)
67            for attr in data.domain:
68                self.attrCombo.addItem(self.icons[attr.varType], attr.name)
69            if self.discAttrs:
70                self.rowAttribute = self.discAttrs[0].name
71                self.columnAttribute = self.discAttrs[min(1, len(self.discAttrs))].name
72            if self.attrCombo.count():
73                self.attribute = self.data.domain[0].name
74        self.openContext("", self.data)
75        self.rowColAttributeChanged()
76       
77    def rowColAttributeChanged(self):
78        if self.rowCombo.count():
79            self.realRowAttr = self.data.domain[self.rowAttribute]
80            self.realColAttr = self.data.domain[self.columnAttribute]
81            ncols = len(self.realColAttr.values)
82            nrows = len(self.realRowAttr.values)
83            self.subsets = [[orange.ExampleTable(self.data.domain) for i in range(ncols+1)] for j in range(nrows+1)]
84            for ex in self.data:
85                rowVal, colVal =ex[self.realRowAttr], ex[self.realColAttr]
86                if not (rowVal.isSpecial() or colVal.isSpecial()):
87                    self.subsets[int(rowVal)][int(colVal)].append(ex)
88                    self.subsets[int(rowVal)][ncols].append(ex)
89                    self.subsets[nrows][int(colVal)].append(ex)
90                    self.subsets[nrows][ncols].append(ex)
91            self.horizontalValues = list(self.realColAttr.values)
92            self.verticalValues = list(self.realRowAttr.values)
93            self.table.show()
94        else:
95            self.subsets = None
96            self.table.hide()
97        self.attributeChanged()
98
99           
100    def attributeChanged(self):
101        self.stats = None
102        if self.attrCombo.count():
103            self.realAttribute = self.data.domain[self.attribute]
104            for cb, flag in self.checkBoxes:
105                cb.setDisabled(self.realAttribute.varType not in flag)
106            if self.subsets:
107                if self.realAttribute.varType == orange.Variable.Continuous:
108                    self.stats = [[orange.BasicAttrStat(self.realAttribute, examples) for examples in row] for row in self.subsets]
109                else:
110                    self.stats = [[orange.Distribution(self.realAttribute, examples) for examples in row] for row in self.subsets]
111        self.updateMatrix()
112       
113    def rowData(self, shw, row):
114        attr = self.data.domain[self.attribute]
115        if shw == "Count":
116            return [str(stat.n) if type(stat)==orange.BasicAttrStat else str(stat.cases) for stat in row]
117        elif shw == "Sum":
118            return [str(attr(stat.sum)) for stat in row]
119        elif shw == "Average":
120            return [(str(attr(stat.avg)) if stat.n else "") for stat in row]
121        elif shw == "Deviation":
122            return [(str(attr(stat.dev)) if stat.sum > 1 else "") for stat in row]
123        elif shw == "Minimum":
124            return [(str(attr(stat.min)) if stat.n else "") for stat in row]
125        elif shw == "Maximum":
126            return [(str(attr(stat.max)) if stat.n else "") for stat in row]
127        elif shw == "Most common":
128            values = []
129            for stat in row:
130                if stat.cases:
131                    m = stat.modus()
132                    values.append(("%s (%.2f%%)" % (str(m), stat[int(m)]*100./stat.cases)).decode("utf-8"))
133                else:
134                    values.append("")
135            return values
136        elif shw == "Distribution":
137            values = []
138            for stat in row:
139                if stat.cases:
140                    values.append(", ".join("%s:%i" % (str(val), stat[i]) for i, val in enumerate(attr.values)).decode("utf-8"))
141                else:
142                    values.append("")
143            return values
144        elif shw == "Relative frequencies":
145            values = []
146            for stat in row:
147                if stat.cases:
148                    values.append(", ".join("%s:%.2f%%" % (str(val), stat[i]*100./(stat.cases or 1)) for i, val in enumerate(attr.values)).decode("utf-8"))
149                else:
150                    values.append("")
151            return values
152        else:
153            return [""]*len(row)
154
155    def updateMatrix(self):
156        if self.attrCombo.count() and self.stats:
157            self.table.clearSelection()
158            attr = self.data.domain[self.attribute]
159            vtype = attr.varType
160            cont = self.cont = [name for name, field, flag in self.aggregates if getattr(self, field) and vtype in flag]
161            if len(cont)!=1:
162                cont.append("")
163
164            verticalValues = self.verticalValues + ["(Total)"]
165            realRows = len(verticalValues)*len(cont) + (len(cont)==1)
166            self.table.setRowCount(realRows)
167            self.table.setColumnCount(len(self.horizontalValues)+3)
168           
169            totalBrush = QBrush(QColor(240, 240, 240))
170            aggBrush = QBrush(QColor(216, 216, 216))
171            whiteBrush = QBrush(QColor(255, 255, 255))
172            for coli, val in enumerate(["", ""]+self.horizontalValues+["(Total)"]):
173                w = QTableWidgetItem(val)
174                w.setBackground(aggBrush if coli>1 else whiteBrush)
175                w.setTextAlignment(Qt.AlignHCenter | Qt.AlignVCenter)
176                self.table.setItem(0, coli, w)
177
178            rrowi = 1
179            for rowi, row in enumerate(self.stats):
180                lastRow = rowi == len(self.stats)-1
181                for shwi, shw in enumerate(cont):
182                    if lastRow and not shw:
183                        break
184                    w = QTableWidgetItem("" if shwi else verticalValues[rowi])
185                    w.setBackground(whiteBrush if not shw else aggBrush)
186                    w.setTextAlignment(Qt.AlignHCenter | Qt.AlignVCenter)
187                    self.table.setItem(rrowi, 0, w)
188                   
189                    w = QTableWidgetItem(shw)
190                    if shw:
191                        w.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
192                        w.setBackground(aggBrush)
193                    self.table.setItem(rrowi, 1, w)
194                    values = self.rowData(shw, row)
195                    for coli, val in enumerate(values):
196                        w = QTableWidgetItem(val)
197                        w.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
198                        if val and (lastRow or coli == len(values)-1):
199                            w.setBackground(totalBrush)
200                            w.font().setBold(True)
201                        self.table.setItem(rrowi, coli+2, w)
202                    rrowi += 1
203            self.table.resizeColumnsToContents()
204            self.table.resizeRowsToContents()
205       
206    def selectionChanged(self):
207        selected = [(x.row(), x.column()) for x in self.table.selectedIndexes()]
208        if not selected:
209            data = None
210        else:
211            row, column = selected[0]
212            if row and not self.cont[(row-1) % len(self.cont)]:
213                data = None
214            else:
215                dataRow = self.subsets[(row-1)/len(self.cont) if row else -1]
216                data = dataRow[column-2 if column>=2 else -1] 
217        self.send("Selected Group", data)
218
219    def sendReport(self):
220        if self.attrCombo.count() and self.stats:
221            self.startReport()
222            self.reportSettings("Rotation",
223                                [("Rows", self.rowAttribute),
224                                 ("Columns", self.columnAttribute),
225                                 ("Contents", self.attribute)])
226            self.reportSection("Matrix")
227            res = '<table style="border: 0; cell-padding: 3; cell-border: 0">\n'
228   
229            attr = self.data.domain[self.attribute]
230            verticalValues = self.verticalValues + ["(Total)"]
231            res += "<tr><td/><td/>"+"".join('<td style="background-color: #d8d8d8; font-weight: bold">%s</td>' % x for x in self.horizontalValues+["(Total)"])+"</tr>\n"
232           
233            rrowi = 1
234            for rowi, row in enumerate(self.stats):
235                lastRow = rowi == len(self.stats)-1
236                rowcol = 'style="'+('background-color: #f0f0f0' if lastRow else '')+'; text-align: right"'
237                for shwi, shw in enumerate(self.cont):
238                    if lastRow and not shw:
239                        break
240                    res += "<tr>"
241                    if not shw:
242                        res += "<tr><td/></tr>\n"
243                        continue
244                    if not shwi:
245                        res += '<td rowspan="%i" valign="top" style="background-color: #d8d8d8; font-weight: bold">%s</td>' % (len(self.cont)-1 or 1, verticalValues[rowi])
246                    res += '<td style="background-color: #d8d8d8">%s</td>' % shw
247                    values = self.rowData(shw, row)
248                    res += ''.join("<td %s>%s</td>" % (rowcol, val) for val in values[:-1])
249                    res += '<td style="background-color: #f0f0f0; font-weight: bold; text-align: right">%s</td>' % values[-1]
250                    res += '</tr>'
251            res += '</table>'
252            self.reportRaw(res)
Note: See TracBrowser for help on using the repository browser.