source: orange/Orange/orng/orngABML.py @ 10518:9127eb1fe9a5

Revision 10518:9127eb1fe9a5, 17.3 KB checked in by martin@…, 2 years ago (diff)

Improved learning of rules with ABCN2.

Line 
1# This module is used to handle argumented examples.
2
3import Orange.core
4import re
5import string
6import warnings
7
8import numpy
9import math
10
11# regular expressions
12# exppression for testing validity of a set of arguments:
13testVal = re.compile(r"""[" \s]*                              # remove any special characters at the beginning
14                     ~?                                       # argument could be negative (~) or positive (without ~) 
15                     {                                        # left parenthesis of the argument
16                     \s*[\w\W]+                                # first attribute of the argument 
17                     (\s*,\s*[\w\W]+)*                         # following attributes in the argument 
18                     }                                        # right parenthesis of the argument 
19                     (\s*,\s*~?{\s*[\w\W]+(\s*,\s*[\w\W]+)*})*  # following arguments
20                     [" \s]*"""                               # remove any special characters at the end
21                     , re.VERBOSE)
22
23# splitting regular expressions
24argRE = re.compile(r'[,\s]*(\{[^{}]+\})[,\s]*')
25argAt = re.compile(r'[{},]+')
26argCompare = re.compile(r'[(<=)(>=)<>]')
27
28def strSign(oper):
29    if oper == Orange.core.ValueFilter_continuous.Less:
30        return "<"
31    elif oper == Orange.core.ValueFilter_continuous.LessEqual:
32        return "<="
33    elif oper == Orange.core.ValueFilter_continuous.Greater:
34        return ">"
35    elif oper == Orange.core.ValueFilter_continuous.GreaterEqual:
36        return ">="
37    else: return "="
38
39def strArg(arg, domain, leave_ref):
40    if type(arg) == Orange.core.ValueFilter_discrete:
41        return str(domain[arg.position].name)
42    else:
43        if leave_ref:
44            return str(domain[arg.position].name) + strSign(arg.oper)
45        else:
46            return str(domain[arg.position].name) + strSign(arg.oper) + str(arg.ref)
47
48def listOfAttributeNames(rule, leave_ref=False):
49    if not rule.filter.conditions:
50        return ""
51    list = ""
52    for val in rule.filter.conditions[:-1]:
53        lr = leave_ref or val.unspecialized_condition
54        list += strArg(val, rule.filter.domain, lr) + ","
55    lr = leave_ref or rule.filter.conditions[-1].unspecialized_condition
56    list += strArg(rule.filter.conditions[-1], rule.filter.domain, lr)
57    return list
58
59
60class Argumentation:
61    """ Class that describes a set of positive and negative arguments
62    this class is used as a value for ArgumentationVariable. """
63    def __init__(self):
64        self.positive_arguments = Orange.core.RuleList()
65        self.negative_arguments = Orange.core.RuleList()
66        self.not_yet_computed_arguments = [] # Arguments that need the whole data set
67                                          # when processing are stored here
68
69    # add an argument that supports the class of the example
70    def addPositive(self, argument):
71        self.positive_arguments.append(argument)
72
73    # add an argument that opposes the class of the example
74    def addNegative(self, argument):
75        self.negative_arguments.append(argument)
76
77    def addNotYetComputed(self, argument, notyet, positive):
78        self.not_yet_computed_arguments.append((argument, notyet, positive))
79
80    def __str__(self):
81        retValue = ""
82        # iterate through positive arguments (rules) and
83        # write them down as a text list
84        if len(self.positive_arguments) > 0:
85            for (i, pos) in enumerate(self.positive_arguments[:-1]):
86                retValue += "{" + listOfAttributeNames(pos) + "}"
87                retValue += ","
88            retValue += "{" + listOfAttributeNames(self.positive_arguments[-1]) + "}"
89            # do the same thing for negative argument,
90            # just that this time use sign "~" in front of the list
91        if len(self.negative_arguments) > 0:
92            if len(retValue) > 0:
93                retValue += ","
94            for (i, neg) in enumerate(self.negative_arguments[:-1]):
95                retValue += "~"
96                retValue += "{" + listOfAttributeNames(neg, leave_ref=True) + "}"
97                retValue += ","
98            retValue += "~{" + listOfAttributeNames(self.negative_arguments[-1], leave_ref=True) + "}"
99        return retValue
100
101POSITIVE = True
102NEGATIVE = False
103class ArgumentVariable(Orange.core.PythonVariable):
104    """ For writing and parsing arguments in .tab files. """
105    def str2val(self, strV):
106        """ convert str to val - used for creating variables. """
107        return self.filestr2val(strV, None)
108
109    def filestr2val(self, strV, example=None):
110        """ write arguments (from string in file) to value - used also as a function for reading from data. """
111        mt = testVal.match(strV)
112        if not mt or not mt.end() == len(strV):
113            warnings.warn(strV + " is a badly formed argument.")
114            return Orange.core.PythonValueSpecial(2) # return special if argument doesnt match the formal form
115
116        if not example:
117            example = self.example
118        domain = example.domain
119
120        # get a set of arguments
121        splitedSet = filter(lambda x:x != '' and x != '"', argRE.split(strV))
122
123        # create an Argumentation object - an empty set of arguments
124        argumentation = Argumentation()
125        type = POSITIVE # type of argument - positive = True / negative = False
126        for sp in splitedSet:
127            # for each argument determine whether it is positive or negative
128            if sp == '~':
129                type = NEGATIVE
130                continue
131            argument = Orange.core.Rule(filter=Orange.core.Filter_values(domain=domain))
132            argument.setattr("unspecialized_argument", False)
133
134            reasonValues = filter(lambda x:x != '' and x != '"', argAt.split(sp)) # reasons in this argument
135            # iterate through argument names
136            for r in reasonValues:
137                r = string.strip(r)   # Remove all white characters on both sides
138                try:
139                    attribute = domain[r]
140                except:
141                    attribute = None
142                if attribute: # only attribute name is mentioned as a reason
143                    if domain.index(attribute) < 0:
144                        warnings.warn("Meta attribute %s used in argument. Is this intentional?" % r)
145                        continue
146                    value = example[attribute]
147                    if attribute.varType == Orange.core.VarTypes.Discrete: # discrete argument
148                        argument.filter.conditions.append(Orange.core.ValueFilter_discrete(
149                                                            position=domain.attributes.index(attribute),
150                                                            values=[value],
151                                                            acceptSpecial=0))
152                        argument.filter.conditions[-1].setattr("unspecialized_condition", False)
153                    else: # continuous but without reference point
154                        warnings.warn("Continous attributes (%s) in arguments should not be used without a comparison sign (<,<=,>,>=)" % r)
155
156                else: # attribute and something name is the reason, probably cont. attribute
157                    # one of four possible delimiters should be found, <,>,<=,>=
158                    splitReason = filter(lambda x:x != '' and x != '"', argCompare.split(r))
159                    if len(splitReason) > 2 or len(splitReason) == 0:
160                        warnings.warn("Reason %s is a badly formed part of an argument." % r)
161                        continue
162                    # get attribute name and continous reference value
163                    attributeName = string.strip(splitReason[0])
164                    if len(splitReason) > 1:
165                        refValue = string.strip(splitReason[1])
166                    else:
167                        refValue = ""
168                    if refValue:
169                        sign = r[len(attributeName):-len(refValue)]
170                    else:
171                        sign = r[len(attributeName):]
172
173                    # evaluate name and value                       
174                    try:
175                        attribute = domain[attributeName]
176                    except:
177                        warnings.warn("Attribute %s is not a part of the domain" % attributeName)
178                        continue
179                    if domain.index(attribute) < 0:
180                        warnings.warn("Meta attribute %s used in argument. Is this intentional?" % r)
181                        continue
182                    if refValue:
183                        try:
184                            ref = eval(refValue)
185                        except:
186                            warnings.warn("Error occured while reading value by argument's reason. Argument: %s, value: %s" % (r, refValue))
187                            continue
188                    else:
189                        ref = 0.
190                    if sign == "<": oper = Orange.core.ValueFilter_continuous.Less
191                    elif sign == ">": oper = Orange.core.ValueFilter_continuous.Greater
192                    elif sign == "<=": oper = Orange.core.ValueFilter_continuous.LessEqual
193                    else: oper = Orange.core.ValueFilter_continuous.GreaterEqual
194                    argument.filter.conditions.append(Orange.core.ValueFilter_continuous(
195                                                position=domain.attributes.index(attribute),
196                                                oper=oper,
197                                                ref=ref,
198                                                acceptSpecial=0))
199                    if not refValue and type == POSITIVE:
200                        argument.filter.conditions[-1].setattr("unspecialized_condition", True)
201                        argument.setattr("unspecialized_argument", True)
202                    else:
203                        argument.filter.conditions[-1].setattr("unspecialized_condition", False)
204
205            if example.domain.classVar:
206                argument.classifier = Orange.core.DefaultClassifier(defaultVal=example.getclass())
207            argument.complexity = len(argument.filter.conditions)
208
209            if type: # and len(argument.filter.conditions):
210                argumentation.addPositive(argument)
211            else: # len(argument.filter.conditions):
212                argumentation.addNegative(argument)
213            type = POSITIVE
214        return argumentation
215
216
217    # used for writing to data: specify output (string) presentation of arguments in tab. file
218    def val2filestr(self, val, example):
219        return str(val)
220
221    # used for writing to string
222    def val2str(self, val):
223        return str(val)
224
225
226class ArgumentFilter_hasSpecial:
227    def __call__(self, examples, attribute, target_class= -1, negate=0):
228        indices = [0] * len(examples)
229        for i in range(len(examples)):
230            if examples[i][attribute].isSpecial():
231                indices[i] = 1
232            elif target_class > -1 and not int(examples[i].getclass()) == target_class:
233                indices[i] = 1
234            elif len(examples[i][attribute].value.positive_arguments) == 0:
235                indices[i] = 1
236        return examples.select(indices, 0, negate=negate)
237
238def evaluateAndSortArguments(examples, argAtt, evaluateFunction=None, apriori=None):
239    """ Evaluate positive arguments and sort them by quality. """
240    if not apriori:
241        apriori = Orange.core.Distribution(examples.domain.classVar, examples)
242    if not evaluateFunction:
243        evaluateFunction = Orange.core.RuleEvaluator_Laplace()
244
245    for e in examples:
246        if not e[argAtt].isSpecial():
247            for r in e[argAtt].value.positive_arguments:
248                r.filterAndStore(examples, 0, e[examples.domain.classVar])
249                r.quality = evaluateFunction(r, examples, 0, int(e[examples.domain.classVar]), apriori)
250            e[argAtt].value.positive_arguments.sort(lambda x, y:-cmp(x.quality, y.quality))
251
252def isGreater(oper):
253    if oper == Orange.core.ValueFilter_continuous.Greater or \
254       oper == Orange.core.ValueFilter_continuous.GreaterEqual:
255        return True
256    return False
257
258def isLess(oper):
259    if oper == Orange.core.ValueFilter_continuous.Less or \
260       oper == Orange.core.ValueFilter_continuous.LessEqual:
261        return True
262    return False
263
264class ConvertCont:
265    def __init__(self, position, value, oper, newAtt):
266        self.value = value
267        self.oper = oper
268        self.position = position
269        self.newAtt = newAtt
270
271    def __call__(self, example, returnWhat):
272        if example[self.position].isSpecial():
273            return example[self.position]
274        if isLess(self.oper):
275            if example[self.position] < self.value:
276                return Orange.core.Value(self.newAtt, self.value)
277            else:
278                return Orange.core.Value(self.newAtt, float(example[self.position]))
279        else:
280            if example[self.position] > self.value:
281                return Orange.core.Value(self.newAtt, self.value)
282            else:
283                return Orange.core.Value(self.newAtt, float(example[self.position]))
284
285
286
287class ConvertClass:
288    """ Converting class variables into dichotomous class variable. """
289    def __init__(self, classAtt, classValue, newClassAtt):
290        self.classAtt = classAtt
291        self.classValue = classValue
292        self.newClassAtt = newClassAtt
293
294    def __call__(self, example, returnWhat):
295        if example[self.classAtt] == self.classValue:
296            return Orange.data.Value(self.newClassAtt, self.classValue + "_")
297        else:
298            return Orange.data.Value(self.newClassAtt, "not " + self.classValue)
299
300
301def create_dichotomous_class(domain, att, value, negate, removeAtt=None):
302    # create new variable
303    newClass = Orange.feature.Discrete(att.name + "_", values=[str(value) + "_", "not " + str(value)])
304    positive = Orange.data.Value(newClass, str(value) + "_")
305    negative = Orange.data.Value(newClass, "not " + str(value))
306    newClass.getValueFrom = ConvertClass(att, str(value), newClass)
307
308    att = [a for a in domain.attributes]
309    newDomain = Orange.data.Domain(att + [newClass])
310    newDomain.addmetas(domain.getmetas())
311    if negate == 1:
312        return (newDomain, negative)
313    else:
314        return (newDomain, positive)
315
316
317
318def addErrors(test_data, classifier):
319    """ Main task of this function is to add probabilistic errors to examples."""
320    for ex_i, ex in enumerate(test_data):
321        (cl, prob) = classifier(ex, Orange.core.GetBoth)
322        ex.setmeta("ProbError", float(ex.getmeta("ProbError")) + 1. - prob[ex.getclass()])
323
324def nCrossValidation(data, learner, weightID=0, folds=5, n=4, gen=0, argument_id="Arguments"):
325    """ Function performs n x fold crossvalidation. For each classifier
326        test set is updated by calling function addErrors. """
327    acc = 0.0
328    rules = {}
329    for d in data:
330        rules[float(d["SerialNumberPE"])] = []
331    pick = Orange.core.MakeRandomIndicesCV(folds=folds, randseed=gen, stratified=Orange.core.MakeRandomIndices.StratifiedIfPossible)
332    for n_i in range(n):
333        pick.randseed = gen + 10 * n_i
334        selection = pick(data)
335        for folds_i in range(folds):
336            for data_i, e in enumerate(data):
337                try:
338                    if e[argument_id]: # examples with arguments do not need to be tested
339                        selection[data_i] = folds_i + 1
340                except:
341                    pass
342            train_data = data.selectref(selection, folds_i, negate=1)
343            test_data = data.selectref(selection, folds_i, negate=0)
344            classifier = learner(train_data, weightID)
345            addErrors(test_data, classifier)
346            # add rules
347            for d in test_data:
348                for r in classifier.rules:
349                    if r(d):
350                        rules[float(d["SerialNumberPE"])].append(r)
351    # normalize prob errors
352    for d in data:
353        d["ProbError"] = d["ProbError"] / n
354    return rules
355
356def findProb(learner, examples, weightID=0, folds=5, n=4, gen=0, thr=0.5, argument_id="Arguments"):
357    """ General method for calling to find problematic example.
358        It returns all critial examples along with average probabilistic errors that ought to be higher then thr.
359        Taking the one with highest error is the same as taking the most
360        problematic example. """
361
362    newDomain = Orange.core.Domain(examples.domain.attributes, examples.domain.classVar)
363    newDomain.addmetas(examples.domain.getmetas())
364    newExamples = Orange.core.ExampleTable(newDomain, examples)
365    if not newExamples.domain.hasmeta("ProbError"):
366        newId = Orange.core.newmetaid()
367        newDomain.addmeta(newId, Orange.core.FloatVariable("ProbError"))
368        newExamples = Orange.core.ExampleTable(newDomain, examples)
369    if not newExamples.domain.hasmeta("SerialNumberPE"):
370        newId = Orange.core.newmetaid()
371        newDomain.addmeta(newId, Orange.core.FloatVariable("SerialNumberPE"))
372        newExamples = Orange.core.ExampleTable(newDomain, examples)
373    for i in range(len(newExamples)):
374        newExamples[i]["SerialNumberPE"] = float(i)
375        newExamples[i]["ProbError"] = 0.
376
377    # it returns a list of examples now: (index of example-starting with 0, example, prob error, rules covering example
378    rules = nCrossValidation(newExamples, learner, weightID=weightID, folds=folds, n=n, gen=gen, argument_id=argument_id)
379    return [(ei, examples[ei], float(e["ProbError"]), rules[float(e["SerialNumberPE"])]) for ei, e in enumerate(newExamples) if e["ProbError"] > thr]
380
Note: See TracBrowser for help on using the repository browser.