source: orange/Orange/OrangeWidgets/Prototypes/OWCorrelations.py @ 9996:0191d66861df

Revision 9996:0191d66861df, 18.1 KB checked in by ales_erjavec, 2 years ago (diff)

Orange.data.new_meta_id -> Orange.feature.Descriptor.new_meta_id

Line 
1"""
2<name>Correlations</name>
3<description>Compute all pairwise attribute correlations</description>
4<icon>icons/Correlations.png</icon>
5<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
6
7"""
8
9from OWWidget import *
10
11import OWGUI
12import OWGraph
13
14import Orange
15from Orange.data import variable
16
17def is_continuous(var):
18    return isinstance(var, variable.Continuous)
19
20def is_discrete(var):
21    return isinstance(var, variable.Discrete)
22
23def pairwise_pearson_correlations(data, vars=None):
24    if vars is None:
25        vars = list(data.domain.variables)
26       
27    matrix = Orange.core.SymMatrix(len(vars))
28   
29    for i in range(len(vars)):
30        for j in  range(i + 1, len(vars)):
31            matrix[i, j] = Orange.core.PearsonCorrelation(vars[i], vars[j], data, 0).r
32           
33    return matrix
34
35def pairwise_spearman_correlations(data, vars=None):
36    import numpy
37    import statc
38   
39    if vars is None:
40        vars = list(data.domain.variables)
41   
42    matrix = Orange.core.SymMatrix(len(vars))
43   
44    all_vars = list(data.domain.variables)
45    indices = [all_vars.index(v) for v in vars]
46    (data,) = data.to_numpy_MA("Ac")
47   
48    averages = numpy.ma.average(data, axis=0)
49   
50    for i, var_i in enumerate(indices):
51        for j, var_j in enumerate(indices[i + 1:], i + 1):
52            a = data[:, var_i].filled(averages[var_i])
53            b = data[:, var_j].filled(averages[var_j])
54            matrix[i, j] = statc.spearmanr(list(a), list(b))[0]
55           
56    return matrix
57
58def target_pearson_correlations(data, vars=None, target_var=None):
59    if vars is None:
60        vars = list(data.domain.variables)
61   
62    if target_var is None:
63        if is_continuous(data.domain.class_var):
64            target_var = data.domain.class_var
65        else:
66            raise ValueError("A data with continuous class variable expected if 'target_var' is not explicitly declared.")
67   
68    correlations = []
69    for var in vars:
70        correlations.append(Orange.core.PearsonCorrelation(var, target_var, data, 0).r)
71       
72    return correlations
73
74
75def target_spearman_correlations(data, vars=None, target_var=None):
76    import numpy
77    import statc
78   
79    if vars is None:
80        vars = list(data.domain.variables)
81   
82    if target_var is None:
83        if is_continuous(data.domain.class_var):
84            target_var = data.domain.class_var
85        else:
86            raise ValueError("A data with continuous class variable expected if 'target_var' is not explicitly declared.")
87   
88    all_vars = list(data.domain.variables)
89    indices = [all_vars.index(v) for v in vars]
90    target_index = all_vars.index(target_var)
91    (data,) = data.to_numpy_MA("Ac")
92   
93    averages = numpy.ma.average(data, axis=0)
94    target_values = data[:, target_index].filled(averages[target_index])
95    target_values = list(target_values)
96   
97    correlations = []
98    for i, var_i in enumerate(indices):
99        a = data[:,var_i].filled(averages[var_i])
100        correlations.append(statc.spearmanr(list(a), target_values)[0])
101       
102    return correlations
103
104   
105def matrix_to_table(matrix, items=None):
106    from Orange.data import variable
107    if items is None:
108        items = getattr(matrix, "items", None)
109    if items is None:
110        items = range(matrix.dim)
111       
112    items = map(str, items)
113   
114    attrs = [variable.Continuous(name) for name in items]
115    domain = Orange.data.Domain(attrs, None)
116    row_name = variable.String("Row name")
117    domain.addmeta(Orange.feature.Descriptor.new_meta_id(), row_name)
118   
119    table = Orange.data.Table(domain, [list(r) for r in matrix])
120    for item, row in zip(items, table):
121        row[row_name] = item
122       
123    return table
124
125class CorrelationsItemDelegate(QStyledItemDelegate):
126    def displayText(self, value, locale):
127        v = value.toPyObject()
128        if isinstance(v, float):
129            return QString("%.4f" % v)
130        else:
131            return QStyledItemDelegate.displayText(value, locale)
132       
133class CorrelationsTableView(QTableView):
134    def sizeHint(self):
135        hint = QTableView.sizeHint(self)
136        h_header = self.horizontalHeader()
137        v_header = self.verticalHeader()
138        width = v_header.width() + h_header.length() + 4 + self.verticalScrollBar().width()
139        height = v_header.length() + h_header.height() + 4 + self.horizontalScrollBar().height()
140        return QSize(width, height)
141   
142class OWCorrelations(OWWidget):
143    contextHandlers = {"": DomainContextHandler("", ["selected_index"])}
144   
145    settingsList = ["correlations_type", "pairwise_correlations", "splitter_state"]
146   
147    COR_TYPES = ["Pairwise Pearson correlation",
148                 "Pairwise Spearman correlation",
149                 "Correlate with class"
150                 ]
151    def __init__(self, parent=None, signalManager=None, title="Correlations"):
152        # Call OWBaseWidget constructor to bypass OWWidget layout
153        OWBaseWidget.__init__(self, parent, signalManager, title,)
154#                          wantMainAre=False, noReport=True)
155       
156        self.inputs = [("Data", Orange.data.Table, self.set_data)]
157       
158        self.outputs = [("Correlations", Orange.data.Table),
159                        ("Variables", AttributeList)]
160       
161        # Settings
162       
163        self.pairwise_correlations = True
164        self.correlations_type = 0
165       
166        self.selected_index = None
167        self.changed_flag = False
168        self.auto_commit = True
169       
170        self.splitter_state = None
171       
172        self.loadSettings()
173       
174        #####
175        # GUI
176        #####
177       
178        layout = QVBoxLayout(self)
179        layout.setMargin(4)
180        self.setLayout(layout)
181       
182        self.splitter = QSplitter()
183        self.layout().addWidget(self.splitter)
184        self.splitter.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
185       
186        self.controlArea = OWGUI.widgetBox(self.splitter, addToLayout=False)
187        self.splitter.addWidget(self.controlArea)
188        self.mainArea = OWGUI.widgetBox(self.splitter, addToLayout=False)
189        self.mainArea.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
190        self.splitter.addWidget(self.mainArea)
191        self.splitter.setSizes([1,1])
192       
193        if self.splitter_state is not None:
194            try:
195                self.splitter.restoreState(QByteArray(self.splitter_state))
196            except Exception, ex:
197                pass
198           
199        self.splitter.splitterMoved.connect(
200                        self.on_splitter_moved
201                        )
202       
203        box = OWGUI.widgetBox(self.controlArea, "Correlations")
204        self.corr_radio_buttons = OWGUI.radioButtonsInBox(box, 
205                                self, "correlations_type",
206                                btnLabels=self.COR_TYPES, 
207                                callback=self.on_corr_type_change,
208                                )
209       
210        self.corr_table = CorrelationsTableView()
211        self.corr_table.setSelectionMode(QTableView.SingleSelection)
212        self.corr_table.setItemDelegate(CorrelationsItemDelegate(self))
213        self.corr_table.setEditTriggers(QTableView.NoEditTriggers)
214        self.corr_table.horizontalHeader().sectionClicked.connect(
215                    self.on_horizontal_header_section_click
216                    )
217        self.corr_table.verticalHeader().sectionClicked.connect(
218                    self.on_vertical_header_section_click
219                    )
220        self.corr_model = QStandardItemModel()
221        self.corr_table.setModel(self.corr_model)
222       
223        self.corr_table.selectionModel().selectionChanged.connect(
224                    self.on_table_selection_change
225                    )
226       
227        self.controlArea.layout().addWidget(self.corr_table)
228       
229        self.corr_graph = CorrelationsGraph(self)
230        self.corr_graph.showFilledSymbols = False
231       
232        self.mainArea.layout().addWidget(self.corr_graph)
233       
234        self.clear()
235       
236        self.resize(1000, 600)
237       
238    def clear(self):
239        self.data = None
240        self.cont_vars = None
241        self.var_names = None
242        self.selected_vars = None
243        self.clear_computed()
244        self.clear_graph()
245       
246    def clear_computed(self):
247        self.corr_model.clear()
248        self.set_all_pairwise_matrix(None, None)
249        self.set_target_correlations(None, None)
250       
251    def clear_graph(self):
252        self.corr_graph.clear()
253        self.corr_graph.setData(None, None)
254        self.corr_graph.replot()
255       
256    def set_data(self, data=None):
257        self.closeContext("")
258        self.clear()
259        self.information(0)
260        self.data = data
261        if data is not None and len(filter(is_continuous, data.domain)) >= 2:
262            self.set_variables_list(data)
263            self.selected_index = None
264            self.corr_graph.setData(data)
265            self.openContext("", data)
266           
267            b = self.corr_radio_buttons.buttons[-1]
268            if not is_continuous(data.domain.class_var):
269                self.correlations_type = min(self.correlations_type, 1)
270                b.setEnabled(False)
271            else:
272                b.setEnabled(True)
273               
274            if self.selected_index is None or \
275                    any(n in self.data.domain for n in self.selected_index):
276                self.selected_index = self.var_names[:2]
277               
278            self.run()
279               
280        elif data is not None:
281            self.data = None
282            self.information(0, "Need data with at least 2 continuous variables.")
283           
284        self.commit_if()
285           
286    def set_variables_list(self, data):
287        vars = list(data.domain.variables)
288        vars = [v for v in vars if is_continuous(v)]
289        self.cont_vars = vars
290        self.var_names = [v.name for v in vars]
291   
292    @property
293    def target_variable(self):
294        if self.data:
295            return self.data.domain.class_var
296        else:
297            return None
298       
299    def run(self):
300        if self.correlations_type < 2:
301            if self.correlations_type == 0:
302                matrix = pairwise_pearson_correlations(self.data, self.cont_vars)
303            elif self.correlations_type == 1:
304                matrix = pairwise_spearman_correlations(self.data, self.cont_vars)
305            self.set_all_pairwise_matrix(matrix)
306           
307        elif is_continuous(self.target_variable):
308            vars = [v for v in self.cont_vars if v != self.target_variable]
309            p_corr = target_pearson_correlations(self.data, vars, self.target_variable)
310            s_corr = target_spearman_correlations(self.data, vars, self.target_variable)
311            correlations = map(list, zip(p_corr, s_corr))
312            self.set_target_correlations(correlations, vars, self.target_variable)
313           
314    def set_all_pairwise_matrix(self, matrix, vars=None):
315        self.matrix = matrix
316        if matrix is not None:
317            for i, row in enumerate(matrix):
318                for j, e in enumerate(row):
319                    item = QStandardItem()
320                    if i != j:
321                        item.setData(e, Qt.DisplayRole)
322                    else:
323                        item.setData(QVariant(QColor(Qt.gray)), Qt.BackgroundRole)
324                    self.corr_model.setItem(i, j, item)
325                   
326            if vars is None:
327                vars = self.cont_vars
328            header = [v.name for v in vars]
329            self.corr_model.setVerticalHeaderLabels(header)
330            self.corr_model.setHorizontalHeaderLabels(header)
331           
332            self.corr_table.resizeColumnsToContents()
333            self.corr_table.resizeRowsToContents()
334           
335            QTimer.singleShot(100, self.corr_table.updateGeometry)
336#            self.corr_table.updateGeometry()
337   
338    def set_target_correlations(self, correlations, vars=None, target_var=None):
339        self.target_correlations = correlations
340        if correlations is not None:
341            for i, row in enumerate(correlations):
342                for j, c in enumerate(row):
343                    item = QStandardItem()
344                    item.setData(c, Qt.DisplayRole)
345                    self.corr_model.setItem(i, j, item)
346               
347            if vars is None:
348                vars = self.cont_vars
349           
350            v_header = [v.name for v in vars]
351            h_header = ["Pearson", "Spearman"]
352            self.corr_model.setVerticalHeaderLabels(v_header)
353            self.corr_model.setHorizontalHeaderLabels(h_header)
354           
355            self.corr_table.resizeColumnsToContents()
356            self.corr_table.resizeRowsToContents()
357           
358            QTimer.singleShot(100, self.corr_table.updateGeometry)
359#            self.corr_table.updateGeometry()
360           
361    def set_selected_vars(self, x, y):
362        x = self.cont_vars.index(x)
363        y = self.cont_vars.index(y)
364        if self.correlations_type == 2:
365            y = 0
366       
367        model = self.corr_model
368        sel_model = self.corr_table.selectionModel()
369        sel_model.select(model.index(x, y),
370                         QItemSelectionModel.ClearAndSelect)
371   
372    def on_corr_type_change(self):
373        if self.data is not None:
374            curr_selection = self.selected_vars
375            self.clear_computed()
376            self.run()
377           
378            if curr_selection:
379                try:
380                    self.set_selected_vars(*curr_selection)
381                except Exception, ex:
382                    import traceback
383                    traceback.print_exc()
384           
385            self.commit_if()
386       
387    def on_table_selection_change(self, selected, deselected):
388        indexes = self.corr_table.selectionModel().selectedIndexes()
389        if indexes:
390            index = indexes[0]
391            i, j = index.row(), index.column()
392            if self.correlations_type == 2 and is_continuous(self.target_variable):
393                j = len(self.var_names) - 1
394           
395            self.corr_graph.updateData(self.var_names[i],
396                           self.var_names[j],
397                           self.data.domain.class_var.name \
398                           if is_discrete(self.data.domain.class_var) else \
399                           "(Same color)")
400            vars = [self.cont_vars[i], self.cont_vars[j]]
401        else:
402            # TODO: Clear graph
403            vars = None
404        self.selected_vars = vars
405       
406        self.send("Variables", vars)
407       
408    def on_horizontal_header_section_click(self, section):
409        sel_model = self.corr_table.selectionModel()
410        indexes = sel_model.selectedIndexes()
411        if indexes:
412            index = indexes[0]
413            i, j = index.row(), index.column()
414            sel_index = self.corr_model.index(i, section)
415            sel_model.setCurrentIndex(sel_index,
416                                QItemSelectionModel.ClearAndSelect)
417           
418    def on_vertical_header_section_click(self, section):
419        sel_model = self.corr_table.selectionModel()
420        indexes = sel_model.selectedIndexes()
421        if indexes:
422            index = indexes[0]
423            i, j = index.row(), index.column()
424            sel_index = self.corr_model.index(section, j)
425            sel_model.setCurrentIndex(sel_index,
426                                QItemSelectionModel.ClearAndSelect)
427           
428    def on_splitter_moved(self, *args):
429        self.splitter_state = str(self.splitter.saveState())
430           
431    def commit_if(self):
432        if self.auto_commit:
433            self.commit()
434        else:
435            self.changed_flag = True
436   
437    def commit(self):
438        table = None
439        if self.data is not None:
440            if self.correlations_type == 2 and \
441                    is_continuous(self.target_variable):
442                pearson, _ = variable.make("Pearson", Orange.core.VarTypes.Continuous)
443                spearman, _ = variable.make("Spearman", Orange.core.VarTypes.Continuous)
444                row_name, _ = variable.make("Variable", Orange.core.VarTypes.String)
445               
446                domain = Orange.data.Domain([pearson, spearman], None)
447                domain.addmeta(Orange.feature.Descriptor.new_meta_id(), row_name)
448                table = Orange.data.Table(domain, self.target_correlations)
449                for inst, name in zip(table, self.var_names):
450                    inst[row_name] = name
451#            else:
452#                table = matrix_to_table(self.matrix, self.var_names)
453       
454        self.send("Correlations", table)
455       
456from OWScatterPlotGraph import OWScatterPlotGraph
457
458class CorrelationsGraph(OWScatterPlotGraph):
459    def updateData(self, x_attr, y_attr, *args, **kwargs):
460        OWScatterPlotGraph.updateData(self, x_attr, y_attr, *args, **kwargs)
461        if not hasattr(self, "regresson_line"):
462            self.regression_line = self.addCurve("regresson_line",
463                                                  style=OWGraph.QwtPlotCurve.Lines,
464                                                  symbol=OWGraph.QwtSymbol.NoSymbol,
465                                                  autoScale=True)
466        if isinstance(x_attr, basestring):
467            x_index = self.attribute_name_index[x_attr]
468        else:
469            x_index = x_attr
470           
471        if isinstance(y_attr, basestring):
472            y_index = self.attribute_name_index[y_attr]
473        else:
474            y_index = y_attr
475       
476        X = self.original_data[x_index]
477        Y = self.original_data[y_index]
478       
479        valid = self.getValidList([x_index, y_index])
480       
481        X = X[valid]
482        Y = Y[valid]
483        x_min, x_max = self.attr_values[x_attr]
484       
485        import numpy
486        X = numpy.array([numpy.ones_like(X), X]).T
487        try:
488            beta, _, _, _ = numpy.linalg.lstsq(X, Y)
489        except numpy.linalg.LinAlgError:
490            beta = [0, 0]
491       
492        y1 = beta[0] + x_min * beta[1]
493        y2 = beta[0] + x_max * beta[1]
494       
495        self.regression_line.setData([x_min, x_max], [y1, y2])       
496        self.replot()
497
498def main():
499    import sys
500    app = QApplication(sys.argv)
501    data = Orange.data.Table("housing")
502#    data = Orange.data.Table("iris")
503    w = OWCorrelations()
504    w.set_data(None)
505    w.set_data(data)
506    w.show()
507    sys.exit(app.exec_())
508   
509if __name__ == "__main__":
510    main()
511   
Note: See TracBrowser for help on using the repository browser.