source: orange/orange/OrangeWidgets/OWContexts.py @ 9602:ca787805af6e

Revision 9602:ca787805af6e, 25.8 KB checked in by ales_erjavec, 2 years ago (diff)

Added parentContext parameter to initLocalContext

Line 
1import time, copy, orange
2from string import *
3
4contextStructureVersion = 100
5
6class Context:
7    def __init__(self, **argkw):
8        self.time = time.time()
9        self.__dict__.update(argkw)
10
11    def __getstate__(self):
12        s = dict(self.__dict__)
13        for nc in getattr(self, "noCopy", []):
14            if s.has_key(nc):
15                del s[nc]
16        return s
17
18   
19class ContextHandler:
20    maxSavedContexts = 50
21   
22    def __init__(self, contextName = "", cloneIfImperfect = True, findImperfect = True, syncWithGlobal = True, contextDataVersion = 0, **args):
23        self.contextName = contextName
24        self.localContextName = "localContexts"+contextName
25        self.cloneIfImperfect, self.findImperfect = cloneIfImperfect, findImperfect
26        self.contextDataVersion = contextDataVersion
27        self.syncWithGlobal = syncWithGlobal
28        self.globalContexts = []
29        self.__dict__.update(args)
30
31    def newContext(self):
32        return Context()
33
34    def openContext(self, widget, *arg, **argkw):
35        context, isNew = self.findOrCreateContext(widget, *arg, **argkw)
36        if context:
37            if isNew:
38                self.settingsFromWidget(widget, context)
39            else:
40                self.settingsToWidget(widget, context)
41        return context
42
43    def initLocalContext(self, widget, parentContext=None):
44        if not hasattr(widget, self.localContextName):
45            if parentContext is None:
46                globalContexts = self.globalContexts
47            else:
48                # parentContext is a Schema level context repo
49                globalContexts = parentContext.globalContexts
50               
51            if self.syncWithGlobal:
52                setattr(widget, self.localContextName, globalContexts)
53            else:
54                setattr(widget, self.localContextName, copy.deepcopy(globalContexts))
55       
56    def findOrCreateContext(self, widget, *arg, **argkw):       
57        index, context, score = self.findMatch(widget, self.findImperfect, *arg, **argkw)
58        if context:
59            if index < 0:
60                self.addContext(widget, context)
61            else:
62                self.moveContextUp(widget, index)
63            return context, False
64        else:
65            context = self.newContext()
66            self.addContext(widget, context)
67            return context, True
68
69    def closeContext(self, widget, context):
70        self.settingsFromWidget(widget, context)
71
72    def fastSave(self, context, widget, name, value):
73        pass
74
75    def settingsToWidget(self, widget, context):
76        cb = getattr(widget, "settingsToWidgetCallback" + self.contextName, None)
77        return cb and cb(self, context)
78
79    def settingsFromWidget(self, widget, context):
80        cb = getattr(widget, "settingsFromWidgetCallback" + self.contextName, None)
81        return cb and cb(self, context)
82
83    def findMatch(self, widget, imperfect = True, *arg, **argkw):
84        bestI, bestContext, bestScore = -1, None, -1
85        for i, c in enumerate(getattr(widget, self.localContextName)):
86            score = self.match(c, imperfect, *arg, **argkw)
87            if score == 2:
88                return i, c, score
89            if score and score > bestScore:
90                bestI, bestContext, bestScore = i, c, score
91
92        if bestContext and self.cloneIfImperfect:
93            if hasattr(self, "cloneContext"):
94                bestContext = self.cloneContext(bestContext, *arg, **argkw)
95            else:
96                bestContext = copy.deepcopy(bestContext)
97            bestI = -1
98               
99        return bestI, bestContext, bestScore
100           
101    def moveContextUp(self, widget, index):
102        localContexts = getattr(widget, self.localContextName)
103        l = getattr(widget, self.localContextName)
104        context = l.pop(index)
105        context.time = time.time()
106        l.insert(0, context)
107
108    def addContext(self, widget, context):
109        l = getattr(widget, self.localContextName)
110        l.insert(0, context)
111        while len(l) > self.maxSavedContexts:
112            del l[-1]
113
114    def mergeBack(self, widget):
115        if not self.syncWithGlobal or getattr(widget, self.localContextName) is not  self.globalContexts:
116            self.globalContexts.extend([c for c in getattr(widget, self.localContextName) if c not in self.globalContexts])
117            self.globalContexts.sort(lambda c1,c2: -cmp(c1.time, c2.time))
118            self.globalContexts[:] = self.globalContexts[:self.maxSavedContexts]
119
120
121class ContextField:
122    def __init__(self, name, flags = 0, **argkw):
123        self.name = name
124        self.flags = flags
125        self.__dict__.update(argkw)
126   
127
128class DomainContextHandler(ContextHandler):
129    Optional, SelectedRequired, Required = range(3)
130    RequirementMask = 3
131    NotAttribute = 4
132    List = 8
133    RequiredList = Required + List
134    SelectedRequiredList = SelectedRequired + List
135    ExcludeOrdinaryAttributes, IncludeMetaAttributes = 16, 32
136
137    MatchValuesNo, MatchValuesClass, MatchValuesAttributes = range(3)
138   
139    def __init__(self, contextName, fields = [],
140                 cloneIfImperfect = True, findImperfect = True, syncWithGlobal = True,
141                 maxAttributesToPickle = 100, matchValues = 0, 
142                 forceOrdinaryAttributes = False, forceMetaAttributes = False, contextDataVersion = 0, **args):
143        ContextHandler.__init__(self, contextName, cloneIfImperfect, findImperfect, syncWithGlobal, contextDataVersion = contextDataVersion, **args)
144        self.maxAttributesToPickle = maxAttributesToPickle
145        self.matchValues = matchValues
146        self.fields = []
147        hasMetaAttributes = hasOrdinaryAttributes = False
148
149        for field in fields:
150            if isinstance(field, ContextField):
151                self.fields.append(field)
152                if not field.flags & self.NotAttribute:
153                    hasOrdinaryAttributes = hasOrdinaryAttributes or not field.flags & self.ExcludeOrdinaryAttributes
154                    hasMetaAttributes = hasMetaAttributes or field.flags & self.IncludeMetaAttributes
155            elif type(field)==str:
156                self.fields.append(ContextField(field, self.Required))
157                hasOrdinaryAttributes = True
158            # else it's a tuple
159            else:
160                flags = field[1]
161                if isinstance(field[0], list):
162                    self.fields.extend([ContextField(x, flags) for x in field[0]])
163                else:
164                    self.fields.append(ContextField(field[0], flags))
165                if not flags & self.NotAttribute:
166                    hasOrdinaryAttributes = hasOrdinaryAttributes or not flags & self.ExcludeOrdinaryAttributes
167                    hasMetaAttributes = hasMetaAttributes or flags & self.IncludeMetaAttributes
168                   
169        self.hasOrdinaryAttributes, self.hasMetaAttributes = hasOrdinaryAttributes, hasMetaAttributes
170       
171    def encodeDomain(self, domain):
172        if self.matchValues == 2:
173            attributes = self.hasOrdinaryAttributes and \
174                         dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
175                                for attr in domain])
176            metas = self.hasMetaAttributes and \
177                         dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
178                                for attr in domain.getmetas().values()])
179        else:
180            if self.hasOrdinaryAttributes:
181                attributes = dict([(attr.name, attr.varType) for attr in domain.attributes])
182                classVar = domain.classVar
183                if classVar:
184                    if self.matchValues and classVar.varType == orange.VarTypes.Discrete:
185                        attributes[classVar.name] = classVar.values
186                    else:
187                        attributes[classVar.name] = classVar.varType
188            else:
189                attributes = False
190
191            metas = self.hasMetaAttributes and dict([(attr.name, attr.varType) for attr in domain.getmetas().values()]) or {}
192
193        return attributes, metas
194   
195    def findOrCreateContext(self, widget, domain):
196        if not domain:
197            return None, False
198       
199        if not isinstance(domain, orange.Domain):
200            domain = domain.domain
201           
202        encodedDomain = self.encodeDomain(domain)
203        context, isNew = ContextHandler.findOrCreateContext(self, widget, domain, *encodedDomain)
204        if not context:
205            return None, False
206       
207        if len(encodedDomain) == 2:
208            context.attributes, context.metas = encodedDomain
209        else:
210            context.attributes, context.classVar, context.metas = encodedDomain
211
212        metaIds = domain.getmetas().keys()
213        metaIds.sort()
214        context.orderedDomain = []
215        if self.hasOrdinaryAttributes:
216            context.orderedDomain.extend([(attr.name, attr.varType) for attr in domain])
217        if self.hasMetaAttributes:
218            context.orderedDomain.extend([(domain[i].name, domain[i].varType) for i in metaIds])
219
220        if isNew:
221            context.values = {}
222            context.noCopy = ["orderedDomain"]
223        return context, isNew
224
225#    def exists(self, field, value, context):
226#        for check, what in ((not field.flags & self.ExcludeOrdinaryAttributes, context.attributes),
227#                            (field.flags & self.IncludeMetaAttributes, context.metas)):
228#            if check:
229#                inDomainType = context.attributes.get(value[0])
230#                if isinstance(savedType, list):
231#                    if what.get(value[0], False) == savedType
232#                                f
233#                    if saved
234             
235    def settingsToWidget(self, widget, context):
236        ContextHandler.settingsToWidget(self, widget, context)
237        excluded = {}
238        addOrdinaryTo = []
239        addMetaTo = []
240
241        def set_if_all_hashable(iter):
242            try:
243                return set(iter)
244            except TypeError:
245                return list(iter)
246           
247        def attrSet(attrs):
248            if isinstance(attrs, dict):
249                return set_if_all_hashable(attrs.items())
250            elif isinstance(attrs, bool):
251                return {}
252            else:
253                return set([])
254           
255        attrItemsSet = attrSet(context.attributes)
256        metaItemsSet = attrSet(context.metas)
257        for field in self.fields:
258            name, flags = field.name, field.flags
259
260            excludes = getattr(field, "reservoir", [])
261            if excludes:
262                if not isinstance(excludes, list):
263                    excludes = [excludes]
264                for exclude in excludes:
265                    excluded.setdefault(exclude, [])
266                    if not (flags & self.NotAttribute + self.ExcludeOrdinaryAttributes):
267                        addOrdinaryTo.append(exclude)
268                    if flags & self.IncludeMetaAttributes:
269                        addMetaTo.append(exclude)                   
270
271            if not context.values.has_key(name):
272                continue
273           
274            value = context.values[name]
275
276            if not flags & self.List:
277# TODO: is setattr supposed to check that we do not assign values that are optional and do not exist?
278# is context cloning's filter enough to get rid of such attributes?
279                setattr(widget, name, value[0])
280                for exclude in excludes:
281                    excluded[exclude].append(value)
282
283            else:
284                newLabels, newSelected = [], []
285                oldSelected = hasattr(field, "selected") and context.values.get(field.selected, []) or []
286                for i, saved in enumerate(value):
287                    if not flags & self.ExcludeOrdinaryAttributes and (saved in context.attributes or saved in attrItemsSet) \
288                       or flags & self.IncludeMetaAttributes and (saved in context.metas or saved in metaItemsSet):
289                        if i in oldSelected:
290                            newSelected.append(len(newLabels))
291                        newLabels.append(saved)
292
293                context.values[name] = newLabels
294                setattr(widget, name, value)
295
296                if hasattr(field, "selected"):
297                    context.values[field.selected] = newSelected
298                    setattr(widget, field.selected, context.values[field.selected])
299
300                for exclude in excludes:
301                    excluded[exclude].extend(value)
302
303        for name, values in excluded.items():
304            addOrd, addMeta = name in addOrdinaryTo, name in addMetaTo
305            ll = [a for a in context.orderedDomain if a not in values and ((addOrd and context.attributes.get(a[0], None) == a[1]) or (addMeta and context.metas.get(a[0], None) == a[1]))]
306            setattr(widget, name, ll)
307           
308
309    def settingsFromWidget(self, widget, context):
310        ContextHandler.settingsFromWidget(self, widget, context)
311        context.values = {}
312        for field in self.fields:
313            if not field.flags & self.List:
314                self.saveLow(context, widget, field.name, widget.getdeepattr(field.name), field.flags)
315            else:
316                value = widget.getdeepattr(field.name)
317                # shallow copy of the list
318                context.values[field.name] = copy.copy(value) #type(value)(value)
319                if hasattr(field, "selected"):
320                    context.values[field.selected] = list(widget.getdeepattr(field.selected))
321
322    def fastSave(self, context, widget, name, value):
323        if context:
324            for field in self.fields:
325                if name == field.name:
326                    if field.flags & self.List:
327                        # shallow copy of the list
328                        context.values[field.name] = copy.copy(value) #type(value)(value)
329                    else:
330                        self.saveLow(context, widget, name, value, field.flags)
331                    return
332                if name == getattr(field, "selected", None):
333                    context.values[field.selected] = list(value)
334                    return
335
336    def saveLow(self, context, widget, field, value, flags):
337        # The code below uses type(value)(value) to make at least a shallow copy of the
338        # attributes. This will mostly make copies of integers, floats and strings, but
339        # it is crucial to make copies of lists
340        value = copy.copy(value) #type(value)(value)
341        if isinstance(value, str):
342            valtype = not flags & self.ExcludeOrdinaryAttributes and context.attributes.get(value, -1)
343            if valtype == -1:
344                valtype = flags & self.IncludeMetaAttributes and context.attributes.get(value, -1)
345            context.values[field] = value, valtype # -1 means it's not an attribute
346        else:
347            context.values[field] = value, -2
348
349    def attributeExists(self, value, flags, attributes, metas):
350        return not flags & self.ExcludeOrdinaryAttributes and attributes.get(value[0], -1) == value[1] \
351                or flags & self.IncludeMetaAttributes and metas.get(value[0], -1) == value[1]
352   
353    def match(self, context, imperfect, domain, attributes, metas):
354        if (attributes, metas) == (context.attributes, context.metas):
355            return 2
356        if not imperfect:
357            return 0
358
359        filled = potentiallyFilled = 0
360        for field in self.fields:
361            flags = field.flags
362            value = context.values.get(field.name, None)
363            if value:
364                if flags & self.List:
365                    if flags & self.RequirementMask == self.Required:
366                        potentiallyFilled += len(value)
367                        filled += len(value)
368                        for item in value:
369                            if not self.attributeExists(item, flags, attributes, metas): 
370                                return 0
371                    else:
372                        selectedRequired = field.flags & self.RequirementMask == self.SelectedRequired
373                        selected = context.values.get(field.selected, [])
374                        potentiallyFilled += len(selected) 
375                        for i in selected:
376                            # TODO: shouldn't we check the attribute type here, too? or should we change self.saveLow for these field types then?
377                            if (not flags & self.ExcludeOrdinaryAttributes and value[i] in attributes
378                                 or flags & self.IncludeMetaAttributes and value[i] in metas): 
379                                filled += 1
380                            else:
381                                if selectedRequired:
382                                    return 0
383                else:
384                    potentiallyFilled += 1
385                    if value[1] >= 0:
386                        if (not flags & self.ExcludeOrdinaryAttributes and attributes.get(value[0], None) == value[1]
387                             or flags & self.IncludeMetaAttributes and metas.get(value[0], None) == value[1]): 
388                            filled += 1
389                        else:
390                            if flags & self.Required:
391                                return 0
392
393            if not potentiallyFilled:
394                return 1.0
395            else:
396                return filled / float(potentiallyFilled)
397
398    def cloneContext(self, context, domain, attributes, metas):
399        import copy
400        context = copy.deepcopy(context)
401       
402        for field in self.fields:
403            value = context.values.get(field.name, None)
404            if value:
405                if field.flags & self.List:
406                    i = j = realI = 0
407                    selected = context.values.get(field.selected, [])
408                    selected.sort()
409                    nextSel = selected and selected[0] or None
410                    while i < len(value):
411                        if not self.attributeExists(value[i], field.flags, attributes, metas):
412                            del value[i]
413                            if nextSel == realI:
414                                del selected[j]
415                                nextSel = j < len(selected) and selected[j] or None
416                        else:
417                            if nextSel == realI:
418                                selected[j] -= realI - i
419                                j += 1
420                                nextSel = j < len(selected) and selected[j] or None
421                            i += 1
422                        realI += 1
423                    if hasattr(field, "selected"):
424                        context.values[field.selected] = selected[:j]
425                else:
426                    if value[1] >= 0 and not self.attributeExists(value, field.flags, attributes, metas):
427                        del context.values[field.name]
428                       
429        context.attributes, context.metas = attributes, metas
430        context.orderedDomain = [(attr.name, attr.varType) for attr in domain]
431        return context
432
433    # this is overloaded to get rid of the huge domains
434    def mergeBack(self, widget):
435        if not self.syncWithGlobal or getattr(widget, self.localContextName) is not self.globalContexts:
436            self.globalContexts.extend([c for c in getattr(widget, self.localContextName) if c not in self.globalContexts])
437            mp = self.maxAttributesToPickle
438            self.globalContexts[:] = filter(lambda c: (c.attributes and len(c.attributes) or 0) + (c.metas and len(c.metas) or 0) < mp, self.globalContexts)
439            self.globalContexts.sort(lambda c1,c2: -cmp(c1.time, c2.time))
440            self.globalContexts[:] = self.globalContexts[:self.maxSavedContexts]
441
442   
443class ClassValuesContextHandler(ContextHandler):
444    def __init__(self, contextName, fields = [], syncWithGlobal = True, contextDataVersion = 0, **args):
445        ContextHandler.__init__(self, contextName, False, False, syncWithGlobal, contextDataVersion = contextDataVersion, **args)
446        if isinstance(fields, list):
447            self.fields = fields
448        else:
449            self.fields = [fields]
450       
451    def findOrCreateContext(self, widget, classes):
452        if isinstance(classes, orange.Variable):
453            classes = classes.varType == orange.VarTypes.Discrete and classes.values
454        if not classes:
455            return None, False
456        context, isNew = ContextHandler.findOrCreateContext(self, widget, classes)
457        if not context:
458            return None, False
459        context.classes = classes
460        if isNew:
461            context.values = {}
462        return context, isNew
463
464    def settingsToWidget(self, widget, context):
465        ContextHandler.settingsToWidget(self, widget, context)
466        for field in self.fields:
467            setattr(widget, field, context.values[field])
468           
469    def settingsFromWidget(self, widget, context):
470        ContextHandler.settingsFromWidget(self, widget, context)
471        # shallow copy!
472        values = context.values = {}
473        for field in self.fields:
474            value = widget.getdeepattr(field)
475            values[field] = copy.copy(value) #type(value)(value)
476
477    def fastSave(self, context, widget, name, value):
478        if context and name in self.fields:
479            # shallow copy!
480            context.values[name] = copy.copy(value) #type(value)(value)
481
482    def match(self, context, imperfect, classes):
483        return context.classes == classes and 2
484
485    def cloneContext(self, context, domain, encodedDomain):
486        import copy
487        return copy.deepcopy(context)
488       
489
490
491### Requires the same the same attributes in the same order
492### The class overloads domain encoding and matching.
493### Due to different encoding, it also needs to overload saveLow and cloneContext
494### (the latter gets really simple now).
495### We could simplify some other methods, but prefer not to replicate the code
496###
497### Note that forceOrdinaryAttributes is here True by default!
498class PerfectDomainContextHandler(DomainContextHandler):
499    def __init__(self, contextName = "", fields = [],
500                 syncWithGlobal = True, **args):
501            DomainContextHandler.__init__(self, contextName, fields, False, False, syncWithGlobal, **args)
502
503       
504    def encodeDomain(self, domain):
505        if self.matchValues == 2:
506            attributes = tuple([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
507                         for attr in domain])
508            classVar = domain.classVar
509            if classVar:
510                classVar = classVar.name, classVar.varType != orange.VarTypes.Discrete and classVar.varType or classVar.values
511            metas = dict([(attr.name, attr.varType != orange.VarTypes.Discrete and attr.varType or attr.values)
512                         for attr in domain.getmetas().values()])
513        else:
514            attributes = tuple([(attr.name, attr.varType) for attr in domain.attributes])
515            classVar = domain.classVar
516            if classVar:
517                classVar = classVar.name, classVar.varType
518            metas = dict([(attr.name, attr.varType) for attr in domain.getmetas().values()])
519        return attributes, classVar, metas
520   
521
522
523    def match(self, context, imperfect, domain, attributes, classVar, metas):
524        return (attributes, classVar, metas) == (context.attributes, context.classVar, context.metas) and 2
525
526
527    def saveLow(self, context, widget, field, value, flags):
528        if isinstance(value, str):
529            if not flags & self.ExcludeOrdinaryAttributes:
530                attr = [x[1] for x in context.attributes if x[0] == value]
531            if not attr and context.classVar and context.classVar[0] == value:
532                attr = [context.classVar[1]]
533            if not attr and flags & self.IncludeMetaAttributes:
534                attr = [x[1] for x in context.metas if x[0] == value]
535
536            value = copy.copy(value) #type(value)(value)
537            if attr:
538                context.values[field] = value, attr[0]
539            else:
540                context.values[field] = value, -1
541        else:
542            context.values[field] = value, -2
543
544
545    def cloneContext(self, context, domain, encodedDomain):
546        import copy
547        context = copy.deepcopy(context)
548       
549
550class EvaluationResultsContextHandler(ContextHandler):
551    def __init__(self, contextName, targetAttr, selectedAttr,
552                 syncWithGlobal = True, **args):
553            self.targetAttr, self.selectedAttr = targetAttr, selectedAttr
554            ContextHandler.__init__(self, contextName, False, False, syncWithGlobal, **args)
555       
556    def match(self, context, imperfect, cnames, cvalues):
557        return (cnames, cvalues) == (context.classifierNames, context.classValues) and 2
558
559    def fastSave(self, context, widget, name, value):
560        if context:
561            if name == self.targetAttr:
562                context.targetClass = value
563            elif name == self.selectedAttr:
564                context.selectedClassifiers = list(value)
565
566    def settingsFromWidget(self, widget, context):
567        context.targetClass = widget.getdeepattr(self.targetAttr)
568        context.selectedClassifiers = list(widget.getdeepattr(self.selectedAttr))
569
570    def settingsToWidget(self, widget, context):
571        if context.targetClass is not None:
572            setattr(widget, self.targetAttr, context.targetClass)
573        if context.selectedClassifiers is not None:
574            setattr(widget, self.selectedAttr, context.selectedClassifiers)
575           
576    def cloneContext(self, context, domain):
577        import copy
578        context = copy.deepcopy(context)
579       
580    def findOrCreateContext(self, widget, results):
581        if not results:
582            return None, False
583        cnames = [c.name for c in results.classifiers]
584        cvalues = results.classValues
585        context, isNew = ContextHandler.findOrCreateContext(self, widget, results.classifierNames, results.classValues)
586        if not context:
587            return None, False
588        if isNew:
589            context.classifierNames = results.classifierNames
590            context.classValues = results.classValues
591            context.selectedClassifiers = None
592            context.targetClass = None
593        return context, isNew
Note: See TracBrowser for help on using the repository browser.