Ticket #1229 (assigned wish)

Opened 2 years ago

Last modified 2 years ago

TreeClassifier to_networkx_tree() function

Reported by: yazan.boshmaf Owned by: marko
Milestone: 2.6 Component: library
Severity: minor Keywords: TreeClassifier, to_string(), networkx
Cc: miha, janez Blocking:
Blocked By:

Description (last modified by yazan.boshmaf) (diff)

The TreeClassifier has a to_string() function that prints the decision tree in text format. Even though this function can be customized via formatting arguments, the output is still unstructured and the tree cannot be easily traversed and inspected.

Below, I propose to add a new function called to_netwrokx_tree() that converts the internal Orange tree structure to a Networkx directed graph, where internal nodes are features and leaves are decisions. Internal node labels are feature names and edges' labels are feature values on which the tree splits. The leaves' labels are the decisions along with their probabilities.

def to_networkx_tree(pnode, tree=None, level=0, local_index=0):
    '''Returns a networkx tree of the orange classification tree'''

    # create a new directed graph if this is the first call
    if not tree:
        tree = networkx.DiGraph()

    # reached null nodes
    if not pnode:
        return tree

    # if the node has children (not leaves), traverse
    if pnode.branch_selector:

        pnode_desc = pnode.branch_selector.class_var.name

        # create a root node only on first call
        if level == 0:
            pnode_id = str(pnode_desc) + '_%i' % len(tree)
            tree.add_node(pnode_id, shape='none', label=str(pnode_desc))
        else:
            pnode_id = str(pnode_desc) + '_%i' % (len(tree) - 1)

        # iterate through children and update branch
        for i in range(len(pnode.branches)):
            cnode = pnode.branches[i]

            # if the child is not a null node, update the tree
            if cnode:

                # if the child is a parent, update a branch
                if cnode.branch_selector:
                    # add child node
                    cnode_desc = cnode.branch_selector.class_var.name
                    cnode_id = str(cnode_desc) + '_%i' % len(tree)
                    tree.add_node(cnode_id, shape='none', label=str(cnode_desc))

                    # add edge between parent and child
                    edge_desc = str(pnode.branch_descriptions[i])
                    tree.add_edge(pnode_id, cnode_id, label=edge_desc)

                    # recursively update this branch of the tree
                    tree = get_tree(cnode, tree, level+1, i)

                # if the child is a leaf, then add class info
                else:
                    leaf_dist = cnode.node_classifier.default_distribution
                    leaf_prob = leaf_dist[int(cnode.node_classifier.default_val)]
                    leaf_class = cnode.node_classifier.default_value
                    leaf_id = str(leaf_class) + '_%i' % len(tree)
                    leaf_label = str(leaf_class) + ' (%.2f%%)' % (leaf_prob*100)
                    tree.add_node(leaf_id, shape='box', label=leaf_label)

                    # add edge between parent and child
                    edge_desc = str(pnode.branch_descriptions[i])
                    tree.add_edge(pnode_id, leaf_id, label=edge_desc)

            # null node reached
            else:
                pass

        # once done with children for this parent, return tree
        return tree

    # if it's a leaf and it's a root, update the node with class info
    else:
        if level == 0:
            leaf_dist = pnode.node_classifier.default_distribution
            leaf_prob = leaf_dist[int(pnode.node_classifier.default_val)]
            leaf_class = pnode.node_classifier.default_value
            leaf_id = str(leaf_class) + '_%i' % len(tree)
            leaf_label = str(leaf_class) + ' (%.2f%%)' % (leaf_prob*100)
            tree.add_node(leaf_id, shape='box', label=leaf_label)
        return tree

Here's how the code could be used and tested:

import orange, Orange
import networkx
import networkx.algorithms as nx_algorithms

# use voting dataset to train a decision tree
data = orange.ExampleTable('voting')
tree_classifier = orange.TreeLearner(data)

# convert the orange tree to networkx tree
networkx_tree = to_networkx_tree(tree_classifier.tree)

# assert that both trees have the same order, have one connected components, and is a DAG 
assert len(networkx_tree) == tree_classifier.tree.tree_size()
assert nx_algo.number_connected_components(networkx_tree.to_undirected()) == 1
assert nx_algo.is_directed_acyclic_graph(networkx_tree) == True

# convert the tree to AGraph and save it as a Graphviz DOT file
agraph = networkx.to_agraph(networkx_tree)
agraph.write('networkx_tree.dot')

Change History

comment:1 Changed 2 years ago by yazan.boshmaf

  • Description modified (diff)

comment:2 Changed 2 years ago by ales

  • Cc miha, janez added
  • Status changed from new to assigned
  • Owner changed from Yazan Boshmaf to marko
Note: See TracTickets for help on using tickets.