source: orange/orange/OrangeWidgets/OWFreeVizOptimization.py @ 8765:184c4022f54a

Revision 8765:184c4022f54a, 21.8 KB checked in by matejd <matejd@…>, 3 years ago (diff)

Fixed freeviz anchor restraints

Line 
1from OWWidget import OWWidget
2from OWkNNOptimization import *
3import orange, math, random
4import OWGUI, orngVisFuncts, numpy
5from math import sqrt
6
7from orngScaleLinProjData import *
8from orngLinProj import *
9
10class FreeVizOptimization(OWWidget, FreeViz):
11    settingsList = ["stepsBeforeUpdate", "restrain", "differentialEvolutionPopSize",
12                    "s2nSpread", "s2nPlaceAttributes", "autoSetParameters",
13                    "forceRelation", "mirrorSymmetry", "forceSigma", "restrain", "law", "forceRelation", "disableAttractive",
14                    "disableRepulsive", "useGeneralizedEigenvectors", "touringSpeed"]
15
16    forceRelValues = ["4 : 1", "3 : 1", "2 : 1", "3 : 2", "1 : 1", "2 : 3", "1 : 2", "1 : 3", "1 : 4"]
17    attractRepelValues = [(4, 1), (3, 1), (2, 1), (3, 2), (1, 1), (2, 3), (1, 2), (1, 3), (1, 4)]
18
19    def __init__(self, parentWidget = None, signalManager = None, graph = None, parentName = "Visualization widget"):
20        OWWidget.__init__(self, None, signalManager, "FreeViz Dialog", savePosition = True, wantMainArea = 0, wantStatusBar = 1)
21        FreeViz.__init__(self, graph)
22
23        self.parentWidget = parentWidget
24        self.parentName = parentName
25        self.setCaption("FreeViz Optimization Dialog")
26        self.cancelOptimization = 0
27        self.forceRelation = 5
28        self.disableAttractive = 0
29        self.disableRepulsive = 0
30        self.touringSpeed = 4
31        self.graph = graph
32
33        if self.graph:
34            self.graph.hideRadius = 0
35            self.graph.showAnchors = 1
36
37        # differential evolution
38        self.differentialEvolutionPopSize = 100
39        self.DERadvizSolver = None
40
41        self.loadSettings()
42
43        self.layout().setMargin(0)
44        self.tabs = OWGUI.tabWidget(self.controlArea)
45        self.MainTab = OWGUI.createTabPage(self.tabs, "Main")
46        self.ProjectionsTab = OWGUI.createTabPage(self.tabs, "Projections")
47
48        # ###########################
49        # MAIN TAB
50        OWGUI.comboBox(self.MainTab, self, "implementation", box = "FreeViz implementation", items = ["Fast (C) implementation", "Slow (Python) implementation", "LDA"])
51
52        box = OWGUI.widgetBox(self.MainTab, "Optimization")
53
54        self.optimizeButton = OWGUI.button(box, self, "Optimize Separation", callback = self.optimizeSeparation)
55        self.stopButton = OWGUI.button(box, self, "Stop Optimization", callback = self.stopOptimization)
56        self.singleStepButton = OWGUI.button(box, self, "Single Step", callback = self.singleStepOptimization)
57        f = self.optimizeButton.font(); f.setBold(1)
58        self.optimizeButton.setFont(f)
59        self.stopButton.setFont(f); self.stopButton.hide()
60        self.attrKNeighboursCombo = OWGUI.comboBoxWithCaption(box, self, "stepsBeforeUpdate", "Number of steps before updating graph: ", tooltip = "Set the number of optimization steps that will be executed before the updated anchor positions will be visualized", items = [1, 3, 5, 10, 15, 20, 30, 50, 75, 100, 150, 200, 300], sendSelectedValue = 1, valueType = int)
61        OWGUI.checkBox(box, self, "mirrorSymmetry", "Keep mirror symmetry", tooltip = "'Rotational' keeps the second anchor upside")
62
63        vbox = OWGUI.widgetBox(self.MainTab, "Set anchor positions")
64        hbox1 = OWGUI.widgetBox(vbox, orientation = "horizontal")
65        OWGUI.button(hbox1, self, "Sphere" if "3d" in self.parentName.lower() else "Circle", callback = self.radialAnchors)
66        OWGUI.button(hbox1, self, "Random", callback = self.randomAnchors)
67        self.manualPositioningButton = OWGUI.button(hbox1, self, "Manual", callback = self.setManualPosition)
68        self.manualPositioningButton.setCheckable(1)
69        OWGUI.comboBox(vbox, self, "restrain", label="Restrain anchors:", orientation = "horizontal", items = ["Unrestrained", "Fixed Length", "Fixed Angle"], callback = self.setRestraints)
70
71        box2 = OWGUI.widgetBox(self.MainTab, "Forces", orientation = "vertical")
72
73        self.cbLaw = OWGUI.comboBox(box2, self, "law", label="Law", labelWidth = 40, orientation="horizontal", items=["Linear", "Square", "Gaussian", "KNN", "Variance"], callback = self.forceLawChanged)
74
75        hbox2 = OWGUI.widgetBox(box2, orientation = "horizontal")
76        hbox2.layout().addSpacing(10)
77
78        validSigma = QDoubleValidator(self); validSigma.setBottom(0.01)
79        self.spinSigma = OWGUI.lineEdit(hbox2, self, "forceSigma", label = "Kernel width (sigma) ", labelWidth = 110, orientation = "horizontal", valueType = float)
80        self.spinSigma.setFixedSize(60, self.spinSigma.sizeHint().height())
81        self.spinSigma.setSizePolicy(QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed))
82
83        box2.layout().addSpacing(20)
84
85        self.cbforcerel = OWGUI.comboBox(box2, self, "forceRelation", label= "Attractive : Repulsive  ",orientation = "horizontal", items=self.forceRelValues, callback = self.updateForces)
86        self.cbforcebal = OWGUI.checkBox(box2, self, "forceBalancing", "Dynamic force balancing", tooltip="Normalize the forces so that the total sums of the\nrepulsive and attractive are in the above proportion.")
87
88        box2.layout().addSpacing(20)
89
90        self.cbDisableAttractive = OWGUI.checkBox(box2, self, "disableAttractive", "Disable attractive forces", callback = self.setDisableAttractive)
91        self.cbDisableRepulsive = OWGUI.checkBox(box2, self, "disableRepulsive", "Disable repulsive forces", callback = self.setDisableRepulsive)
92
93        box = OWGUI.widgetBox(self.MainTab, "Show anchors")
94        OWGUI.checkBox(box, self, 'graph.showAnchors', 'Show attribute anchors', callback = self.parentWidget.updateGraph)
95        OWGUI.qwtHSlider(box, self, "graph.hideRadius", label="Hide radius", minValue=0, maxValue=9, step=1, ticks=0, callback = self.parentWidget.updateGraph)
96        self.freeAttributesButton = OWGUI.button(box, self, "Remove hidden attributes", callback = self.removeHidden)
97
98        if parentName.lower() != "radviz" and parentName.lower() != "sphereviz":
99            pcaBox = OWGUI.widgetBox(self.ProjectionsTab, "Principal Component Analysis")
100            OWGUI.button(pcaBox, self, "Principal component analysis", callback = self.findPCAProjection)
101            OWGUI.button(pcaBox, self, "Supervised principal component analysis", callback = self.findSPCAProjection)
102            OWGUI.checkBox(pcaBox, self, "useGeneralizedEigenvectors", "Merge examples with same class value")
103            plsBox = OWGUI.widgetBox(self.ProjectionsTab, "Partial Least Squares")
104            OWGUI.button(plsBox, self, "Partial least squares", callback = self.findPLSProjection)
105       
106        box = OWGUI.widgetBox(self.ProjectionsTab, "Projection Tours")
107        self.startTourButton = OWGUI.button(box, self, "Start Random Touring", callback = self.startRandomTouring)
108        self.stopTourButton = OWGUI.button(box, self, "Stop Touring", callback = self.stopRandomTouring)
109        self.stopTourButton.hide()
110        OWGUI.hSlider(box, self, 'touringSpeed', label = "Speed:  ", minValue=1, maxValue=10, step=1)
111        OWGUI.rubber(self.ProjectionsTab)
112       
113        box = OWGUI.widgetBox(self.ProjectionsTab, "Signal to Noise Heuristic")
114        #OWGUI.comboBoxWithCaption(box, self, "s2nSpread", "Anchor spread: ", tooltip = "Are the anchors for each class value placed together or are they distributed along the circle", items = range(11), callback = self.s2nMixAnchors)
115        box2 = OWGUI.widgetBox(box, 0, orientation = "horizontal")
116        OWGUI.widgetLabel(box2, "Anchor spread:           ")
117        OWGUI.hSlider(box2, self, 's2nSpread', minValue=0, maxValue=10, step=1, callback = self.s2nMixAnchors, labelFormat=%d", ticks=0)
118        OWGUI.comboBoxWithCaption(box, self, "s2nPlaceAttributes", "Attributes to place: ", tooltip = "Set the number of top ranked attributes to place. You can select a higher value than the actual number of attributes", items = self.attrsNum, callback = self.s2nMixAnchors, sendSelectedValue = 1, valueType = int)
119        OWGUI.checkBox(box, self, 'autoSetParameters', 'Automatically find optimal parameters')
120        self.s2nMixButton = OWGUI.button(box, self, "Place anchors", callback = self.s2nMixAnchorsAutoSet)
121
122
123        self.forceLawChanged()
124        self.updateForces()
125        self.cbforcebal.setDisabled(self.cbDisableAttractive.isChecked() or self.cbDisableRepulsive.isChecked())
126        self.resize(320,650)
127##        self.parentWidget.learnersArray[3] = S2NHeuristicLearner(self, self.parentWidget)
128
129
130    def startRandomTouring(self):
131        self.startTourButton.hide()
132        self.stopTourButton.show()
133       
134        labels = [self.graph.anchorData[i][2] for i in range(len(self.graph.anchorData))]
135        newXPositions = numpy.array([x[0] for x in self.graph.anchorData])
136        newYPositions = numpy.array([x[1] for x in self.graph.anchorData])
137        step = steps = 0
138        self.canTour = 1
139        while hasattr(self, "canTour"):
140            if step >= steps:
141                oldXPositions = newXPositions
142                oldYPositions = newYPositions
143                newXPositions = numpy.random.uniform(-1, 1, len(self.graph.anchorData))
144                newYPositions = numpy.random.uniform(-1, 1, len(self.graph.anchorData))
145                m = math.sqrt(max(newXPositions**2 + newYPositions**2))
146                newXPositions/= m
147                newYPositions/= m
148                maxDist = max(numpy.sqrt((newXPositions - oldXPositions)**2 + (newYPositions - oldYPositions)**2))
149                steps = int(maxDist * 300)
150                step = 0
151            midX = newXPositions * step/steps + oldXPositions * (steps-step)/steps
152            midY = newYPositions * step/steps + oldYPositions * (steps-step)/steps
153            self.graph.anchorData = [(midX[i], midY[i], labels[i]) for i in range(len(labels))]
154            step += self.touringSpeed
155            self.graph.updateData()
156            if step % 10 == 0:
157                qApp.processEvents()
158            #self.graph.repaint()
159                       
160       
161    def stopRandomTouring(self):
162        self.startTourButton.show()
163        self.stopTourButton.hide()
164        if hasattr(self, "canTour"):
165            delattr(self, "canTour")
166
167
168    # ##############################################################
169    # EVENTS
170    # ##############################################################
171    def setManualPosition(self):
172        self.parentWidget.graph.manualPositioning = self.manualPositioningButton.isChecked()
173
174    def updateForces(self):
175        if self.disableAttractive or self.disableRepulsive:
176            self.attractG, self.repelG = 1 - self.disableAttractive, 1 - self.disableRepulsive
177            self.cbforcerel.setDisabled(True)
178            self.cbforcebal.setDisabled(True)
179        else:
180            self.attractG, self.repelG = self.attractRepelValues[self.forceRelation]
181            self.cbforcerel.setDisabled(False)
182            self.cbforcebal.setDisabled(False)
183
184        self.printEvent("Updated: %i, %i" % (self.attractG, self.repelG), eventVerbosity = 1)
185
186    def forceLawChanged(self):
187        self.spinSigma.setDisabled(self.cbLaw.currentIndex() not in [2, 3])
188
189    def setRestraints(self):
190        if self.restrain:
191            attrList = self.getShownAttributeList()
192            if not attrList:
193                return
194
195            if "3d" in self.parentName.lower():
196                positions = numpy.array([x[:3] for x in self.graph.anchorData])
197                if self.restrain == 1:
198                    positions = numpy.transpose(positions) * numpy.sum(positions**2, 1)**-0.5
199                    self.graph.anchorData = [(positions[0][i], positions[1][i], positions[2][i], a) for i, a in enumerate(attrList)]
200                else:
201                    self.graph.create_anchors(len(attrList), attrList)
202
203                self.graph.updateData()
204                self.graph.repaint()
205                return
206
207            positions = numpy.array([x[:2] for x in self.graph.anchorData])
208
209            if self.restrain == 1:
210                positions = numpy.transpose(positions) * numpy.sum(positions**2,1)**-0.5
211                self.graph.setAnchors(positions[0], positions[1], attrList)
212                #self.graph.anchorData = [(positions[0][i], positions[1][i], a) for i, a in enumerate(attrList)]
213            else:
214                r = numpy.sqrt(numpy.sum(positions**2, 1))
215                phi = 2*math.pi/len(r)
216                self.graph.anchorData = [(r[i] * math.cos(i*phi), r[i] * math.sin(i*phi), a) for i, a in enumerate(attrList)]
217
218            self.graph.updateData()
219            self.graph.repaint()
220
221
222    def setDisableAttractive(self):
223        if self.cbDisableAttractive.isChecked():
224            self.disableRepulsive = 0
225        self.updateForces()
226
227    def setDisableRepulsive(self):
228        if self.cbDisableRepulsive.isChecked():
229            self.disableAttractive = 0
230        self.updateForces()
231
232    # ###############################################################
233    ## FREE VIZ FUNCTIONS
234    # ###############################################################
235    def randomAnchors(self):
236        FreeViz.randomAnchors(self)
237        self.graph.updateData()
238        self.graph.repaint()
239        #self.recomputeEnergy()
240
241    def radialAnchors(self):
242        FreeViz.radialAnchors(self)
243        self.graph.updateData()
244        self.graph.repaint()
245        #self.recomputeEnergy()
246
247    def removeHidden(self):
248        rad2 = (self.graph.hideRadius/10)**2
249        newAnchorData = []
250        shownAttrList = []
251        for i, t in enumerate(self.graph.anchorData):
252            if t[0]**2 + t[1]**2 >= rad2:
253                shownAttrList.append(t[2])
254                newAnchorData.append(t)
255        self.parentWidget.setShownAttributeList(shownAttrList)
256        self.graph.anchorData = newAnchorData
257        self.graph.updateData()
258        self.graph.repaint()
259        #self.recomputeEnergy()
260
261    def singleStepOptimization(self):
262        FreeViz.optimizeSeparation(self, 1, 1)
263        self.graph.potentialsBmp = None
264        self.graph.updateData()
265
266    def optimizeSeparation(self, steps = 10, singleStep = False):
267        self.optimizeButton.hide()
268        self.stopButton.show()
269        self.cancelOptimization = 0
270        #qApp.processEvents()
271       
272        if hasattr(self.graph, 'animate_points'):
273            self.graph_is_animated = self.graph.animate_points
274            self.graph.animate_points = False
275
276        ns = FreeViz.optimizeSeparation(self, self.stepsBeforeUpdate, singleStep, self.parentWidget.distances)
277
278        self.graph.potentialsBmp = None
279        self.graph.updateData()
280
281        self.stopButton.hide()
282        self.optimizeButton.show()
283
284    def stopOptimization(self):
285        self.cancelOptimization = 1
286        if hasattr(self, 'graph_is_animated'):
287            self.graph.animate_points = self.graph_is_animated
288
289#    # #############################################################
290#    # DIFFERENTIAL EVOLUTION
291#    # #############################################################
292#    def createPopulation(self):
293#        if not self.graph.haveData: return
294#        l = len(self.graph.dataDomain.attributes)
295#        self.DERadvizSolver = RadvizSolver(self.parentWidget, l * 2 , self.differentialEvolutionPopSize)
296#        Min = [0.0] * 2* l
297#        Max = [1.0] * 2* l
298#        self.DERadvizSolver.Setup(Min, Max, 0, 0.95, 1)
299#
300#    def evolvePopulation(self):
301#        if not self.graph.haveData: return
302#        if not self.DERadvizSolver:
303#            QMessageBox.critical( None, "Differential evolution", 'To evolve a population you first have to create one by pressing "Create population" button', QMessageBox.Ok)
304#
305#        self.DERadvizSolver.Solve(5)
306#        solution = self.DERadvizSolver.Solution()
307#        self.graph.anchorData = [(solution[2*i], solution[2*i+1], self.graph.dataDomain.attributes[i].name) for i in range(len(self.graph.dataDomain.attributes))]
308#        self.graph.updateData([attr.name for attr in self.graph.dataDomain.attributes], 0)
309#        self.graph.repaint()
310
311    def findPCAProjection(self):
312        self.findProjection(DR_PCA, setAnchors = 1)
313
314    def findSPCAProjection(self):
315        if not self.graph.dataHasClass: 
316            QMessageBox.information( None, self.parentName, 'Supervised PCA can only be applied on data with a class attribute.', QMessageBox.Ok + QMessageBox.Default)
317            return
318        self.findProjection(DR_SPCA, setAnchors = 1)
319
320    def findPLSProjection(self):
321        self.findProjection(DR_PLS, setAnchors = 1)
322       
323    def hideEvent(self, ev):
324        self.stopRandomTouring()        # if we were touring then stop
325        self.saveSettings()
326        OWWidget.hideEvent(self, ev)
327
328
329    # if autoSetParameters is set then try different values for parameters and see how good projection do we get
330    # if not then just use current parameters to place anchors
331    def s2nMixAnchorsAutoSet(self):
332        # check if we have data and a discrete class
333        if not self.graph.haveData or len(self.graph.rawData) == 0 or not self.graph.dataHasDiscreteClass:
334            self.setStatusBarText("No data or data without a discrete class") 
335            return
336
337        vizrank = self.parentWidget.vizrank
338        if self.__class__ != FreeViz: from PyQt4.QtGui import qApp
339
340        if self.autoSetParameters:
341            results = {}
342            self.s2nSpread = 0
343            permutations = orngVisFuncts.generateDifferentPermutations(range(len(self.graph.dataDomain.classVar.values)))
344            for perm in permutations:
345                self.classPermutationList = perm
346                for val in self.attrsNum:
347                    if self.attrsNum[self.attrsNum.index(val)-1] > len(self.graph.dataDomain.attributes): continue    # allow the computations once
348                    self.s2nPlaceAttributes = val
349                    if not self.s2nMixAnchors(0):
350                        return
351                    if self.__class__ != FreeViz:
352                        qApp.processEvents()
353
354                    acc, other = vizrank.kNNComputeAccuracy(self.graph.createProjectionAsExampleTable(None, useAnchorData = 1))
355                    if results.keys() != []: self.setStatusBarText("Current projection value is %.2f (best is %.2f)" % (acc, max(results.keys())))
356                    else:                    self.setStatusBarText("Current projection value is %.2f" % (acc))
357
358                    results[acc] = (perm, val)
359            if results.keys() == []: return
360            self.classPermutationList, self.s2nPlaceAttributes = results[max(results.keys())]
361            if self.__class__ != FreeViz:
362                qApp.processEvents()
363            if not self.s2nMixAnchors(0):        # update the best number of attributes
364                return
365
366            results = []
367            anchors = self.graph.anchorData
368            attributeNameIndex = self.graph.attributeNameIndex
369            attrIndices = [attributeNameIndex[val[2]] for val in anchors]
370            for val in range(10):
371                self.s2nSpread = val
372                if not self.s2nMixAnchors(0):
373                    return
374                acc, other = vizrank.kNNComputeAccuracy(self.graph.createProjectionAsExampleTable(attrIndices, useAnchorData = 1))
375                results.append(acc)
376                if results != []: self.setStatusBarText("Current projection value is %.2f (best is %.2f)" % (acc, max(results)))
377                else:             self.setStatusBarText("Current projection value is %.2f" % (acc))
378            self.s2nSpread = results.index(max(results))
379
380            self.setStatusBarText("Best projection value is %.2f" % (max(results)))
381
382        # always call this. if autoSetParameters then because we need to set the attribute list in radviz. otherwise because it finds the best attributes for current settings
383        self.s2nMixAnchors()
384
385
386
387# #############################################################################
388# class that represents S2N Heuristic classifier
389class S2NHeuristicClassifier(orange.Classifier):
390    def __init__(self, optimizationDlg, radvizWidget, data, nrOfFreeVizSteps = 0):
391        self.optimizationDlg = optimizationDlg
392        self.radvizWidget = radvizWidget
393
394        self.radvizWidget.setData(data)
395        self.optimizationDlg.s2nMixAnchorsAutoSet()
396
397        if nrOfFreeVizSteps > 0:
398            self.optimizationDlg.optimize(nrOfFreeVizSteps)
399
400    # for a given example run argumentation and find out to which class it most often fall
401    def __call__(self, example, returnType):
402        table = orange.ExampleTable(example.domain)
403        table.append(example)
404        self.radvizWidget.setSubsetData(table)       # show the example is we use the widget
405        self.radvizWidget.handleNewSignals()
406
407        anchorData = self.radvizWidget.graph.anchorData
408        attributeNameIndex = self.radvizWidget.graph.attributeNameIndex
409        scaleFunction = self.radvizWidget.graph.scaleExampleValue
410
411        attrListIndices = [attributeNameIndex[val[2]] for val in anchorData]
412        attrVals = [scaleFunction(example, index) for index in attrListIndices]
413
414        table = self.radvizWidget.graph.createProjectionAsExampleTable(attrListIndices, scaleFactor = self.radvizWidget.graph.trueScaleFactor, useAnchorData = 1)
415        knn = self.radvizWidget.optimizationDlg.createkNNLearner(kValueFormula = 0)(table)
416
417        [xTest, yTest] = self.radvizWidget.graph.getProjectedPointPosition(attrListIndices, attrVals, useAnchorData = 1)
418        (classVal, prob) = knn(orange.Example(table.domain, [xTest, yTest, "?"]), orange.GetBoth)
419
420        if returnType == orange.GetBoth: return classVal, prob
421        else:                            return classVal
422
423
424class S2NHeuristicLearner(orange.Learner):
425    def __init__(self, optimizationDlg, radvizWidget):
426        self.radvizWidget = radvizWidget
427        self.optimizationDlg = optimizationDlg
428        self.name = "S2N Feature Selection Learner"
429
430    def __call__(self, examples, weightID = 0, nrOfFreeVizSteps = 0):
431        return S2NHeuristicClassifier(self.optimizationDlg, self.radvizWidget, examples, nrOfFreeVizSteps)
432
433
434
435
436#test widget appearance
437if __name__=="__main__":
438    import sys
439    a=QApplication(sys.argv)
440    ow=FreeVizOptimization()
441    ow.show()
442    a.exec_()
443
Note: See TracBrowser for help on using the repository browser.