source: orange/Orange/OrangeWidgets/Classify/OWNeuralNetwork.py @ 11765:1546cd04481b

Revision 11765:1546cd04481b, 3.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 5 months ago (diff)

Fixed widget layouts.

Line 
1"""
2<name>Neural Network</name>
3<description>Neural network learner.</description>
4<priority>20</priority>
5<icon>icons/NeuralNetwork.svg</icon>
6
7"""
8
9import Orange
10from orngWrap import PreprocessedLearner
11
12from OWWidget import *
13import OWGUI
14
15class OWNeuralNetwork(OWWidget):
16    settingsList = ["name", "n_mid", "reg_fact", "max_iter", "normalize"]
17
18    def __init__(self, parent=None, signalManager=None,
19                 title="Neural Network"):
20        OWWidget.__init__(self, parent, signalManager, title,
21                          wantMainArea=False, resizingEnabled=False)
22
23        self.inputs = [("Data", Orange.data.Table, self.set_data),
24                       ("Preprocess", PreprocessedLearner,
25                        self.set_preprocessor)
26                       ]
27        self.outputs = [("Learner", Orange.classification.Learner),
28                        ("Classifier", Orange.classification.Classifier)
29                        ]
30
31        self.name = "Neural Network"
32        self.n_mid = 20
33        self.reg_fact = 1
34        self.max_iter = 300
35        self.normalize = True
36
37        self.loadSettings()
38
39        box = OWGUI.widgetBox(self.controlArea, "Name", addSpace=True)
40        OWGUI.lineEdit(box, self, "name")
41
42        box = OWGUI.widgetBox(self.controlArea, "Settings",
43                              addSpace=True)
44
45        form = QFormLayout(
46            spacing=8, formAlignment=Qt.AlignLeft, labelAlignment=Qt.AlignLeft,
47            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
48        )
49        box.layout().addLayout(form)
50
51        form.addRow(
52            "Hidden layer neurons",
53            OWGUI.spin(box, self, "n_mid", 2, 10000, 1,
54                       tooltip="Number of neurons in the hidden layer.",
55                       addToLayout=False)
56        )
57
58        form.addRow(
59            "Regularization factor",
60            OWGUI.doubleSpin(box, self, "reg_fact", 0.1, 10.0, 0.1,
61                             addToLayout=False)
62        )
63
64        form.addRow(
65            "Max iterations",
66            OWGUI.spin(box, self, "max_iter", 100, 10000, 1,
67                       tooltip="Maximal number of optimization iterations.",
68                       addToLayout=False)
69        )
70
71        OWGUI.checkBox(box, self, 'normalize', 'Normalize the data')
72
73        OWGUI.button(self.controlArea, self, "&Apply",
74                     callback=self.apply,
75                     tooltip="Create the learner and apply it on input data.",
76                     autoDefault=True)
77
78        self.data = None
79        self.preprocessor = None
80        self.apply()
81
82    def set_data(self, data=None):
83        self.data = data
84        self.error([0])
85        self.data = data
86        self.apply()
87
88    def set_preprocessor(self, preprocessor=None):
89        self.preprocessor = preprocessor
90
91    def apply(self):
92        learner = Orange.classification.neural.NeuralNetworkLearner(
93            name=self.name, n_mid=self.n_mid,
94            reg_fact=self.reg_fact, max_iter=self.max_iter
95        )
96
97        if self.preprocessor is not None:
98            learner = self.preprocessor.wrapLearner(learner)
99
100        classifier = None
101        self.error([1])
102        if self.data is not None:
103            try:
104                classifier = learner(self.data)
105                classifier.name = self.name
106            except Exception, ex:
107                self.error(1, str(ex))
108
109        self.send("Learner", learner)
110        self.send("Classifier", classifier)
111
112    def sendReport(self):
113        self.reportSettings("Parameters",
114                            [("Hidden layer neurons", self.n_mid),
115                             ("Regularization factor", self.reg_fact),
116                             ("Max iterations", self.max_iter)]
117                            )
118
119
120if __name__ == "__main__":
121    app = QApplication(sys.argv)
122    w = OWNeuralNetwork()
123    w.show()
124    app.exec_()
125    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.