Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed FIGS plotting #134

Merged
merged 5 commits into from
Sep 15, 2022
Merged

Fixed FIGS plotting #134

merged 5 commits into from
Sep 15, 2022

Conversation

mepland
Copy link
Collaborator

@mepland mepland commented Sep 14, 2022

Fixed Issue 132, FIGS plots not appearing correctly.

The primary bug was in the assignment of node ids here.

            right = next(node_counter)
            left = next(node_counter)

They were being improperly set during the recursion of _update_node(nd). I've fixed this by assigning a new node_num variable after the trees are created during fit() here and using that instead:

        # add node_num to final tree
        for tree_ in self.trees_:
            node_counter = iter(range(0, int(1e06)))
            def _add_node_num(node: Node):
                if node is None:
                    return
                node.setattrs(node_num=next(node_counter))
                _add_node_num(node.left)
                _add_node_num(node.right)

            _add_node_num(tree_)

I also took the opportunity to return a real sklearn DecisionTreeClassifier or DecisionTreeRegressor object, filling the parameters, including tree_, with the __setstate__() method, building on this SO question. In order to do this, I needed the impurity at each node and the "value" as expected by sklearn, i.e. value = np.array([neg_count, pos_count], dtype=float). If we further rewrite the FIGS class to save this 2D "value" along side the current value, perhaps as value_sklearn, I wouldn't need X_train, y_train for the extract_sklearn_tree_from_figs function, and the subsequent plotting functions.

@csinva does my implementation of the impurity variable look correct? I see the impurities are recomputed after I grab my impurity values, so I expect not. Perhaps you could fix this, or let me know the best way to get the final impurity at each node? I'll also wait for the go ahead on adding the value_sklearn variable, and refactoring away the dependence on X_train, y_train in the plotting functions.

@mepland
Copy link
Collaborator Author

mepland commented Sep 14, 2022

You can see the correct FIGS plots in my notebook here. We can remove the notebook from the PR when ready.

Glucose concentration test <= 99.500 65/192 (33.85%)
	ΔRisk = 0.07 4/59 (6.78%)
	Glucose concentration test <= 168.500 61/133 (45.86%)
		#Pregnant <= 6.500 44/112 (39.29%)
			Body mass index <= 30.850 21/76 (27.63%)
				ΔRisk = 0.06 2/31 (6.45%)
				Blood pressure(mmHg) <= 67.000 19/45 (42.22%)
					ΔRisk = 0.71 10/14 (71.43%)
					ΔRisk = 0.30 9/31 (29.03%)
			ΔRisk = 0.64 23/36 (63.89%)
		Blood pressure(mmHg) <= 93.000 17/21 (80.95%)
			ΔRisk = 0.86 17/19 (89.47%)
			ΔRisk = -0.01 0/2 (0.0%)

figs

@csinva
Copy link
Owner

csinva commented Sep 14, 2022

Yes this looks great! Thank you for this lovely fix + nice modifications 🤗 .

Your impurity calculation is correct (at least for the object sklearn is expecting). Our recalculation of the impurity reduction is for downstream use by the FIGS algorithm when considering opening this node but the impurity at this node is exactly as you have it here.

@mepland
Copy link
Collaborator Author

mepland commented Sep 15, 2022

Great! I will leave the impurity calculation alone then, and have moved the values_sklearn creation into fit().

I tried saving the pos and neg counts while creating each node as calculated from y[idxs], y[idxs_left], y[idxs_right] as appropriate, but was getting inconsistent results. Instead I just added values_sklearn after the final tree has been grown by FIGS, at the same time that I add the node_id values.

I also added an optional fig_size parameter to plot() to control the size of the figure, and thus the spacing between nodes when plotted.

@mepland
Copy link
Collaborator Author

mepland commented Sep 15, 2022

dtreeviz is now working for FIGS! For more plots see my demo notebook here. This closes Issue 135 for FIGS trees.

from dtreeviz import trees
from dtreeviz.models.sklearn_decision_trees import ShadowSKDTree
from imodels.tree.viz_utils import extract_sklearn_tree_from_figs

dt = extract_sklearn_tree_from_figs(model_figs, tree_num=0, n_classes=2)
sk_dtree = ShadowSKDTree(dt, X_train, y_train, feat_names, 'y', [0, 1])

figs

@mepland mepland marked this pull request as ready for review September 15, 2022 04:05
@csinva
Copy link
Owner

csinva commented Sep 15, 2022

Wonderful!

@csinva csinva merged commit 13ea245 into csinva:master Sep 15, 2022
@mattheweplandKH mattheweplandKH deleted the viz branch September 15, 2022 20:22
@Manuelhrokr
Copy link

Importing imodels.tree.viz_utils import extract_sklearn_tree_from_figs as per indicated in the demo notebook gives the following error:

ImportError: cannot import name 'extract_sklearn_tree_from_figs' from 'imodels.tree.viz_utils' (/home/manuelzl/miniconda3/envs/manuelenv_python/lib/python3.7/site-packages/imodels/tree/viz_utils.py)

I checked viz_utils.py file in the path, and the code there is different than the one shown in EXPAND SOURCE CODE.

Method in my current version of viz_utils.py is:
extract_sklearn_tree_from_figs_tree((node, n_classes)

@csinva
Copy link
Owner

csinva commented Oct 17, 2022

Hi @Manuelhrokr sorry about this and for the delay! Do you know what version of imodels you are using? I think your issue should be fixed if you bump to the new version: pip install --upgrade imodels.

@Manuelhrokr
Copy link

Hi @csinva, sorry for my late response. Thank for the feedback, I actually updated as you suggested and it fixed the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FIGS print and plot return different trees
3 participants