source: orange/orange/OrangeWidgets/OWContexts.py @ 8042:ffcb93bc9028

Revision 8042:ffcb93bc9028, 25.6 KB checked in by markotoplak, 3 years ago (diff)

Hierarchical clustering: also catch RuntimeError when importing matplotlib (or the documentation could not be built on server).

Line 
1import time, copy, orange
2from string import *
3
4contextStructureVersion = 100
5
6class Context:
7    def __init__(self, **argkw):
8        self.time = time.time()
9        self.__dict__.update(argkw)
10
11    def __getstate__(self):
12        s = dict(self.__dict__)
13        for nc in getattr(self, "noCopy", []):
14            if s.has_key(nc):
15                del s[nc]
16        return s
17
18   
19class ContextHandler:
20    maxSavedContexts = 50
21   
22    def __init__(self, contextName = "", cloneIfImperfect = True, findImperfect = True, syncWithGlobal = True, contextDataVersion = 0, **args):
23        self.contextName = contextName
24        self.localContextName = "localContexts"+contextName
25        self.cloneIfImperfect, self.findImperfect = cloneIfImperfect, findImperfect
26        self.contextDataVersion = contextDataVersion
27        self.syncWithGlobal = syncWithGlobal
28        self.globalContexts = []
29        self.__dict__.update(args)
30
31    def newContext(self):
32        return Context()
33
34    def openContext(self, widget, *arg, **argkw):
35        context, isNew = self.findOrCreateContext(widget, *arg, **argkw)
36        if context:
37            if isNew:
38                self.settingsFromWidget(widget, context)
39            else:
40                self.settingsToWidget(widget, context)
41        return context
42
43    def initLocalContext(self, widget):
44        if not hasattr(widget, self.localContextName):
45            if self.syncWithGlobal:
46                setattr(widget, self.localContextName, self.globalContexts)
47            else:
48                setattr(widget, self.localContextName, copy.deepcopy(self.globalContexts))
49       
50    def findOrCreateContext(self, widget, *arg, **argkw):       
51        index, context, score = self.findMatch(widget, self.findImperfect, *arg, **argkw)
52        if context:
53            if index < 0:
54                self.addContext(widget, context)
55            else:
56                self.moveContextUp(widget, index)
57            return context, False
58        else:
59            context = self.newContext()
60            self.addContext(widget, context)
61            return context, True
62
63    def closeContext(self, widget, context):
64        self.settingsFromWidget(widget, context)
65
66    def fastSave(self, context, widget, name, value):
67        pass
68
69    def settingsToWidget(self, widget, context):
70        cb = getattr(widget, "settingsToWidgetCallback" + self.contextName, None)
71        return cb and cb(self, context)
72
73    def settingsFromWidget(self, widget, context):
74        cb = getattr(widget, "settingsFromWidgetCallback" + self.contextName, None)
75        return cb and cb(self, context)
76
77    def findMatch(self, widget, imperfect = True, *arg, **argkw):
78        bestI, bestContext, bestScore = -1, None, -1
79        for i, c in enumerate(getattr(widget, self.localContextName)):
80            score = self.match(c, imperfect, *arg, **argkw)
81            if score == 2:
82                return i, c, score
83            if score and score > bestScore:
84                bestI, bestContext, bestScore = i, c, score
85
86        if bestContext and self.cloneIfImperfect:
87            if hasattr(self, "cloneContext"):
88                bestContext = self.cloneContext(bestContext, *arg, **argkw)
89            else:
90                bestContext = copy.deepcopy(bestContext)
91            bestI = -1
92               
93        return bestI, bestContext, bestScore
94           
95    def moveContextUp(self, widget, index):
96        localContexts = getattr(widget, self.localContextName)
97        l = getattr(widget, self.localContextName)
98        context = l.pop(index)
99        context.time = time.time()
100        l.insert(0, context)
101
102    def addContext(self, widget, context):
103        l = getattr(widget, self.localContextName)
104        l.insert(0, context)
105        while len(l) > self.maxSavedContexts:
106            del l[-1]
107
108    def mergeBack(self, widget):
109        if not self.syncWithGlobal or getattr(widget, self.localContextName) is not  self.globalContexts:
110            self.globalContexts.extend([c for c in getattr(widget, self.localContextName) if c not in self.globalContexts])
111            self.globalContexts.sort(lambda c1,c2: -cmp(c1.time, c2.time))
112            self.globalContexts[:] = self.globalContexts[:self.maxSavedContexts]
113
114
115class ContextField:
116    def __init__(self, name, flags = 0, **argkw):
117        self.name = name
118        self.flags = flags
119        self.__dict__.update(argkw)
120   
121
122class DomainContextHandler(ContextHandler):
123    Optional, SelectedRequired, Required = range(3)
124    RequirementMask = 3
125    NotAttribute = 4
126    List = 8
127    RequiredList = Required + List
128    SelectedRequiredList = SelectedRequired + List
129    ExcludeOrdinaryAttributes, IncludeMetaAttributes = 16, 32
130
131    MatchValuesNo, MatchValuesClass, MatchValuesAttributes = range(3)
132   
133    def __init__(self, contextName, fields = [],
134                 cloneIfImperfect = True, findImperfect = True, syncWithGlobal = True,
135                 maxAttributesToPickle = 100, matchValues = 0, 
136                 forceOrdinaryAttributes = False, forceMetaAttributes = False, contextDataVersion = 0, **args):
137        ContextHandler.__init__(self, contextName, cloneIfImperfect, findImperfect, syncWithGlobal, contextDataVersion = contextDataVersion, **args)
138        self.maxAttributesToPickle = maxAttributesToPickle
139        self.matchValues = matchValues
140        self.fields = []
141        hasMetaAttributes = hasOrdinaryAttributes = False
142
143        for field in fields:
144            if isinstance(field, ContextField):
145                self.fields.append(field)
146                if not field.flags & self.NotAttribute:
147                    hasOrdinaryAttributes = hasOrdinaryAttributes or not field.flags & self.ExcludeOrdinaryAttributes
148                    hasMetaAttributes = hasMetaAttributes or field.flags & self.IncludeMetaAttributes
149            elif type(field)==str:
150                self.fields.append(ContextField(field, self.Required))
151                hasOrdinaryAttributes = True
152            # else it's a tuple
153            else:
154                flags = field[1]
155                if isinstance(field[0], list):
156                    self.fields.extend([ContextField(x, flags) for x in field[0]])
157                else:
158                    self.fields.append(ContextField(field[0], flags))
159                if not flags & self.NotAttribute:
160                    hasOrdinaryAttributes = hasOrdinaryAttributes or not flags & self.ExcludeOrdinaryAttributes
161                    hasMetaAttributes = hasMetaAttributes or flags & self.IncludeMetaAttributes
162                   
163        self.hasOrdinaryAttributes, self.hasMetaAttributes = hasOrdinaryAttributes, hasMetaAttributes
164       
165    def encodeDomain(self, domain):
166        if self.matchValues == 2:
167            attributes = self.hasOrdinaryAttributes and \
168                         dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
169                                for attr in domain])
170            metas = self.hasMetaAttributes and \
171                         dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
172                                for attr in domain.getmetas().values()])
173        else:
174            if self.hasOrdinaryAttributes:
175                attributes = dict([(attr.name, attr.varType) for attr in domain.attributes])
176                classVar = domain.classVar
177                if classVar:
178                    if self.matchValues and classVar.varType == orange.VarTypes.Discrete:
179                        attributes[classVar.name] = classVar.values
180                    else:
181                        attributes[classVar.name] = classVar.varType
182            else:
183                attributes = False
184
185            metas = self.hasMetaAttributes and dict([(attr.name, attr.varType) for attr in domain.getmetas().values()]) or {}
186
187        return attributes, metas
188   
189    def findOrCreateContext(self, widget, domain):
190        if not domain:
191            return None, False
192       
193        if not isinstance(domain, orange.Domain):
194            domain = domain.domain
195           
196        encodedDomain = self.encodeDomain(domain)
197        context, isNew = ContextHandler.findOrCreateContext(self, widget, domain, *encodedDomain)
198        if not context:
199            return None, False
200       
201        if len(encodedDomain) == 2:
202            context.attributes, context.metas = encodedDomain
203        else:
204            context.attributes, context.classVar, context.metas = encodedDomain
205
206        metaIds = domain.getmetas().keys()
207        metaIds.sort()
208        context.orderedDomain = []
209        if self.hasOrdinaryAttributes:
210            context.orderedDomain.extend([(attr.name, attr.varType) for attr in domain])
211        if self.hasMetaAttributes:
212            context.orderedDomain.extend([(domain[i].name, domain[i].varType) for i in metaIds])
213
214        if isNew:
215            context.values = {}
216            context.noCopy = ["orderedDomain"]
217        return context, isNew
218
219#    def exists(self, field, value, context):
220#        for check, what in ((not field.flags & self.ExcludeOrdinaryAttributes, context.attributes),
221#                            (field.flags & self.IncludeMetaAttributes, context.metas)):
222#            if check:
223#                inDomainType = context.attributes.get(value[0])
224#                if isinstance(savedType, list):
225#                    if what.get(value[0], False) == savedType
226#                                f
227#                    if saved
228             
229    def settingsToWidget(self, widget, context):
230        ContextHandler.settingsToWidget(self, widget, context)
231        excluded = {}
232        addOrdinaryTo = []
233        addMetaTo = []
234
235        def set_if_all_hashable(iter):
236            try:
237                return set(iter)
238            except TypeError:
239                return list(iter)
240           
241        def attrSet(attrs):
242            if isinstance(attrs, dict):
243                return set_if_all_hashable(attrs.items())
244            elif isinstance(attrs, bool):
245                return {}
246            else:
247                return set([])
248           
249        attrItemsSet = attrSet(context.attributes)
250        metaItemsSet = attrSet(context.metas)
251        for field in self.fields:
252            name, flags = field.name, field.flags
253
254            excludes = getattr(field, "reservoir", [])
255            if excludes:
256                if not isinstance(excludes, list):
257                    excludes = [excludes]
258                for exclude in excludes:
259                    excluded.setdefault(exclude, [])
260                    if not (flags & self.NotAttribute + self.ExcludeOrdinaryAttributes):
261                        addOrdinaryTo.append(exclude)
262                    if flags & self.IncludeMetaAttributes:
263                        addMetaTo.append(exclude)                   
264
265            if not context.values.has_key(name):
266                continue
267           
268            value = context.values[name]
269
270            if not flags & self.List:
271# TODO: is setattr supposed to check that we do not assign values that are optional and do not exist?
272# is context cloning's filter enough to get rid of such attributes?
273                setattr(widget, name, value[0])
274                for exclude in excludes:
275                    excluded[exclude].append(value)
276
277            else:
278                newLabels, newSelected = [], []
279                oldSelected = hasattr(field, "selected") and context.values.get(field.selected, []) or []
280                for i, saved in enumerate(value):
281                    if not flags & self.ExcludeOrdinaryAttributes and (saved in context.attributes or saved in attrItemsSet) \
282                       or flags & self.IncludeMetaAttributes and (saved in context.metas or saved in metaItemsSet):
283                        if i in oldSelected:
284                            newSelected.append(len(newLabels))
285                        newLabels.append(saved)
286
287                context.values[name] = newLabels
288                setattr(widget, name, value)
289
290                if hasattr(field, "selected"):
291                    context.values[field.selected] = newSelected
292                    setattr(widget, field.selected, context.values[field.selected])
293
294                for exclude in excludes:
295                    excluded[exclude].extend(value)
296
297        for name, values in excluded.items():
298            addOrd, addMeta = name in addOrdinaryTo, name in addMetaTo
299            ll = [a for a in context.orderedDomain if a not in values and ((addOrd and context.attributes.get(a[0], None) == a[1]) or (addMeta and context.metas.get(a[0], None) == a[1]))]
300            setattr(widget, name, ll)
301           
302
303    def settingsFromWidget(self, widget, context):
304        ContextHandler.settingsFromWidget(self, widget, context)
305        context.values = {}
306        for field in self.fields:
307            if not field.flags & self.List:
308                self.saveLow(context, widget, field.name, widget.getdeepattr(field.name), field.flags)
309            else:
310                value = widget.getdeepattr(field.name)
311                # shallow copy of the list
312                context.values[field.name] = copy.copy(value) #type(value)(value)
313                if hasattr(field, "selected"):
314                    context.values[field.selected] = list(widget.getdeepattr(field.selected))
315
316    def fastSave(self, context, widget, name, value):
317        if context:
318            for field in self.fields:
319                if name == field.name:
320                    if field.flags & self.List:
321                        # shallow copy of the list
322                        context.values[field.name] = copy.copy(value) #type(value)(value)
323                    else:
324                        self.saveLow(context, widget, name, value, field.flags)
325                    return
326                if name == getattr(field, "selected", None):
327                    context.values[field.selected] = list(value)
328                    return
329
330    def saveLow(self, context, widget, field, value, flags):
331        # The code below uses type(value)(value) to make at least a shallow copy of the
332        # attributes. This will mostly make copies of integers, floats and strings, but
333        # it is crucial to make copies of lists
334        value = copy.copy(value) #type(value)(value)
335        if isinstance(value, str):
336            valtype = not flags & self.ExcludeOrdinaryAttributes and context.attributes.get(value, -1)
337            if valtype == -1:
338                valtype = flags & self.IncludeMetaAttributes and context.attributes.get(value, -1)
339            context.values[field] = value, valtype # -1 means it's not an attribute
340        else:
341            context.values[field] = value, -2
342
343    def attributeExists(self, value, flags, attributes, metas):
344        return not flags & self.ExcludeOrdinaryAttributes and attributes.get(value[0], -1) == value[1] \
345                or flags & self.IncludeMetaAttributes and metas.get(value[0], -1) == value[1]
346   
347    def match(self, context, imperfect, domain, attributes, metas):
348        if (attributes, metas) == (context.attributes, context.metas):
349            return 2
350        if not imperfect:
351            return 0
352
353        filled = potentiallyFilled = 0
354        for field in self.fields:
355            flags = field.flags
356            value = context.values.get(field.name, None)
357            if value:
358                if flags & self.List:
359                    if flags & self.RequirementMask == self.Required:
360                        potentiallyFilled += len(value)
361                        filled += len(value)
362                        for item in value:
363                            if not self.attributeExists(item, flags, attributes, metas): 
364                                return 0
365                    else:
366                        selectedRequired = field.flags & self.RequirementMask == self.SelectedRequired
367                        selected = context.values.get(field.selected, [])
368                        potentiallyFilled += len(selected) 
369                        for i in selected:
370                            # TODO: shouldn't we check the attribute type here, too? or should we change self.saveLow for these field types then?
371                            if (not flags & self.ExcludeOrdinaryAttributes and value[i] in attributes
372                                 or flags & self.IncludeMetaAttributes and value[i] in metas): 
373                                filled += 1
374                            else:
375                                if selectedRequired:
376                                    return 0
377                else:
378                    potentiallyFilled += 1
379                    if value[1] >= 0:
380                        if (not flags & self.ExcludeOrdinaryAttributes and attributes.get(value[0], None) == value[1]
381                             or flags & self.IncludeMetaAttributes and metas.get(value[0], None) == value[1]): 
382                            filled += 1
383                        else:
384                            if flags & self.Required:
385                                return 0
386
387            if not potentiallyFilled:
388                return 1.0
389            else:
390                return filled / float(potentiallyFilled)
391
392    def cloneContext(self, context, domain, attributes, metas):
393        import copy
394        context = copy.deepcopy(context)
395       
396        for field in self.fields:
397            value = context.values.get(field.name, None)
398            if value:
399                if field.flags & self.List:
400                    i = j = realI = 0
401                    selected = context.values.get(field.selected, [])
402                    selected.sort()
403                    nextSel = selected and selected[0] or None
404                    while i < len(value):
405                        if not self.attributeExists(value[i], field.flags, attributes, metas):
406                            del value[i]
407                            if nextSel == realI:
408                                del selected[j]
409                                nextSel = j < len(selected) and selected[j] or None
410                        else:
411                            if nextSel == realI:
412                                selected[j] -= realI - i
413                                j += 1
414                                nextSel = j < len(selected) and selected[j] or None
415                            i += 1
416                        realI += 1
417                    if hasattr(field, "selected"):
418                        context.values[field.selected] = selected[:j]
419                else:
420                    if value[1] >= 0 and not self.attributeExists(value, field.flags, attributes, metas):
421                        del context.values[field.name]
422                       
423        context.attributes, context.metas = attributes, metas
424        context.orderedDomain = [(attr.name, attr.varType) for attr in domain]
425        return context
426
427    # this is overloaded to get rid of the huge domains
428    def mergeBack(self, widget):
429        if not self.syncWithGlobal or getattr(widget, self.localContextName) is not self.globalContexts:
430            self.globalContexts.extend([c for c in getattr(widget, self.localContextName) if c not in self.globalContexts])
431            mp = self.maxAttributesToPickle
432            self.globalContexts[:] = filter(lambda c: (c.attributes and len(c.attributes) or 0) + (c.metas and len(c.metas) or 0) < mp, self.globalContexts)
433            self.globalContexts.sort(lambda c1,c2: -cmp(c1.time, c2.time))
434            self.globalContexts[:] = self.globalContexts[:self.maxSavedContexts]
435
436   
437class ClassValuesContextHandler(ContextHandler):
438    def __init__(self, contextName, fields = [], syncWithGlobal = True, contextDataVersion = 0, **args):
439        ContextHandler.__init__(self, contextName, False, False, syncWithGlobal, contextDataVersion = contextDataVersion, **args)
440        if isinstance(fields, list):
441            self.fields = fields
442        else:
443            self.fields = [fields]
444       
445    def findOrCreateContext(self, widget, classes):
446        if isinstance(classes, orange.Variable):
447            classes = classes.varType == orange.VarTypes.Discrete and classes.values
448        if not classes:
449            return None, False
450        context, isNew = ContextHandler.findOrCreateContext(self, widget, classes)
451        if not context:
452            return None, False
453        context.classes = classes
454        if isNew:
455            context.values = {}
456        return context, isNew
457
458    def settingsToWidget(self, widget, context):
459        ContextHandler.settingsToWidget(self, widget, context)
460        for field in self.fields:
461            setattr(widget, field, context.values[field])
462           
463    def settingsFromWidget(self, widget, context):
464        ContextHandler.settingsFromWidget(self, widget, context)
465        # shallow copy!
466        values = context.values = {}
467        for field in self.fields:
468            value = widget.getdeepattr(field)
469            values[field] = copy.copy(value) #type(value)(value)
470
471    def fastSave(self, context, widget, name, value):
472        if context and name in self.fields:
473            # shallow copy!
474            context.values[name] = copy.copy(value) #type(value)(value)
475
476    def match(self, context, imperfect, classes):
477        return context.classes == classes and 2
478
479    def cloneContext(self, context, domain, encodedDomain):
480        import copy
481        return copy.deepcopy(context)
482       
483
484
485### Requires the same the same attributes in the same order
486### The class overloads domain encoding and matching.
487### Due to different encoding, it also needs to overload saveLow and cloneContext
488### (the latter gets really simple now).
489### We could simplify some other methods, but prefer not to replicate the code
490###
491### Note that forceOrdinaryAttributes is here True by default!
492class PerfectDomainContextHandler(DomainContextHandler):
493    def __init__(self, contextName = "", fields = [],
494                 syncWithGlobal = True, **args):
495            DomainContextHandler.__init__(self, contextName, fields, False, False, syncWithGlobal, **args)
496
497       
498    def encodeDomain(self, domain):
499        if self.matchValues == 2:
500            attributes = tuple([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
501                         for attr in domain])
502            classVar = domain.classVar
503            if classVar:
504                classVar = classVar.name, classVar.varType != orange.VarTypes.Discrete and classVar.varType or classVar.values
505            metas = dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
506                         for attr in domain.getmetas().values()])
507        else:
508            attributes = tuple([(attr.name, attr.varType) for attr in domain.attributes])
509            classVar = domain.classVar
510            if classVar:
511                classVar = classVar.name, classVar.varType
512            metas = dict([(attr.name, attr.varType) for attr in domain.getmetas().values()])
513        return attributes, classVar, metas
514   
515
516
517    def match(self, context, imperfect, domain, attributes, classVar, metas):
518        return (attributes, classVar, metas) == (context.attributes, context.classVar, context.metas) and 2
519
520
521    def saveLow(self, context, widget, field, value, flags):
522        if isinstance(value, str):
523            if not flags & self.ExcludeOrdinaryAttributes:
524                attr = [x[1] for x in context.attributes if x[0] == value]
525            if not attr and context.classVar and context.classVar[0] == value:
526                attr = [context.classVar[1]]
527            if not attr and flags & self.IncludeMetaAttributes:
528                attr = [x[1] for x in context.metas if x[0] == value]
529
530            value = copy.copy(value) #type(value)(value)
531            if attr:
532                context.values[field] = value, attr[0]
533            else:
534                context.values[field] = value, -1
535        else:
536            context.values[field] = value, -2
537
538
539    def cloneContext(self, context, domain, encodedDomain):
540        import copy
541        context = copy.deepcopy(context)
542       
543
544class EvaluationResultsContextHandler(ContextHandler):
545    def __init__(self, contextName, targetAttr, selectedAttr,
546                 syncWithGlobal = True, **args):
547            self.targetAttr, self.selectedAttr = targetAttr, selectedAttr
548            ContextHandler.__init__(self, contextName, False, False, syncWithGlobal, **args)
549       
550    def match(self, context, imperfect, cnames, cvalues):
551        return (cnames, cvalues) == (context.classifierNames, context.classValues) and 2
552
553    def fastSave(self, context, widget, name, value):
554        if context:
555            if name == self.targetAttr:
556                context.targetClass = value
557            elif name == self.selectedAttr:
558                context.selectedClassifiers = list(value)
559
560    def settingsFromWidget(self, widget, context):
561        context.targetClass = widget.getdeepattr(self.targetAttr)
562        context.selectedClassifiers = list(widget.getdeepattr(self.selectedAttr))
563
564    def settingsToWidget(self, widget, context):
565        if context.targetClass is not None:
566            setattr(widget, self.targetAttr, context.targetClass)
567        if context.selectedClassifiers is not None:
568            setattr(widget, self.selectedAttr, context.selectedClassifiers)
569           
570    def cloneContext(self, context, domain):
571        import copy
572        context = copy.deepcopy(context)
573       
574    def findOrCreateContext(self, widget, results):
575        if not results:
576            return None, False
577        cnames = [c.name for c in results.classifiers]
578        cvalues = results.classValues
579        context, isNew = ContextHandler.findOrCreateContext(self, widget, results.classifierNames, results.classValues)
580        if not context:
581            return None, False
582        if isNew:
583            context.classifierNames = results.classifierNames
584            context.classValues = results.classValues
585            context.selectedClassifiers = None
586            context.targetClass = None
587        return context, isNew
Note: See TracBrowser for help on using the repository browser.