source: orange/orange/OrangeWidgets/Classify/OWCN2RulesViewer.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 16.2 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""
2<name>CN2 Rules Viewer</name>
3<description>Viewer of classification rules.</description>
4<icon>icons/CN2RulesViewer.png</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>2120</priority>
7"""
8
9import orngEnviron
10
11import orange, orngCN2
12from OWWidget import *
13from OWPredictions import PyTableModel
14import OWGUI, OWColorPalette
15import sys, re, math
16
17def _toPyObject(variant):
18    val = variant.toPyObject()
19    if isinstance(val, type(NotImplemented)): # PyQt 4.4 converts python int, floats ... to C types
20        qtype = variant.type()
21        if qtype == QVariant.Double:
22            val, ok = variant.toDouble()
23        elif qtype == QVariant.Int:
24            val, ok = variant.toInt()
25        elif qtype == QVariant.LongLong:
26            val, ok = variant.toLongLong()
27        elif qtype == QVariant.String:
28            val = variant.toString()
29    return val
30       
31   
32class DistributionItemDelegate(QStyledItemDelegate):
33    def __init__(self, parent):
34        QStyledItemDelegate.__init__(self, parent)
35
36    def displayText(self, value, locale):
37        dist = value.toPyObject()
38        if isinstance(dist, orange.Distribution):
39            return QString("<" + ",".join(["%.1f" % c for c in dist]) + ">")
40        else:
41            return QStyledItemDelegate.displayText(value, locale)
42       
43    def sizeHint(self, option, index):
44        metrics = QFontMetrics(option.font)
45        height = metrics.lineSpacing() * 2 + 8 # 4 pixel margin
46        width = metrics.width(self.displayText(index.data(Qt.DisplayRole), QLocale())) + 8
47        return QSize(width, height)
48   
49    def paint(self, painter, option, index):
50        dist = index.data(Qt.DisplayRole).toPyObject()
51        rect = option.rect.adjusted(4, 4, -4, -4)
52        rect_w = rect.width() - len([c for c in dist if c]) # This is for the separators in the distribution bar
53        rect_h = rect.height()
54        colors = OWColorPalette.ColorPaletteHSV(len(dist))
55        abs = dist.abs
56        dist_sum = 0
57       
58        painter.save()
59        painter.setFont(option.font)
60       
61        qApp.style().drawPrimitive(QStyle.PE_PanelItemViewRow, option, painter)
62        qApp.style().drawPrimitive(QStyle.PE_PanelItemViewItem, option, painter)
63       
64        showText = getattr(self, "showDistText", True)
65        metrics = QFontMetrics(option.font)
66        drect_h = metrics.height()
67        lineSpacing = metrics.lineSpacing()
68        leading = metrics.leading()
69        distText = self.displayText(index.data(Qt.DisplayRole), QLocale())
70       
71        if option.state & QStyle.State_Selected:
72            color = option.palette.highlightedText().color()
73        else:
74            color = option.palette.text().color()
75#        painter.setBrush(QBrush(color))
76        painter.setPen(QPen(color))
77           
78        if showText:
79            textPos = rect.topLeft()
80            textRect = QRect(textPos, QSize(rect.width(), rect.height() / 2 - leading))
81            painter.drawText(textRect, Qt.AlignHCenter | Qt.AlignBottom, distText)
82           
83        painter.setPen(QPen(Qt.black))
84        painter.translate(QPoint(rect.topLeft().x(), rect.center().y() - (drect_h/2 if not showText else  0)))
85        for i, count in enumerate(dist):
86            if count:
87                color = colors[i]
88                painter.setBrush(color)
89                painter.setRenderHint(QPainter.Antialiasing)
90                width = round(rect_w * float(count) / abs)
91                painter.drawRoundedRect(QRect(1, 3, width, 5), 1, 2)
92                painter.translate(width, 0)
93        painter.restore()
94       
95class MultiLineStringItemDelegate(QStyledItemDelegate):
96    def sizeHint(self, option, index):
97        metrics = QFontMetrics(option.font)
98        text = index.data(Qt.DisplayRole).toString()
99        size = metrics.size(0, text)
100        return QSize(size.width() + 8, size.height() + 8) # 4 pixel margin
101   
102    def paint(self, painter, option, index):
103        text = self.displayText(index.data(Qt.DisplayRole), QLocale())
104        painter.save()
105       
106        qApp.style().drawPrimitive(QStyle.PE_PanelItemViewRow, option, painter)
107        qApp.style().drawPrimitive(QStyle.PE_PanelItemViewItem, option, painter)
108       
109        rect = option.rect.adjusted(4, 4, -4, -4)
110           
111        if option.state & QStyle.State_Selected:
112            color = option.palette.highlightedText().color()
113        else:
114            color = option.palette.text().color()
115#        painter.setBrush(QBrush(color))
116        painter.setPen(QPen(color))
117       
118           
119        painter.drawText(rect, option.displayAlignment, text)
120        painter.restore()
121       
122       
123class PyObjectItemDelegate(QStyledItemDelegate):
124    def displayText(self, value, locale):
125        obj = _toPyObject(value) #value.toPyObject()
126        return QString(str(obj))
127   
128   
129class PyFloatItemDelegate(QStyledItemDelegate):
130    def displayText(self, value, locale):
131        obj = _toPyObject(value)
132        if isinstance(obj, float):
133            return QString("%.2f" % obj)
134        else:
135            return QString(str(obj))
136       
137def rule_to_string(rule, show_distribution = True):
138    """
139    Write a string presentation of rule in human readable format.
140   
141    :param rule: rule to pretty-print.
142    :type rule: :class:`Orange.classification.rules.Rule`
143   
144    :param show_distribution: determines whether presentation should also
145        contain the distribution of covered instances
146    :type show_distribution: bool
147   
148    """
149    import Orange
150    def selectSign(oper):
151        if oper == Orange.core.ValueFilter_continuous.Less:
152            return "<"
153        elif oper == Orange.core.ValueFilter_continuous.LessEqual:
154            return "<="
155        elif oper == Orange.core.ValueFilter_continuous.Greater:
156            return ">"
157        elif oper == Orange.core.ValueFilter_continuous.GreaterEqual:
158            return ">="
159        else: return "="
160
161    if not rule:
162        return "None"
163    conds = rule.filter.conditions
164    domain = rule.filter.domain
165   
166    def pprint_values(values):
167        if len(values) > 1:
168            return "[" + ",".join(values) + "]"
169        else:
170            return str(values[0])
171       
172    ret = "IF "
173    if len(conds)==0:
174        ret = ret + "TRUE"
175
176    for i,c in enumerate(conds):
177        if i > 0:
178            ret += " AND "
179        if type(c) == Orange.core.ValueFilter_discrete:
180            ret += domain[c.position].name + "=" + pprint_values( \
181                   [domain[c.position].values[int(v)] for v in c.values])
182        elif type(c) == Orange.core.ValueFilter_continuous:
183            ret += domain[c.position].name + selectSign(c.oper) + str(c.ref)
184    if rule.classifier and type(rule.classifier) == Orange.classification.ConstantClassifier\
185            and rule.classifier.default_val:
186        ret = ret + " THEN "+domain.class_var.name+"="+\
187        str(rule.classifier.default_value)
188        if show_distribution:
189            ret += str(rule.class_distribution)
190    elif rule.classifier and type(rule.classifier) == Orange.classification.ConstantClassifier\
191            and type(domain.class_var) == Orange.core.EnumVariable:
192        ret = ret + " THEN "+domain.class_var.name+"="+\
193        str(rule.class_distribution.modus())
194        if show_distribution:
195            ret += str(rule.class_distribution)
196    return ret       
197
198       
199class OWCN2RulesViewer(OWWidget):
200    settingsList = ["show_Rule_length", "show_Rule_quality", "show_Coverage",
201                    "show_Predicted_class", "show_Distribution", "show_Rule"]
202   
203    def __init__(self, parent=None, signalManager=None, name="CN2 Rules Viewer"):
204        OWWidget.__init__(self, parent, signalManager, name)
205        self.inputs = [("Rule Classifier", orange.RuleClassifier, self.setRuleClassifier)]
206        self.outputs = [("Data", ExampleTable), ("Features", AttributeList)]
207       
208        self.show_Rule_length = True
209        self.show_Rule_quality = True
210        self.show_Coverage = True
211        self.show_Predicted_class = True
212        self.show_Distribution = True
213        self.show_Rule = True
214       
215        self.autoCommit = False
216        self.selectedAttrsOnly = True
217       
218       
219        self.loadSettings()
220       
221        #####
222        # GUI
223        #####
224       
225        box = OWGUI.widgetBox(self.controlArea, "Show Info", addSpace=True)
226        box.layout().setSpacing(3)
227        self.headers = ["Rule length",
228                        "Rule quality",
229                        "Coverage",
230                        "Predicted class",
231                        "Distribution",
232                        "Rule"]
233       
234        for i, header in enumerate(self.headers):
235            OWGUI.checkBox(box, self, "show_%s" % header.replace(" ", "_"), header,
236                           tooltip="Show %s column" % header.lower(),
237                           callback=self.updateVisibleColumns)
238           
239        box = OWGUI.widgetBox(self.controlArea, "Output")
240        box.layout().setSpacing(3)
241        cb = OWGUI.checkBox(box, self, "autoCommit", "Commit on any change",
242                            callback=self.commitIf)
243       
244        OWGUI.checkBox(box, self, "selectedAttrsOnly", "Selected attributes only",
245                       tooltip="Send selected attributes only",
246                       callback=self.commitIf)
247       
248        b = OWGUI.button(box, self, "Commit", callback=self.commit, default=True)
249        OWGUI.setStopper(self, b, cb, "changedFlag", callback=self.commit)
250       
251        OWGUI.rubber(self.controlArea)
252       
253        self.tableView = QTableView()
254        self.tableView.setItemDelegate(PyObjectItemDelegate(self))
255        self.tableView.setItemDelegateForColumn(1, PyFloatItemDelegate(self))
256        self.tableView.setItemDelegateForColumn(2, PyFloatItemDelegate(self))
257        self.tableView.setItemDelegateForColumn(4, DistributionItemDelegate(self))
258        self.tableView.setItemDelegateForColumn(5, MultiLineStringItemDelegate(self))
259        self.tableView.setSortingEnabled(True)
260        self.tableView.setSelectionBehavior(QTableView.SelectRows)
261        self.tableView.setAlternatingRowColors(True)
262       
263        self.rulesTableModel = PyTableModel([], self.headers)
264        self.proxyModel = QSortFilterProxyModel(self)
265        self.proxyModel.setSourceModel(self.rulesTableModel)
266       
267        self.tableView.setModel(self.proxyModel)
268        self.connect(self.tableView.selectionModel(),
269                     SIGNAL("selectionChanged(QItemSelection, QItemSelection)"),
270                     lambda is1, is2: self.commitIf())
271        self.connect(self.tableView.horizontalHeader(), SIGNAL("sectionClicked(int)"), lambda section: self.tableView.resizeRowsToContents())
272        self.mainArea.layout().addWidget(self.tableView)
273
274        self.updateVisibleColumns()
275       
276        self.changedFlag = False
277        self.classifier = None
278        self.rules = []
279        self.resize(800, 600)
280
281    def sendReport(self):
282        nrules = self.rulesTableModel.rowCount()
283        print nrules
284        if not nrules:
285            self.reportRaw("<p>No rules.</p>")
286            return
287       
288        shown = [i for i, header in enumerate(self.headers) if getattr(self, "show_%s" % header.replace(" ", "_"))]
289        rep = '<table>\n<tr style="height: 2px"><th colspan="11"  style="border-bottom: thin solid black; height: 2px;">\n'
290        rep += "<tr>"+"".join("<th>%s</th>" % self.headers[i] for i in shown)+"</tr>\n"
291        for row in range(nrules):
292            rep += "<tr>"
293            for col in shown:
294                data = _toPyObject(self.rulesTableModel.data(self.rulesTableModel.createIndex(row, col)))
295                if col==4:
296                    rep += "<td>%s</td>" % ":".join(map(str, data))
297                elif col in (0, 3):
298                    rep += '<td align="center">%s</td>' % data
299                elif col in (1, 2):
300                    rep += '<td align="right">%.3f&nbsp;</td>' % data
301                else:
302                    rep += '<td>%s</td>' % data
303        rep += '<tr style="height: 2px"><th colspan="11"  style="border-bottom: thin solid black; height: 2px;">\n</table>\n'
304        self.reportRaw(rep)
305       
306    def setRuleClassifier(self, classifier=None):
307        self.classifier = classifier
308        if classifier is not None:
309            self.rules = classifier.rules
310        else:
311            self.rules = []
312       
313    def handleNewSignals(self):
314        self.updateRulesModel()
315        self.commit()
316   
317    def updateRulesModel(self):
318        table = []
319        if self.classifier is not None:
320            for i, r in enumerate(self.classifier.rules):
321                table.append((int(r.complexity),
322                              r.quality,
323                              r.classDistribution.abs,
324                              str(r.classifier.defaultValue),
325                              r.classDistribution,
326                              self.ruleText(r)))
327
328        self.rulesTableModel = PyTableModel(table, self.headers)
329        self.proxyModel.setSourceModel(self.rulesTableModel)
330        self.tableView.resizeColumnsToContents()
331        self.tableView.resizeRowsToContents()
332        self.updateVisibleColumns() # if the widget got data for the first time
333           
334   
335    def ruleText(self, rule):
336        text = rule_to_string(rule, show_distribution=False)
337        p = re.compile(r"[0-9]\.[0-9]+")
338        text = p.sub(lambda match: "%.2f" % float(match.group()[0]), text)
339        text = text.replace("AND", "AND\n   ")
340        text = text.replace("THEN", "\nTHEN")
341        return text
342   
343    def updateVisibleColumns(self):
344        anyVisible = False
345        for i, header in enumerate(self.headers):
346            visible = getattr(self, "show_%s" % header.replace(" ", "_"))
347            self.tableView.horizontalHeader().setSectionHidden(i, not visible)
348            anyVisible = anyVisible or visible
349       
350        # report button is not available if not running canvas
351        if hasattr(self, "reportButton"):
352            self.reportButton.setEnabled(anyVisible)
353
354   
355    def commitIf(self):
356        if self.autoCommit:
357            self.commit()
358        else:
359            self.changedFlag = True
360           
361    def selectedAttrsFromRules(self, rules):
362        selected = []
363        for rule in rules:
364            for c in rule.filter.conditions:
365                selected.append(rule.filter.domain[c.position])
366        return set(selected)
367   
368    def selectedExamplesFromRules(self, rules, examples):
369        selected = []
370        for rule in rules:
371            selected.extend(examples.filterref(rule.filter))
372            rule.filter.negate=1
373            examples = examples.filterref(rule.filter)
374            rule.filter.negate=0
375        return selected
376             
377    def commit(self):
378        rows = self.tableView.selectionModel().selectedRows()
379        rows = [self.proxyModel.mapToSource(index) for index in rows]
380        rows = [index.row() for index in rows]
381        selectedRules = [self.classifier.rules[row] for row in rows]
382       
383        if selectedRules:
384            examples = self.classifier.examples
385            selectedExamples = self.selectedExamplesFromRules(selectedRules, self.classifier.examples)
386            selectedAttrs = self.selectedAttrsFromRules(selectedRules)
387            selectedAttrs = [attr for attr in examples.domain.attributes if attr in selectedAttrs] # restore the order
388            if self.selectedAttrsOnly:
389                domain = orange.Domain(selectedAttrs, examples.domain.classVar)
390                domain.addmetas(examples.domain.getmetas())
391                selectedExamples = orange.ExampleTable(domain, selectedExamples)
392            else:
393                selectedExamples = orange.ExampleTable(selectedExamples)
394               
395            self.send("Data", selectedExamples)
396            self.send("Features", orange.VarList(list(selectedAttrs)))
397       
398        else:
399            self.send("Data", None)
400            self.send("Features", None)
401       
402        self.changedFlag = False
403       
404   
405if __name__=="__main__":
406    ap=QApplication(sys.argv)
407    w=OWCN2RulesViewer()
408    #data=orange.ExampleTable("../../doc/datasets/car.tab")
409    data = orange.ExampleTable("../../doc/datasets/car.tab")
410    l=orngCN2.CN2UnorderedLearner()
411    l.ruleFinder.ruleStoppingValidator=orange.RuleValidator_LRS()
412    w.setRuleClassifier(l(data))
413    w.setRuleClassifier(l(data))
414    w.handleNewSignals()
415    w.show()
416    ap.exec_()
417
Note: See TracBrowser for help on using the repository browser.