source: orange/Orange/orng/orngABML.py @ 9737:585926e9995e

Revision 9737:585926e9995e, 16.3 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Moved createDichotomousClass to rules.py.

Line 
1# This module is used to handle argumented examples.
2
3import Orange.core
4import re
5import string
6import warnings
7
8import numpy
9import math
10
11from Orange.classification.rules import create_dichotomous_class as createDichotomousClass
12from Orange.classification.rules import ConvertClass
13# regular expressions
14# exppression for testing validity of a set of arguments:
15testVal = re.compile(r"""[" \s]*                              # remove any special characters at the beginning
16                     ~?                                       # argument could be negative (~) or positive (without ~) 
17                     {                                        # left parenthesis of the argument
18                     \s*[\w\W]+                                # first attribute of the argument 
19                     (\s*,\s*[\w\W]+)*                         # following attributes in the argument 
20                     }                                        # right parenthesis of the argument 
21                     (\s*,\s*~?{\s*[\w\W]+(\s*,\s*[\w\W]+)*})*  # following arguments
22                     [" \s]*"""                               # remove any special characters at the end
23                     , re.VERBOSE)
24
25# splitting regular expressions
26argRE = re.compile(r'[,\s]*(\{[^{}]+\})[,\s]*')
27argAt = re.compile(r'[{},]+')
28argCompare = re.compile(r'[(<=)(>=)<>]')
29
30def strSign(oper):
31    if oper == Orange.core.ValueFilter_continuous.Less:
32        return "<"
33    elif oper == Orange.core.ValueFilter_continuous.LessEqual:
34        return "<="
35    elif oper == Orange.core.ValueFilter_continuous.Greater:
36        return ">"
37    elif oper == Orange.core.ValueFilter_continuous.GreaterEqual:
38        return ">="
39    else: return "="
40
41def strArg(arg, domain, leave_ref):
42    if type(arg) == Orange.core.ValueFilter_discrete:
43        return str(domain[arg.position].name)
44    else:
45        if leave_ref:
46            return str(domain[arg.position].name) + strSign(arg.oper)
47        else:
48            return str(domain[arg.position].name) + strSign(arg.oper) + str(arg.ref)
49
50def listOfAttributeNames(rule, leave_ref=False):
51    if not rule.filter.conditions:
52        return ""
53    list = ""
54    for val in rule.filter.conditions[:-1]:
55        lr = leave_ref or val.unspecialized_condition
56        list += strArg(val, rule.filter.domain, lr) + ","
57    lr = leave_ref or rule.filter.conditions[-1].unspecialized_condition
58    list += strArg(rule.filter.conditions[-1], rule.filter.domain, lr)
59    return list
60
61
62class Argumentation:
63    """ Class that describes a set of positive and negative arguments
64    this class is used as a value for ArgumentationVariable. """
65    def __init__(self):
66        self.positive_arguments = Orange.core.RuleList()
67        self.negative_arguments = Orange.core.RuleList()
68        self.not_yet_computed_arguments = [] # Arguments that need the whole data set
69                                          # when processing are stored here
70
71    # add an argument that supports the class of the example
72    def addPositive(self, argument):
73        self.positive_arguments.append(argument)
74
75    # add an argument that opposes the class of the example
76    def addNegative(self, argument):
77        self.negative_arguments.append(argument)
78
79    def addNotYetComputed(self, argument, notyet, positive):
80        self.not_yet_computed_arguments.append((argument, notyet, positive))
81
82    def __str__(self):
83        retValue = ""
84        # iterate through positive arguments (rules) and
85        # write them down as a text list
86        if len(self.positive_arguments) > 0:
87            for (i, pos) in enumerate(self.positive_arguments[:-1]):
88                retValue += "{" + listOfAttributeNames(pos) + "}"
89                retValue += ","
90            retValue += "{" + listOfAttributeNames(self.positive_arguments[-1]) + "}"
91            # do the same thing for negative argument,
92            # just that this time use sign "~" in front of the list
93        if len(self.negative_arguments) > 0:
94            if len(retValue) > 0:
95                retValue += ","
96            for (i, neg) in enumerate(self.negative_arguments[:-1]):
97                retValue += "~"
98                retValue += "{" + listOfAttributeNames(neg, leave_ref=True) + "}"
99                retValue += ","
100            retValue += "~{" + listOfAttributeNames(self.negative_arguments[-1], leave_ref=True) + "}"
101        return retValue
102
103POSITIVE = True
104NEGATIVE = False
105class ArgumentVariable(Orange.core.PythonVariable):
106    """ For writing and parsing arguments in .tab files. """
107    def str2val(self, strV):
108        """ convert str to val - used for creating variables. """
109        return self.filestr2val(strV, None)
110
111    def filestr2val(self, strV, example=None):
112        """ write arguments (from string in file) to value - used also as a function for reading from data. """
113        mt = testVal.match(strV)
114        if not mt or not mt.end() == len(strV):
115            warnings.warn(strV + " is a badly formed argument.")
116            return Orange.core.PythonValueSpecial(2) # return special if argument doesnt match the formal form
117
118        if not example:
119            example = self.example
120        domain = example.domain
121
122        # get a set of arguments
123        splitedSet = filter(lambda x:x != '' and x != '"', argRE.split(strV))
124
125        # create an Argumentation object - an empty set of arguments
126        argumentation = Argumentation()
127        type = POSITIVE # type of argument - positive = True / negative = False
128        for sp in splitedSet:
129            # for each argument determine whether it is positive or negative
130            if sp == '~':
131                type = NEGATIVE
132                continue
133            argument = Orange.core.Rule(filter=Orange.core.Filter_values(domain=domain))
134            argument.setattr("unspecialized_argument", False)
135
136            reasonValues = filter(lambda x:x != '' and x != '"', argAt.split(sp)) # reasons in this argument
137            # iterate through argument names
138            for r in reasonValues:
139                r = string.strip(r)   # Remove all white characters on both sides
140                try:
141                    attribute = domain[r]
142                except:
143                    attribute = None
144                if attribute: # only attribute name is mentioned as a reason
145                    if domain.index(attribute) < 0:
146                        warnings.warn("Meta attribute %s used in argument. Is this intentional?" % r)
147                        continue
148                    value = example[attribute]
149                    if attribute.varType == Orange.core.VarTypes.Discrete: # discrete argument
150                        argument.filter.conditions.append(Orange.core.ValueFilter_discrete(
151                                                            position=domain.attributes.index(attribute),
152                                                            values=[value],
153                                                            acceptSpecial=0))
154                        argument.filter.conditions[-1].setattr("unspecialized_condition", False)
155                    else: # continuous but without reference point
156                        warnings.warn("Continous attributes (%s) in arguments should not be used without a comparison sign (<,<=,>,>=)" % r)
157
158                else: # attribute and something name is the reason, probably cont. attribute
159                    # one of four possible delimiters should be found, <,>,<=,>=
160                    splitReason = filter(lambda x:x != '' and x != '"', argCompare.split(r))
161                    if len(splitReason) > 2 or len(splitReason) == 0:
162                        warnings.warn("Reason %s is a badly formed part of an argument." % r)
163                        continue
164                    # get attribute name and continous reference value
165                    attributeName = string.strip(splitReason[0])
166                    if len(splitReason) > 1:
167                        refValue = string.strip(splitReason[1])
168                    else:
169                        refValue = ""
170                    if refValue:
171                        sign = r[len(attributeName):-len(refValue)]
172                    else:
173                        sign = r[len(attributeName):]
174
175                    # evaluate name and value                       
176                    try:
177                        attribute = domain[attributeName]
178                    except:
179                        warnings.warn("Attribute %s is not a part of the domain" % attributeName)
180                        continue
181                    if domain.index(attribute) < 0:
182                        warnings.warn("Meta attribute %s used in argument. Is this intentional?" % r)
183                        continue
184                    if refValue:
185                        try:
186                            ref = eval(refValue)
187                        except:
188                            warnings.warn("Error occured while reading value by argument's reason. Argument: %s, value: %s" % (r, refValue))
189                            continue
190                    else:
191                        ref = 0.
192                    if sign == "<": oper = Orange.core.ValueFilter_continuous.Less
193                    elif sign == ">": oper = Orange.core.ValueFilter_continuous.Greater
194                    elif sign == "<=": oper = Orange.core.ValueFilter_continuous.LessEqual
195                    else: oper = Orange.core.ValueFilter_continuous.GreaterEqual
196                    argument.filter.conditions.append(Orange.core.ValueFilter_continuous(
197                                                position=domain.attributes.index(attribute),
198                                                oper=oper,
199                                                ref=ref,
200                                                acceptSpecial=0))
201                    if not refValue and type == POSITIVE:
202                        argument.filter.conditions[-1].setattr("unspecialized_condition", True)
203                        argument.setattr("unspecialized_argument", True)
204                    else:
205                        argument.filter.conditions[-1].setattr("unspecialized_condition", False)
206
207            if example.domain.classVar:
208                argument.classifier = Orange.core.DefaultClassifier(defaultVal=example.getclass())
209            argument.complexity = len(argument.filter.conditions)
210
211            if type: # and len(argument.filter.conditions):
212                argumentation.addPositive(argument)
213            else: # len(argument.filter.conditions):
214                argumentation.addNegative(argument)
215            type = POSITIVE
216        return argumentation
217
218
219    # used for writing to data: specify output (string) presentation of arguments in tab. file
220    def val2filestr(self, val, example):
221        return str(val)
222
223    # used for writing to string
224    def val2str(self, val):
225        return str(val)
226
227
228class ArgumentFilter_hasSpecial:
229    def __call__(self, examples, attribute, target_class= -1, negate=0):
230        indices = [0] * len(examples)
231        for i in range(len(examples)):
232            if examples[i][attribute].isSpecial():
233                indices[i] = 1
234            elif target_class > -1 and not int(examples[i].getclass()) == target_class:
235                indices[i] = 1
236            elif len(examples[i][attribute].value.positive_arguments) == 0:
237                indices[i] = 1
238        return examples.select(indices, 0, negate=negate)
239
240def evaluateAndSortArguments(examples, argAtt, evaluateFunction=None, apriori=None):
241    """ Evaluate positive arguments and sort them by quality. """
242    if not apriori:
243        apriori = Orange.core.Distribution(examples.domain.classVar, examples)
244    if not evaluateFunction:
245        evaluateFunction = Orange.core.RuleEvaluator_Laplace()
246
247    for e in examples:
248        if not e[argAtt].isSpecial():
249            for r in e[argAtt].value.positive_arguments:
250                r.filterAndStore(examples, 0, e[examples.domain.classVar])
251                r.quality = evaluateFunction(r, examples, 0, int(e[examples.domain.classVar]), apriori)
252            e[argAtt].value.positive_arguments.sort(lambda x, y:-cmp(x.quality, y.quality))
253
254def isGreater(oper):
255    if oper == Orange.core.ValueFilter_continuous.Greater or \
256       oper == Orange.core.ValueFilter_continuous.GreaterEqual:
257        return True
258    return False
259
260def isLess(oper):
261    if oper == Orange.core.ValueFilter_continuous.Less or \
262       oper == Orange.core.ValueFilter_continuous.LessEqual:
263        return True
264    return False
265
266class ConvertCont:
267    def __init__(self, position, value, oper, newAtt):
268        self.value = value
269        self.oper = oper
270        self.position = position
271        self.newAtt = newAtt
272
273    def __call__(self, example, returnWhat):
274        if example[self.position].isSpecial():
275            return example[self.position]
276        if isLess(self.oper):
277            if example[self.position] < self.value:
278                return Orange.core.Value(self.newAtt, self.value)
279            else:
280                return Orange.core.Value(self.newAtt, float(example[self.position]))
281        else:
282            if example[self.position] > self.value:
283                return Orange.core.Value(self.newAtt, self.value)
284            else:
285                return Orange.core.Value(self.newAtt, float(example[self.position]))
286
287
288def addErrors(test_data, classifier):
289    """ Main task of this function is to add probabilistic errors to examples."""
290    for ex_i, ex in enumerate(test_data):
291        (cl, prob) = classifier(ex, Orange.core.GetBoth)
292        ex.setmeta("ProbError", float(ex.getmeta("ProbError")) + 1. - prob[ex.getclass()])
293
294def nCrossValidation(data, learner, weightID=0, folds=5, n=4, gen=0, argument_id="Arguments"):
295    """ Function performs n x fold crossvalidation. For each classifier
296        test set is updated by calling function addErrors. """
297    acc = 0.0
298    rules = {}
299    for d in data:
300        rules[float(d["SerialNumberPE"])] = []
301    pick = Orange.core.MakeRandomIndicesCV(folds=folds, randseed=gen, stratified=Orange.core.MakeRandomIndices.StratifiedIfPossible)
302    for n_i in range(n):
303        pick.randseed = gen + 10 * n_i
304        selection = pick(data)
305        for folds_i in range(folds):
306            for data_i, e in enumerate(data):
307                try:
308                    if e[argument_id]: # examples with arguments do not need to be tested
309                        selection[data_i] = folds_i + 1
310                except:
311                    pass
312            train_data = data.selectref(selection, folds_i, negate=1)
313            test_data = data.selectref(selection, folds_i, negate=0)
314            classifier = learner(train_data, weightID)
315            addErrors(test_data, classifier)
316            # add rules
317            for d in test_data:
318                for r in classifier.rules:
319                    if r(d):
320                        rules[float(d["SerialNumberPE"])].append(r)
321    # normalize prob errors
322    for d in data:
323        d["ProbError"] = d["ProbError"] / n
324    return rules
325
326def findProb(learner, examples, weightID=0, folds=5, n=4, gen=0, thr=0.5, argument_id="Arguments"):
327    """ General method for calling to find problematic example.
328        It returns all critial examples along with average probabilistic errors that ought to be higher then thr.
329        Taking the one with highest error is the same as taking the most
330        problematic example. """
331
332    newDomain = Orange.core.Domain(examples.domain.attributes, examples.domain.classVar)
333    newDomain.addmetas(examples.domain.getmetas())
334    newExamples = Orange.core.ExampleTable(newDomain, examples)
335    if not newExamples.domain.hasmeta("ProbError"):
336        newId = Orange.core.newmetaid()
337        newDomain.addmeta(newId, Orange.core.FloatVariable("ProbError"))
338        newExamples = Orange.core.ExampleTable(newDomain, examples)
339    if not newExamples.domain.hasmeta("SerialNumberPE"):
340        newId = Orange.core.newmetaid()
341        newDomain.addmeta(newId, Orange.core.FloatVariable("SerialNumberPE"))
342        newExamples = Orange.core.ExampleTable(newDomain, examples)
343    for i in range(len(newExamples)):
344        newExamples[i]["SerialNumberPE"] = float(i)
345        newExamples[i]["ProbError"] = 0.
346
347    # it returns a list of examples now: (index of example-starting with 0, example, prob error, rules covering example
348    rules = nCrossValidation(newExamples, learner, weightID=weightID, folds=folds, n=n, gen=gen, argument_id=argument_id)
349    return [(ei, examples[ei], float(e["ProbError"]), rules[float(e["SerialNumberPE"])]) for ei, e in enumerate(newExamples) if e["ProbError"] > thr]
350
Note: See TracBrowser for help on using the repository browser.