-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed FIGS plotting #134
Fixed FIGS plotting #134
Conversation
…tion that returns a real sklearn DecisionTree object
You can see the correct FIGS plots in my notebook here. We can remove the notebook from the PR when ready.
|
Yes this looks great! Thank you for this lovely fix + nice modifications 🤗 . Your impurity calculation is correct (at least for the object sklearn is expecting). Our recalculation of the impurity reduction is for downstream use by the FIGS algorithm when considering opening this node but the impurity at this node is exactly as you have it here. |
…ize parameter to plot()
Great! I will leave the impurity calculation alone then, and have moved the I tried saving the pos and neg counts while creating each node as calculated from I also added an optional |
|
Wonderful! |
Importing
I checked Method in my current version of |
Hi @Manuelhrokr sorry about this and for the delay! Do you know what version of imodels you are using? I think your issue should be fixed if you bump to the new version: |
Hi @csinva, sorry for my late response. Thank for the feedback, I actually updated as you suggested and it fixed the problem. |
Fixed Issue 132, FIGS plots not appearing correctly.
The primary bug was in the assignment of node ids here.
They were being improperly set during the recursion of
_update_node(nd)
. I've fixed this by assigning a newnode_num
variable after the trees are created duringfit()
here and using that instead:I also took the opportunity to return a real sklearn
DecisionTreeClassifier
orDecisionTreeRegressor
object, filling the parameters, includingtree_
, with the__setstate__()
method, building on this SO question. In order to do this, I needed the impurity at each node and the "value" as expected by sklearn, i.e.value = np.array([neg_count, pos_count], dtype=float)
. If we further rewrite theFIGS
class to save this 2D "value" along side the current value, perhaps asvalue_sklearn
, I wouldn't needX_train, y_train
for theextract_sklearn_tree_from_figs
function, and the subsequent plotting functions.@csinva does my implementation of the
impurity
variable look correct? I see the impurities are recomputed after I grab my impurity values, so I expect not. Perhaps you could fix this, or let me know the best way to get the final impurity at each node? I'll also wait for the go ahead on adding thevalue_sklearn
variable, and refactoring away the dependence onX_train, y_train
in the plotting functions.