Changeset 10518:9127eb1fe9a5 in orange


Ignore:
Timestamp:
03/14/12 00:15:17 (2 years ago)
Author:
martin@…
Branch:
default
Message:

Improved learning of rules with ABCN2.

Location:
Orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/rules.py

    r10407 r10518  
    3939Validator_LRS = Orange.core.RuleValidator_LRS 
    4040     
     41from Orange.orng.orngABML import \ 
     42    ArgumentFilter_hasSpecial, \ 
     43    create_dichotomous_class, \ 
     44    evaluateAndSortArguments 
    4145from Orange.misc import deprecated_keywords 
    4246from Orange.misc import deprecated_members 
    4347 
    44  
    45 class ConvertClass: 
    46     """ Converting class variables into dichotomous class variable. """ 
    47     def __init__(self, classAtt, classValue, newClassAtt): 
    48         self.classAtt = classAtt 
    49         self.classValue = classValue 
    50         self.newClassAtt = newClassAtt 
    51  
    52     def __call__(self, example, returnWhat): 
    53         if example[self.classAtt] == self.classValue: 
    54             return Orange.data.Value(self.newClassAtt, self.classValue + "_") 
    55         else: 
    56             return Orange.data.Value(self.newClassAtt, "not " + self.classValue) 
    57  
    58  
    59 def create_dichotomous_class(domain, att, value, negate, removeAtt=None): 
    60     # create new variable 
    61     newClass = Orange.feature.Discrete(att.name + "_", values=[str(value) + "_", "not " + str(value)]) 
    62     positive = Orange.data.Value(newClass, str(value) + "_") 
    63     negative = Orange.data.Value(newClass, "not " + str(value)) 
    64     newClass.getValueFrom = ConvertClass(att, str(value), newClass) 
    65  
    66     att = [a for a in domain.attributes] 
    67     newDomain = Orange.data.Domain(att + [newClass]) 
    68     newDomain.addmetas(domain.getmetas()) 
    69     if negate == 1: 
    70         return (newDomain, negative) 
    71     else: 
    72         return (newDomain, positive) 
    7348 
    7449 
     
    559534                 rule_sig=1.0, att_sig=1.0, postpruning=None, min_quality=0., min_coverage=1, min_improved=1, min_improved_perc=0.0, 
    560535                 learn_for_class=None, learn_one_rule=False, evd=None, evd_arguments=None, prune_arguments=False, analyse_argument= -1, 
    561                  alternative_learner=None, min_cl_sig=0.5, min_beta=0.0, set_prefix_rules=False, add_sub_rules=False, debug=False, 
     536                 alternative_learner=None, min_cl_sig=0.5, min_beta=0.0, set_prefix_rules=False, add_sub_rules=True, debug=False, 
    562537                 **kwds): 
    563538 
     
    652627            while aes: 
    653628                if self.analyse_argument > -1 and \ 
    654                    (isinstance(self.analyse_argument, Orange.data.Instance) and not Orange.data.Instance(dich_data.domain, self.analyse_argument) == aes[0] or \ 
     629                   (isinstance(self.analyse_argument, Orange.core.Example) and not Orange.core.Example(dich_data.domain, self.analyse_argument) == aes[0] or \ 
    655630                    isinstance(self.analyse_argument, int) and not dich_data[self.analyse_argument] == aes[0]): 
    656631                    aes = aes[1:] 
     
    667642                else: 
    668643                    aes = aes[1:] 
    669                 aes = aes[1:] 
    670644 
    671645            if not progress and self.debug: 
     
    674648            # remove all examples covered by rules 
    675649            for rule in rules: 
    676                 dich_data = self.remove_covered_examples(rule, dich_data, weight_id) 
     650                dich_data = self.remove_covered_examples(rule, dich_data, weight_id, True) 
    677651            if progress: 
    678652                progress(self.remaining_probability(dich_data), None) 
     
    683657                while dich_data: 
    684658                    # learn a rule 
    685                     rule = self.learn_normal_rule(dich_data, weight_id, self.apriori) 
     659                    rule, good_rule = self.learn_normal_rule(dich_data, weight_id, self.apriori) 
    686660                    if not rule: 
    687661                        break 
    688662                    if self.debug: 
    689                         print "rule learned: ", Orange.classification.rules.rule_to_string(rule), rule.quality 
    690                     dich_data = self.remove_covered_examples(rule, dich_data, weight_id) 
     663                        if good_rule: 
     664                            print "rule learned: ", rule_to_string(rule), rule.quality 
     665                        else: 
     666                            print "rule only to influence learning: ", rule_to_string(rule), rule.quality 
     667                             
     668                    dich_data = self.remove_covered_examples(rule, dich_data, weight_id, good_rule) 
     669 
    691670                    if progress: 
    692671                        progress(self.remaining_probability(dich_data), None) 
    693                     rules.append(rule) 
     672                    if good_rule: 
     673                        rules.append(rule) 
    694674                    if self.learn_one_rule: 
    695675                        break 
     
    715695        if not positive_args: # something wrong 
    716696            raise "There is a problem with argumented example %s" % str(ae) 
     697            return None 
     698        if False in [p(ae) for p in positive_args]: # a positive argument is not covering this example 
     699            raise "One argument does not cover critical example: %s!"%str(ae) 
    717700            return None 
    718701        negative_args = self.init_neg_args(ae, examples, weight_id) 
     
    804787        self.rule_finder.ruleFilter = self.ruleFilter 
    805788 
     789 
    806790    def learn_normal_rule(self, examples, weight_id, apriori): 
    807791        if hasattr(self.rule_finder.evaluator, "bestRule"): 
    808792            self.rule_finder.evaluator.bestRule = None 
    809         rule = self.rule_finder(examples, weight_id, 0, RuleList()) 
     793        rule = self.rule_finder(examples,weight_id,0,RuleList()) 
    810794        if hasattr(self.rule_finder.evaluator, "bestRule") and self.rule_finder.evaluator.returnExpectedProb: 
     795            if not self.rule_finder.evaluator.bestRule and rule.quality > 0: 
     796                return (rule, False) 
    811797            rule = self.rule_finder.evaluator.bestRule 
    812798            self.rule_finder.evaluator.bestRule = None 
    813799        if self.postpruning: 
    814             rule = self.postpruning(rule, examples, weight_id, 0, aprior) 
    815         return rule 
    816  
    817     def remove_covered_examples(self, rule, examples, weight_id): 
    818         nexamples, nweight = self.cover_and_remove(rule, examples, weight_id, 0) 
     800            rule = self.postpruning(rule,examples,weight_id,0, aprior) 
     801        return (rule, True) 
     802     
     803 
     804    def remove_covered_examples(self, rule, examples, weight_id, good_rule): 
     805        if good_rule: 
     806            nexamples, nweight = self.cover_and_remove(rule, examples, weight_id, 0) 
     807        else: 
     808            nexamples, nweight = self.cover_and_remove.mark_examples_solved(rule,examples,weight_id,0) 
    819809        return nexamples 
    820810 
     
    12251215            if r and not rule_in_set(r, best_rules) and int(examples[r_i].getclass()) == int(r.classifier.default_value): 
    12261216                if hasattr(r.learner, "arg_example"): 
    1227                     setattr(r, "best_example", r.learner.arg_example) 
     1217                    r.setattr("best_example", r.learner.arg_example) 
    12281218                else: 
    1229                     setattr(r, "best_example", examples[r_i]) 
     1219                    r.setattr("best_example", examples[r_i]) 
    12301220                best_rules.append(r) 
    12311221        return best_rules 
    12321222 
    1233     def __call__(self, rule, examples, weights, target_class): 
     1223 
    12341224        """ if example has an argument, then the rule must be consistent with the argument. """ 
    12351225        example = getattr(rule.learner, "arg_example", None) 
     1226        if example: 
     1227            for ei, e in enumerate(examples): 
     1228                if e == example: 
     1229                    e[self.prob_attribute] = rule.quality+0.001 # 0.001 is added to avoid numerical errors 
     1230                    self.best_rule[ei]=rule 
     1231        else:         
     1232            for ei, e in enumerate(examples): 
     1233                if rule(e) and rule.quality>e[self.prob_attribute]: 
     1234                    e[self.prob_attribute] = rule.quality+0.001 # 0.001 is added to avoid numerical errors 
     1235                    self.best_rule[ei]=rule 
     1236        return (examples, weights) 
     1237 
     1238    def mark_examples_solved(self, rule, examples, weights, target_class): 
    12361239        for ei, e in enumerate(examples): 
    1237             if e == example: 
     1240            if rule(e): 
    12381241                e[self.prob_attribute] = 1.0 
    1239                 self.best_rule[ei] = rule 
    1240             elif rule(e) and rule.quality > e[self.prob_attribute]: 
    1241                 e[self.prob_attribute] = rule.quality + 0.001 # 0.001 is added to avoid numerical errors 
    1242                 self.best_rule[ei] = rule 
    12431242        return (examples, weights) 
    1244  
    1245     def filter_covers_example(self, example, filter): 
    1246         filter_indices = CoversArguments.filterIndices(filter) 
    1247         if filter(example): 
    1248             try: 
    1249                 if example[self.argument_id].value and len(example[self.argument_id].value.positive_arguments) > 0: # example has positive arguments 
    1250                     # conditions should cover at least one of the positive arguments 
    1251                     one_arg_covered = False 
    1252                     for pA in example[self.argument_id].value.positive_arguments: 
    1253                         arg_covered = [self.condIn(c, filter_indices) for c in pA.filter.conditions] 
    1254                         one_arg_covered = one_arg_covered or len(arg_covered) == sum(arg_covered) #arg_covered 
    1255                         if one_arg_covered: 
    1256                             break 
    1257                     if not one_arg_covered: 
    1258                         return False 
    1259                 if example[self.argument_id].value and len(example[self.argument_id].value.negative_arguments) > 0: # example has negative arguments 
    1260                     # condition should not cover neither of negative arguments 
    1261                     for pN in example[self.argument_id].value.negative_arguments: 
    1262                         arg_covered = [self.condIn(c, filter_indices) for c in pN.filter.conditions] 
    1263                         if len(arg_covered) == sum(arg_covered): 
    1264                             return False 
    1265             except: 
    1266                 return True 
    1267             return True 
    1268         return False 
    1269  
    1270     def condIn(self, cond, filter_indices): # is condition in the filter? 
    1271         condInd = CoversArguments.conditionIndex(cond) 
    1272         if operator.or_(condInd, filter_indices[cond.position]) == filter_indices[cond.position]: 
    1273             return True 
    1274         return False 
    12751243 
    12761244 
     
    17421710# This filter is the ugliest code ever! Problem is with Orange, I had some problems with inheriting deepCopy 
    17431711# I should take another look at it. 
    1744 class ArgFilter(Orange.data.filter.Filter): 
     1712class ArgFilter(Orange.core.Filter): 
    17451713    """ This class implements AB-covering principle. """ 
    1746     def __init__(self, argument_id=None, filter=Orange.data.filter.Values(), arg_example=None): 
     1714    def __init__(self, argument_id=None, filter = Orange.core.Filter_values(), arg_example = None): 
    17471715        self.filter = filter 
    1748         self.indices = getattr(filter, "indices", []) 
    1749         if not self.indices and len(filter.conditions) > 0: 
    1750             self.indices = CoversArguments.filterIndices(filter) 
     1716        self.indices = getattr(filter,"indices",[]) 
     1717        if not self.indices and len(filter.conditions)>0: 
     1718            self.indices = RuleCoversArguments.filterIndices(filter) 
    17511719        self.argument_id = argument_id 
    17521720        self.domain = self.filter.domain 
    17531721        self.conditions = filter.conditions 
    17541722        self.arg_example = arg_example 
    1755  
    1756     def condIn(self, cond): # is condition in the filter? 
    1757         condInd = ruleCoversArguments.conditionIndex(cond) 
    1758         if operator.or_(condInd, self.indices[cond.position]) == self.indices[cond.position]: 
     1723        self.only_arg_example = True 
     1724         
     1725    def condIn(self,cond): # is condition in the filter? 
     1726        condInd = RuleCoversArguments.conditionIndex(cond) 
     1727        if operator.or_(condInd,self.indices[cond.position]) == self.indices[cond.position]: 
    17591728            return True 
    17601729        return False 
    1761  
    1762     def __call__(self, example): 
    1763 ##        print "in", self.filter(example)#, self.filter.conditions[0](example) 
    1764 ##        print self.filter.conditions[1].values 
    1765         if self.filter(example) and example != self.arg_example: 
    1766             return True 
    1767         elif self.filter(example): 
    1768             try: 
    1769                 if example[self.argument_id].value and len(example[self.argument_id].value.positiveArguments) > 0: # example has positive arguments 
    1770                     # conditions should cover at least one of the positive arguments 
    1771                     oneArgCovered = False 
    1772                     for pA in example[self.argument_id].value.positiveArguments: 
    1773                         argCovered = [self.condIn(c) for c in pA.filter.conditions] 
    1774                         oneArgCovered = oneArgCovered or len(argCovered) == sum(argCovered) #argCovered 
    1775                         if oneArgCovered: 
    1776                             break 
    1777                     if not oneArgCovered: 
     1730     
     1731    def __call__(self,example): 
     1732        if not self.filter(example): 
     1733            return False 
     1734        elif (not self.only_arg_example or example == self.arg_example): 
     1735            if example[self.argument_id].value and len(example[self.argument_id].value.positive_arguments)>0: # example has positive arguments 
     1736                # conditions should cover at least one of the positive arguments 
     1737                oneArgCovered = False 
     1738                for pA in example[self.argument_id].value.positive_arguments: 
     1739                    argCovered = [self.condIn(c) for c in pA.filter.conditions] 
     1740                    oneArgCovered = oneArgCovered or len(argCovered) == sum(argCovered) #argCovered 
     1741                    if oneArgCovered: 
     1742                        break 
     1743                if not oneArgCovered: 
     1744                    return False 
     1745            if example[self.argument_id].value and len(example[self.argument_id].value.negative_arguments)>0: # example has negative arguments 
     1746                # condition should not cover neither of negative arguments 
     1747                for pN in example[self.argument_id].value.negative_arguments: 
     1748                    argCovered = [self.condIn(c) for c in pN.filter.conditions] 
     1749                    if len(argCovered)==sum(argCovered): 
    17781750                        return False 
    1779                 if example[self.argument_id].value and len(example[self.argument_id].value.negativeArguments) > 0: # example has negative arguments 
    1780                     # condition should not cover neither of negative arguments 
    1781                     for pN in example[self.argument_id].value.negativeArguments: 
    1782                         argCovered = [self.condIn(c) for c in pN.filter.conditions] 
    1783                         if len(argCovered) == sum(argCovered): 
    1784                             return False 
    1785             except: 
    1786                 return True 
    1787             return True 
    1788         else: 
    1789             return False 
    1790  
    1791     def __setattr__(self, name, obj): 
    1792         self.__dict__[name] = obj 
    1793         self.filter.setattr(name, obj) 
     1751        return True 
     1752 
     1753    def __setattr__(self,name,obj): 
     1754        self.__dict__[name]=obj 
     1755        self.filter.setattr(name,obj) 
    17941756 
    17951757    def deep_copy(self): 
    17961758        newFilter = ArgFilter(argument_id=self.argument_id) 
    1797         newFilter.filter = Orange.data.filter.Values() #self.filter.deepCopy() 
     1759        newFilter.filter = Orange.core.Filter_values() #self.filter.deepCopy() 
    17981760        newFilter.filter.conditions = self.filter.conditions[:] 
    17991761        newFilter.domain = self.filter.domain 
     
    18031765        newFilter.conditions = newFilter.filter.conditions 
    18041766        newFilter.indices = self.indices[:] 
     1767        newFilter.arg_example = self.arg_example 
    18051768        return newFilter 
    1806  
    18071769ArgFilter = deprecated_members({"argumentID": "argument_id"})(ArgFilter) 
    18081770 
     
    18881850        self.optimize_betas = optimize_betas 
    18891851        self.selected_evaluation = CrossValidation(folds=5) 
     1852        self.penalty = penalty 
    18901853 
    18911854    def __call__(self, rules, examples, weight=0): 
     
    18971860##            for e in examples: 
    18981861##                prob_dist.append(classifier(e,Orange.core.GetProbabilities)) 
    1899             cl = RuleClassifier_logit(rules, self.min_cl_sig, self.min_beta, examples, weight, self.set_prefix_rules, self.optimize_betas, classifier, prob_dist) 
     1862            cl = Orange.core.RuleClassifier_logit(rules, self.min_cl_sig, self.min_beta, self.penalty, examples, weight, self.set_prefix_rules, self.optimize_betas, classifier, prob_dist) 
    19001863        else: 
    1901             cl = RuleClassifier_logit(rules, self.min_cl_sig, self.min_beta, examples, weight, self.set_prefix_rules, self.optimize_betas) 
    1902  
    1903 ##        print "result" 
     1864            cl = Orange.core.RuleClassifier_logit(rules, self.min_cl_sig, self.min_beta, self.penalty, examples, weight, self.set_prefix_rules, self.optimize_betas) 
     1865 
    19041866        for ri, r in enumerate(cl.rules): 
    19051867            cl.rules[ri].setattr("beta", cl.ruleBetas[ri]) 
    1906 ##            if cl.ruleBetas[ri] > 0: 
    1907 ##                print Orange.classification.rules.rule_to_string(r), r.quality, cl.ruleBetas[ri] 
    1908         cl.all_rules = cl.rules 
     1868        cl.setattr("all_rules", cl.rules) 
    19091869        cl.rules = self.sort_rules(cl.rules) 
    19101870        cl.ruleBetas = [r.beta for r in cl.rules] 
  • Orange/orng/orngABML.py

    r9737 r10518  
    99import math 
    1010 
    11 from Orange.classification.rules import create_dichotomous_class as createDichotomousClass 
    12 from Orange.classification.rules import ConvertClass 
    1311# regular expressions 
    1412# exppression for testing validity of a set of arguments: 
     
    286284 
    287285 
     286 
     287class ConvertClass: 
     288    """ Converting class variables into dichotomous class variable. """ 
     289    def __init__(self, classAtt, classValue, newClassAtt): 
     290        self.classAtt = classAtt 
     291        self.classValue = classValue 
     292        self.newClassAtt = newClassAtt 
     293 
     294    def __call__(self, example, returnWhat): 
     295        if example[self.classAtt] == self.classValue: 
     296            return Orange.data.Value(self.newClassAtt, self.classValue + "_") 
     297        else: 
     298            return Orange.data.Value(self.newClassAtt, "not " + self.classValue) 
     299 
     300 
     301def create_dichotomous_class(domain, att, value, negate, removeAtt=None): 
     302    # create new variable 
     303    newClass = Orange.feature.Discrete(att.name + "_", values=[str(value) + "_", "not " + str(value)]) 
     304    positive = Orange.data.Value(newClass, str(value) + "_") 
     305    negative = Orange.data.Value(newClass, "not " + str(value)) 
     306    newClass.getValueFrom = ConvertClass(att, str(value), newClass) 
     307 
     308    att = [a for a in domain.attributes] 
     309    newDomain = Orange.data.Domain(att + [newClass]) 
     310    newDomain.addmetas(domain.getmetas()) 
     311    if negate == 1: 
     312        return (newDomain, negative) 
     313    else: 
     314        return (newDomain, positive) 
     315 
     316 
     317 
    288318def addErrors(test_data, classifier): 
    289319    """ Main task of this function is to add probabilistic errors to examples.""" 
Note: See TracChangeset for help on using the changeset viewer.