source: orange/Orange/OrangeWidgets/Data/OWDataSampler.py @ 11816:6265b5fb0c0d

Revision 11816:6265b5fb0c0d, 16.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 4 months ago (diff)

Added "Sample on any change" check box. Code cleanup.

Line 
1import random
2
3import Orange
4from Orange.data import sample
5
6import OWGUI
7from OWWidget import *
8
9NAME = "Data Sampler"
10DESCRIPTION = "Samples data from a data set."
11ICON = "icons/DataSampler.svg"
12PRIORITY = 1125
13CATEGORY = "Data"
14MAINTAINER = "Aleksander Sadikov"
15MAINTAINER_EMAIL = "aleksander.sadikov(@at@)fri.uni-lj.si"
16INPUTS = [("Data", Orange.data.Table, "setData", Default)]
17OUTPUTS = [("Data Sample", Orange.data.Table, ),
18           ("Remaining Data", Orange.data.Table, )]
19
20
21class OWDataSampler(OWWidget):
22    settingsList = [
23        "Stratified", "Repeat", "UseSpecificSeed", "RandomSeed",
24        "GroupSeed", "outFold", "Folds", "SelectType", "useCases", "nCases",
25        "selPercentage", "CVFolds", "nGroups",
26        "pGroups", "GroupText", "autocommit"]
27
28    contextHandlers = {
29        "": DomainContextHandler("", ["nCases", "selPercentage"])
30    }
31
32    def __init__(self, parent=None, signalManager=None):
33        OWWidget.__init__(self, parent, signalManager, 'SampleData',
34                          wantMainArea=0)
35
36        self.inputs = [("Data", ExampleTable, self.setData)]
37        self.outputs = [("Data Sample", ExampleTable),
38                        ("Remaining Data", ExampleTable)]
39
40        # initialization of variables
41        self.data = None                        # dataset (incoming stream)
42        self.indices = None                     # indices that control sampling
43
44        self.Stratified = 1                     # use stratified sampling if possible?
45        self.Repeat = 0                         # can elements repeat in a sample?
46        self.UseSpecificSeed = 0                # use a specific random seed?
47        self.RandomSeed = 1                     # specific seed used
48        self.GroupSeed = 1                      # current seed for multiple group selection
49        self.outFold = 1                        # folder/group to output
50        self.Folds = 1                          # total number of folds/groups
51
52        self.SelectType = 0                     # sampling type (LOO, CV, ...)
53        self.useCases = 0                       # use a specific number of cases?
54        self.nCases = 25                        # number of cases to use
55        self.selPercentage = 30                 # sample size in %
56        self.CVFolds = 10                       # number of CV folds
57        self.nGroups = 3                        # number of groups
58        self.pGroups = [0.1, 0.25, 0.5]         # sizes of groups
59        self.GroupText = '0.1,0.25,0.5'         # assigned to Groups Control (for internal use)
60        self.autocommit = False
61
62        # Invalidated settings flag.
63        self.outputInvalidateFlag = False
64
65        self.loadSettings()
66
67        # GUI
68
69        # Info Box
70        box1 = OWGUI.widgetBox(self.controlArea, "Information", addSpace=True)
71        # Input data set info
72        self.infoa = OWGUI.widgetLabel(box1, 'No data on input.')
73        # Sampling type/parameters info
74        self.infob = OWGUI.widgetLabel(box1, ' ')
75        # Output data set info
76        self.infoc = OWGUI.widgetLabel(box1, ' ')
77
78        # Options Box
79        box2 = OWGUI.widgetBox(self.controlArea, 'Options', addSpace=True)
80        OWGUI.checkBox(box2, self, 'Stratified', 'Stratified (if possible)',
81                       callback=self.settingsChanged)
82
83        OWGUI.checkWithSpin(
84            box2, self, 'Set random seed:', 0, 32767,
85            'UseSpecificSeed',
86            'RandomSeed',
87            checkCallback=self.settingsChanged,
88            spinCallback=self.settingsChanged
89        )
90
91        # Sampling Type Box
92        self.s = [None, None, None, None]
93        self.sBox = OWGUI.widgetBox(self.controlArea, "Sampling type",
94                                    addSpace=True)
95        self.sBox.buttons = []
96
97        # Random Sampling
98        self.s[0] = OWGUI.appendRadioButton(self.sBox, self, "SelectType",
99                                            'Random sampling')
100
101        # indent
102        indent = OWGUI.checkButtonOffsetHint(self.s[0])
103        # repeat checkbox
104        self.h1Box = OWGUI.indentedBox(self.sBox, sep=indent,
105                                       orientation="horizontal")
106        OWGUI.checkBox(self.h1Box, self, 'Repeat', 'With replacement',
107                       callback=self.settingsChanged)
108
109        # specified number of elements checkbox
110        self.h2Box = OWGUI.indentedBox(self.sBox, sep=indent,
111                                       orientation="horizontal")
112        check, _ = OWGUI.checkWithSpin(
113            self.h2Box, self, 'Sample size (instances):', 1, 1000000000,
114            'useCases', 'nCases',
115            checkCallback=self.settingsChanged,
116            spinCallback=self.settingsChanged
117        )
118
119        # percentage slider
120        self.h3Box = OWGUI.indentedBox(self.sBox, sep=indent)
121        OWGUI.widgetLabel(self.h3Box, "Sample size:")
122
123        self.slidebox = OWGUI.widgetBox(self.h3Box, orientation="horizontal")
124        OWGUI.hSlider(self.slidebox, self, 'selPercentage',
125                      minValue=1, maxValue=100, step=1, ticks=10,
126                      labelFormat="   %d%%",
127                      callback=self.settingsChanged)
128
129        # Sample size (instances) check disables the Percentage slider.
130        # TODO: Should be an exclusive option (radio buttons)
131        check.disables.extend([(-1, self.h3Box)])
132        check.makeConsistent()
133
134        # Cross Validation sampling options
135        self.s[1] = OWGUI.appendRadioButton(self.sBox, self, "SelectType",
136                                            "Cross validation")
137
138        box = OWGUI.indentedBox(self.sBox, sep=indent,
139                                orientation="horizontal")
140        OWGUI.spin(box, self, 'CVFolds', 2, 100, step=1,
141                   label='Number of folds:  ',
142                   callback=self.settingsChanged)
143
144        # Leave-One-Out
145        self.s[2] = OWGUI.appendRadioButton(self.sBox, self, "SelectType",
146                                            "Leave-one-out")
147
148        # Multiple Groups
149        self.s[3] = OWGUI.appendRadioButton(self.sBox, self, "SelectType",
150                                            'Multiple subsets')
151        gbox = OWGUI.indentedBox(self.sBox, sep=indent,
152                                 orientation="horizontal")
153        OWGUI.lineEdit(gbox, self, 'GroupText',
154                       label='Subset sizes (e.g. "0.1, 0.2, 0.5"):',
155                       callback=self.multipleChanged)
156
157        # Output Group Box
158        box = OWGUI.widgetBox(self.controlArea, 'Output Data for Fold / Group',
159                              addSpace=True)
160        self.foldcombo = OWGUI.comboBox(
161            box, self, "outFold", items=range(1, 101),
162            label='Fold / group:', orientation="horizontal",
163            sendSelectedValue=1, valueType=int,
164            callback=self.invalidate
165        )
166        self.foldcombo.setEnabled(self.SelectType != 0)
167
168        # Sample Data box
169        OWGUI.rubber(self.controlArea)
170        box = OWGUI.widgetBox(self.controlArea, "Sample Data")
171        cb = OWGUI.checkBox(box, self, "autocommit", "Sample on any change")
172        self.sampleButton = OWGUI.button(box, self, 'Sample &Data',
173                                         callback=self.sdata, default=True)
174        OWGUI.setStopper(self, self.sampleButton, cb, "outputInvalidateFlag",
175                         callback=self.sdata)
176
177        # set initial radio button on (default sample type)
178        self.s[self.SelectType].setChecked(True)
179
180        # Connect radio buttons (SelectType)
181        for i, button in enumerate(self.s):
182            button.toggled[bool].connect(
183                lambda state, i=i: self.samplingTypeChanged(state, i)
184            )
185
186        self.process()
187
188        self.resize(200, 275)
189
190    # CONNECTION TRIGGER AND GUI ROUTINES
191    # enables RadioButton switching
192    def samplingTypeChanged(self, value, i):
193        """Sampling type changed."""
194        self.SelectType = i
195        self.settingsChanged()
196
197    def multipleChanged(self):
198        """Multiple subsets (Groups) changed."""
199        self.error(1)
200        try:
201            self.pGroups = [float(x) for x in self.GroupText.split(',')]
202            self.nGroups = len(self.pGroups)
203        except:
204            self.error(1, "Invalid specification for sizes of subsets.")
205        else:
206            self.settingsChanged()
207
208    def updateFoldCombo(self):
209        """Update the 'Folds' combo box contents."""
210        fold = self.outFold
211        self.Folds = 1
212        if self.SelectType == 1:
213            self.Folds = self.CVFolds
214        elif self.SelectType == 2:
215            if self.data:
216                self.Folds = len(self.data)
217            else:
218                self.Folds = 1
219        elif self.SelectType == 3:
220            self.Folds = self.nGroups
221
222        self.foldcombo.clear()
223        for x in range(self.Folds):
224            self.foldcombo.addItem(str(x + 1))
225        self.outFold = min(fold, self.Folds)
226
227    def setData(self, dataset):
228        """Set the input data set."""
229        self.closeContext()
230        if dataset is not None:
231            self.infoa.setText('%d instances in input data set.' %
232                               len(dataset))
233            self.data = dataset
234            self.openContext("", dataset)
235            self.process()
236            self.sdata()
237        else:
238            self.infoa.setText('No data on input.')
239            self.infob.setText('')
240            self.infoc.setText('')
241            self.send("Data Sample", None)
242            self.send("Remaining Data", None)
243            self.data = None
244
245    # feeds the output stream
246    def sdata(self):
247        if not self.data:
248            return
249
250        # select data
251        if self.SelectType == 0:
252            if self.useCases == 1 and self.Repeat == 1:
253                indices = self.indices(self.data)
254                sample = [self.data[i] for i in indices]
255                sample = Orange.data.Table(self.data.domain, sample)
256                remainder = None
257            else:
258                indices = self.indices(self.data)
259                sample = self.data.select(indices, 0)
260                remainder = self.data.select(indices, 1)
261            self.infoc.setText('Output: %d instances.' % len(sample))
262        elif self.SelectType == 3:
263            indices = self.indices(self.data, p0=self.pGroups[self.outFold - 1])
264            sample = self.data.select(indices, 0)
265            remainder = self.data.select(indices, 1)
266            self.infoc.setText(
267                'Output: subset %(fold)d of %(folds)d, %(len)d instance(s).' %
268                {"fold": self.outFold, "folds": self.Folds, "len": len(sample)}
269            )
270        else:
271            # CV/LOO
272            indices = self.indices(self.data)
273            sample = self.data.select(indices, self.outFold - 1)
274            remainder = self.data.select(indices, self.outFold - 1, negate=1)
275            self.infoc.setText(
276                'Output: fold %(fold)d of %(folds)d, %(len)d instance(s).' %
277                {"fold": self.outFold, "folds": self.Folds, "len": len(sample)}
278            )
279
280        if sample is not None:
281            sample.name = self.data.name
282        if remainder is not None:
283            remainder.name = self.data.name
284
285        # send data
286        self.nSample = len(sample)
287        self.nRemainder = len(remainder) if remainder is not None else 0
288        self.send("Data Sample", sample)
289        self.send("Remaining Data", remainder)
290
291        self.outputInvalidateFlag = False
292
293    def process(self):
294        self.error(0)
295        self.warning(0)
296
297        self.infob.setText('')
298
299        if self.SelectType == 0:
300            # Random Selection
301            if self.useCases == 1:
302                ncases = self.nCases
303                if self.Repeat == 0:
304                    ncases = self.nCases
305                    if self.data is not None and ncases > len(self.data):
306                        self.warning(0, "Sample size (w/o repetitions) larger than dataset.")
307                        ncases = len(self.data)
308                    p0 = ncases + 1e-7 if ncases == 1 else ncases
309                    self.indices = sample.SubsetIndices2(p0=p0)
310                    self.infob.setText('Random sampling, using exactly %d instances.' % ncases)
311                else:
312                    p0 = ncases + 1e-7 if ncases == 1 else ncases
313                    self.indices = sample.SubsetIndicesMultiple(p0=p0)
314                    self.infob.setText('Random sampling with repetitions, %d instances.' % ncases)
315            else:
316                if self.selPercentage == 100:
317                    p0 = len(self.data) if self.data is not None else 1.0
318                else:
319                    p0 = float(self.selPercentage) / 100.0
320                self.indices = sample.SubsetIndices2(p0=p0)
321                self.infob.setText('Random sampling, %d%% of input instances.' % self.selPercentage)
322            if self.Stratified == 1:
323                self.indices.stratified = self.indices.StratifiedIfPossible
324            else:
325                self.indices.stratified = self.indices.NotStratified
326            if self.UseSpecificSeed == 1:
327                self.indices.randseed = self.RandomSeed
328            else:
329                self.indices.randomGenerator = Orange.misc.Random(random.randint(0,65536))
330
331        # Cross Validation / LOO
332        elif self.SelectType == 1 or self.SelectType == 2:
333            # apply selected options
334            if self.SelectType == 2:
335                folds = len(self.data) if self.data is not None else 1
336                self.infob.setText('Leave-one-out.')
337            else:
338                folds = self.CVFolds
339                self.infob.setText('%d-fold cross validation.' % self.CVFolds)
340            self.indices = sample.SubsetIndicesCV(folds=folds)
341            if self.Stratified == 1:
342                self.indices.stratified = self.indices.StratifiedIfPossible
343            else:
344                self.indices.stratified = self.indices.NotStratified
345            if self.UseSpecificSeed == 1:
346                self.indices.randseed = self.RandomSeed
347            else:
348                self.indices.randseed = random.randint(0, 65536)
349
350        # MultiGroup
351        elif self.SelectType == 3:
352            self.infob.setText('Multiple subsets.')
353            #prepare indices generator
354            self.indices = sample.SubsetIndices2()
355            if self.Stratified == 1:
356                self.indices.stratified = self.indices.StratifiedIfPossible
357            else:
358                self.indices.stratified = self.indices.NotStratified
359            if self.UseSpecificSeed == 1:
360                self.indices.randseed = self.RandomSeed
361            else:
362                self.indices.randomGenerator = Orange.misc.Random(random.randint(0,65536))
363
364    def settingsChanged(self):
365        # enable fold selection and fill combobox if applicable
366        if self.SelectType == 0:
367            self.foldcombo.setEnabled(False)
368        else:
369            self.foldcombo.setEnabled(True)
370            self.updateFoldCombo()
371
372        self.process()
373        self.invalidate()
374
375    def invalidate(self):
376        """Invalidate current output."""
377        self.infoc.setText('...')
378        if self.autocommit:
379            self.sdata()
380        else:
381            self.outputInvalidateFlag = True
382
383    def sendReport(self):
384        if self.SelectType == 0:
385            if self.useCases:
386                stype = "Random sample of %i instances" % self.nCases
387            else:
388                stype = "Random sample with %i%% instances" % self.selPercentage
389        elif self.SelectType == 1:
390            stype = "%i-fold cross validation" % self.CVFolds
391        elif self.SelectType == 2:
392            stype = "Leave one out"
393        elif self.SelectType == 3:
394            stype = "Multiple subsets"
395        self.reportSettings("Settings", [("Sampling type", stype), 
396                                         ("Stratification", OWGUI.YesNo[self.Stratified]),
397                                         ("Random seed", str(self.RandomSeed) if self.UseSpecificSeed else "auto")])
398        if self.data is not None:
399            self.reportSettings("Data", [("Input", "%i examples" % len(self.data)), 
400                                         ("Sample", "%i examples" % self.nSample), 
401                                         ("Rest", "%i examples" % self.nRemainder)])
402        else:
403            self.reportSettings("Data", [("Input", "None")])
404
405
406if __name__ == "__main__":
407    appl = QApplication(sys.argv)
408    ow = OWDataSampler()
409    data = Orange.data.Table('iris.tab')
410    ow.setData(data)
411    ow.show()
412    appl.exec_()
413    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.