Ignore:
Timestamp:
04/17/12 15:14:10 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added a 'SimpleTreeClassifier' to 'TreeClassifier' conversion function.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/tree.py

    r11459 r11626  
    27582758    return count 
    27592759 
     2760 
     2761def _simple_tree_convert(tree, domain, training_data=None, weight_id=None): 
     2762    """ 
     2763    Convert an :class:`SimpleTreeClassifier` to a :class:`TreeClassifier`. 
     2764 
     2765    The domain used to build it must be supplied with the `domain` 
     2766    parameter. If `training_data` is not None it is split and assigned 
     2767    to the tree's nodes. 
     2768 
     2769    :param SimpleTreeClassifeir tree: 
     2770        The :class:`SimpleTreeClassifier` instance. 
     2771    :param Orange.data.Domain domain: 
     2772        The domain on which the `tree` was built. 
     2773    :param Orange.data.Table training_data: 
     2774        Optional training data do assign to the nodes of the newly 
     2775        constructed TreeClassifier. 
     2776    :param int weight_id: 
     2777        The weight (if any) used when training the `tree`. 
     2778    :rval: TreeClassifier 
     2779 
     2780    """ 
     2781    import json 
     2782    Distribution = Orange.statistics.distribution.Distribution 
     2783 
     2784    if not isinstance(tree, SimpleTreeClassifier): 
     2785        raise TypeError("SimpleTreeClassifier instance expected (got %s)" % 
     2786                        type(tree).__name__) 
     2787 
     2788    def is_discrete(var): 
     2789        return isinstance(var, Orange.feature.Discrete) 
     2790 
     2791    def is_continuous(var): 
     2792        return isinstance(var, Orange.feature.Continuous) 
     2793 
     2794    # Get the string representation as used by pickle. 
     2795    _, (tree_string, ), _ = tree.__reduce__() 
     2796    # convert it to a valid json string 
     2797    tree_string = "[ %s ]" % (tree_string.replace(" ", ",") 
     2798                              .replace("{,", "[") 
     2799                              .replace(",}", "]") 
     2800                              .rstrip(",")) 
     2801 
     2802    tree_list = json.loads(tree_string) 
     2803 
     2804    node_type, child_count, branches = tree_list 
     2805    # node_type 0 is a classifier, 1 a regression tree 
     2806    if node_type == 0 and not is_discrete(domain.class_var): 
     2807        raise ValueError 
     2808    elif node_type == 1 and not is_continuous(domain.class_var): 
     2809        raise ValueError 
     2810 
     2811    def discrete_dist(var, values): 
     2812        """ 
     2813        Create a discrete distribution containing `values`. 
     2814        """ 
     2815        dist = Distribution(var) 
     2816        for i, val in enumerate(values): 
     2817            dist.add(i, val) 
     2818        return dist 
     2819 
     2820    def continuous_dist(var, count, value): 
     2821        """ 
     2822        Create a continuous distribution with `count` points at `value`. 
     2823        """ 
     2824        dist = Distribution(var) 
     2825        dist.add(value, count) 
     2826        return dist 
     2827 
     2828    if is_discrete(domain.class_var): 
     2829        def node_distribution(values): 
     2830            return discrete_dist(domain.class_var, values) 
     2831    else: 
     2832        def node_distribution(count_valuesum): 
     2833            count, valuesum = count_valuesum 
     2834            return continuous_dist(domain.class_var, count, valuesum / count) 
     2835 
     2836    def build_tree(branch_list): 
     2837        """ 
     2838        Recursivly build a tree for a `branch_list`. 
     2839        """ 
     2840        node_type = branch_list[0] 
     2841        node = Orange.core.TreeNode() 
     2842 
     2843        if node_type in [0, 1]: 
     2844            # Internal split node 
     2845            branch_count, split_var_ind, split_val = branch_list[1:4] 
     2846            sub_branches = branch_list[4: 4 + branch_count] 
     2847            distribution = branch_list[4 + branch_count:] 
     2848            split_var = domain[split_var_ind] 
     2849 
     2850            node.distribution = node_distribution(distribution) 
     2851 
     2852            node.branches = map(build_tree, sub_branches) 
     2853 
     2854            node.branch_sizes = \ 
     2855                [sum(branch.branch_sizes or [branch.distribution.abs]) 
     2856                 for branch in node.branches] 
     2857 
     2858            if node_type == 0: 
     2859                # Discrete split node 
     2860                node.branch_descriptions = split_var.values 
     2861 
     2862                node.branch_selector = \ 
     2863                    Orange.core.ClassifierFromVarFD( 
     2864                        class_var=split_var, 
     2865                        position=split_var_ind, 
     2866                        domain=domain, 
     2867                        distribution_for_unknown=node.distribution) 
     2868 
     2869            else: 
     2870                # Continuous split node 
     2871                node.branch_descriptions = \ 
     2872                    ["<=%.3f" % split_val, ">%.3f" % split_val] 
     2873 
     2874                transformer = \ 
     2875                    Orange.feature.discretization.ThresholdDiscretizer( 
     2876                        threshold=split_val) 
     2877 
     2878                selector_var = Orange.feature.Discrete( 
     2879                    split_var.name, values=node.branch_descriptions) 
     2880 
     2881                unknown_dist = discrete_dist(selector_var, node.branch_sizes) 
     2882 
     2883                node.branch_selector = \ 
     2884                    Orange.core.ClassifierFromVarFD( 
     2885                        class_var=selector_var, 
     2886                        domain=domain, 
     2887                        position=split_var_ind, 
     2888                        transformer=transformer, 
     2889                        transform_unknowns=False, 
     2890                        distribution_for_unknown=unknown_dist) 
     2891 
     2892        elif node_type == 2: 
     2893            # Leaf predictor node 
     2894            distribution = branch_list[2:] 
     2895            node.distribution = node_distribution(distribution) 
     2896 
     2897        # Node classifier 
     2898        if is_continuous(domain.class_var): 
     2899            default_val = node.distribution.average() 
     2900        else: 
     2901            default_val = node.distribution.modus() 
     2902 
     2903        node.node_classifier = \ 
     2904            Orange.classification.ConstantClassifier( 
     2905                class_var=domain.class_var, 
     2906                default_val=default_val, 
     2907                default_distribution=node.distribution) 
     2908        return node 
     2909 
     2910    def descend_assign_instances(node, instances, splitter, weight_id=None): 
     2911        node.instances = node.examples = instances 
     2912        if len(instances): 
     2913            node.distribution = Distribution(domain.class_var, instances, 
     2914                                             weight_id) 
     2915 
     2916            node.node_classifier = \ 
     2917                Orange.classification.majority.MajorityLearner( 
     2918                    instances, weight_id) 
     2919 
     2920        if node.branches: 
     2921            split_instances, weights = splitter(node, instances, weight_id) 
     2922 
     2923            if weights is None: 
     2924                weights = [None] * len(node.branches) 
     2925 
     2926            for branch, branch_instances, weight_id in \ 
     2927                    zip(node.branches, split_instances, weights): 
     2928                descend_assign_instances(branch, branch_instances, splitter, 
     2929                                         weight_id) 
     2930 
     2931    tree_root = build_tree(branches) 
     2932 
     2933    if training_data: 
     2934        splitter = Splitter_UnknownsAsSelector() 
     2935        descend_assign_instances(tree_root, training_data, splitter, weight_id) 
     2936 
     2937    tree_c = _TreeClassifier(domain=domain, class_var=domain.class_var) 
     2938    tree_c.descender = Orange.core.TreeDescender_UnknownMergeAsSelector() 
     2939    tree_c.tree = tree_root 
     2940    return TreeClassifier(base_classifier=tree_c) 
Note: See TracChangeset for help on using the changeset viewer.