Changeset 11626:bc36685ddf73 in orange for Orange/classification/tree.py
 Timestamp:
 04/17/12 15:14:10 (2 years ago)
 Branch:
 default
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

Orange/classification/tree.py
r11459 r11626 2758 2758 return count 2759 2759 2760 2761 def _simple_tree_convert(tree, domain, training_data=None, weight_id=None): 2762 """ 2763 Convert an :class:`SimpleTreeClassifier` to a :class:`TreeClassifier`. 2764 2765 The domain used to build it must be supplied with the `domain` 2766 parameter. If `training_data` is not None it is split and assigned 2767 to the tree's nodes. 2768 2769 :param SimpleTreeClassifeir tree: 2770 The :class:`SimpleTreeClassifier` instance. 2771 :param Orange.data.Domain domain: 2772 The domain on which the `tree` was built. 2773 :param Orange.data.Table training_data: 2774 Optional training data do assign to the nodes of the newly 2775 constructed TreeClassifier. 2776 :param int weight_id: 2777 The weight (if any) used when training the `tree`. 2778 :rval: TreeClassifier 2779 2780 """ 2781 import json 2782 Distribution = Orange.statistics.distribution.Distribution 2783 2784 if not isinstance(tree, SimpleTreeClassifier): 2785 raise TypeError("SimpleTreeClassifier instance expected (got %s)" % 2786 type(tree).__name__) 2787 2788 def is_discrete(var): 2789 return isinstance(var, Orange.feature.Discrete) 2790 2791 def is_continuous(var): 2792 return isinstance(var, Orange.feature.Continuous) 2793 2794 # Get the string representation as used by pickle. 2795 _, (tree_string, ), _ = tree.__reduce__() 2796 # convert it to a valid json string 2797 tree_string = "[ %s ]" % (tree_string.replace(" ", ",") 2798 .replace("{,", "[") 2799 .replace(",}", "]") 2800 .rstrip(",")) 2801 2802 tree_list = json.loads(tree_string) 2803 2804 node_type, child_count, branches = tree_list 2805 # node_type 0 is a classifier, 1 a regression tree 2806 if node_type == 0 and not is_discrete(domain.class_var): 2807 raise ValueError 2808 elif node_type == 1 and not is_continuous(domain.class_var): 2809 raise ValueError 2810 2811 def discrete_dist(var, values): 2812 """ 2813 Create a discrete distribution containing `values`. 2814 """ 2815 dist = Distribution(var) 2816 for i, val in enumerate(values): 2817 dist.add(i, val) 2818 return dist 2819 2820 def continuous_dist(var, count, value): 2821 """ 2822 Create a continuous distribution with `count` points at `value`. 2823 """ 2824 dist = Distribution(var) 2825 dist.add(value, count) 2826 return dist 2827 2828 if is_discrete(domain.class_var): 2829 def node_distribution(values): 2830 return discrete_dist(domain.class_var, values) 2831 else: 2832 def node_distribution(count_valuesum): 2833 count, valuesum = count_valuesum 2834 return continuous_dist(domain.class_var, count, valuesum / count) 2835 2836 def build_tree(branch_list): 2837 """ 2838 Recursivly build a tree for a `branch_list`. 2839 """ 2840 node_type = branch_list[0] 2841 node = Orange.core.TreeNode() 2842 2843 if node_type in [0, 1]: 2844 # Internal split node 2845 branch_count, split_var_ind, split_val = branch_list[1:4] 2846 sub_branches = branch_list[4: 4 + branch_count] 2847 distribution = branch_list[4 + branch_count:] 2848 split_var = domain[split_var_ind] 2849 2850 node.distribution = node_distribution(distribution) 2851 2852 node.branches = map(build_tree, sub_branches) 2853 2854 node.branch_sizes = \ 2855 [sum(branch.branch_sizes or [branch.distribution.abs]) 2856 for branch in node.branches] 2857 2858 if node_type == 0: 2859 # Discrete split node 2860 node.branch_descriptions = split_var.values 2861 2862 node.branch_selector = \ 2863 Orange.core.ClassifierFromVarFD( 2864 class_var=split_var, 2865 position=split_var_ind, 2866 domain=domain, 2867 distribution_for_unknown=node.distribution) 2868 2869 else: 2870 # Continuous split node 2871 node.branch_descriptions = \ 2872 ["<=%.3f" % split_val, ">%.3f" % split_val] 2873 2874 transformer = \ 2875 Orange.feature.discretization.ThresholdDiscretizer( 2876 threshold=split_val) 2877 2878 selector_var = Orange.feature.Discrete( 2879 split_var.name, values=node.branch_descriptions) 2880 2881 unknown_dist = discrete_dist(selector_var, node.branch_sizes) 2882 2883 node.branch_selector = \ 2884 Orange.core.ClassifierFromVarFD( 2885 class_var=selector_var, 2886 domain=domain, 2887 position=split_var_ind, 2888 transformer=transformer, 2889 transform_unknowns=False, 2890 distribution_for_unknown=unknown_dist) 2891 2892 elif node_type == 2: 2893 # Leaf predictor node 2894 distribution = branch_list[2:] 2895 node.distribution = node_distribution(distribution) 2896 2897 # Node classifier 2898 if is_continuous(domain.class_var): 2899 default_val = node.distribution.average() 2900 else: 2901 default_val = node.distribution.modus() 2902 2903 node.node_classifier = \ 2904 Orange.classification.ConstantClassifier( 2905 class_var=domain.class_var, 2906 default_val=default_val, 2907 default_distribution=node.distribution) 2908 return node 2909 2910 def descend_assign_instances(node, instances, splitter, weight_id=None): 2911 node.instances = node.examples = instances 2912 if len(instances): 2913 node.distribution = Distribution(domain.class_var, instances, 2914 weight_id) 2915 2916 node.node_classifier = \ 2917 Orange.classification.majority.MajorityLearner( 2918 instances, weight_id) 2919 2920 if node.branches: 2921 split_instances, weights = splitter(node, instances, weight_id) 2922 2923 if weights is None: 2924 weights = [None] * len(node.branches) 2925 2926 for branch, branch_instances, weight_id in \ 2927 zip(node.branches, split_instances, weights): 2928 descend_assign_instances(branch, branch_instances, splitter, 2929 weight_id) 2930 2931 tree_root = build_tree(branches) 2932 2933 if training_data: 2934 splitter = Splitter_UnknownsAsSelector() 2935 descend_assign_instances(tree_root, training_data, splitter, weight_id) 2936 2937 tree_c = _TreeClassifier(domain=domain, class_var=domain.class_var) 2938 tree_c.descender = Orange.core.TreeDescender_UnknownMergeAsSelector() 2939 tree_c.tree = tree_root 2940 return TreeClassifier(base_classifier=tree_c)
Note: See TracChangeset
for help on using the changeset viewer.