-
Notifications
You must be signed in to change notification settings - Fork 702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] make autobatching matrix multiplies more flexible #566
base: master
Are you sure you want to change the base?
Conversation
Former-commit-id: 46c9a58
Hmm, I really don't like the special handling of matrix multiplies here. |
We can easily count how many different nodes a given node is an arg of, instead of how many matmuls. But why is it better? Where is this relevant besides matmul (and maybe affine transform in the future)? |
Ones that are highly relevant in addition to matmul are affine transform, conv2d, tensor contraction, etc. (I'm probably forgetting several). All of these will be much faster if you share parameters vs. iterating over them. Everything else with an arity over 2 is somewhat relevant. Sharing parameters will reduce the need for a memory copy at least. Basically my main priority is either that all nodes be treated the same, or alternatively that we have a couple of equivalence classes like "benefit largely from grouping" or "don't benefit much from grouping", so when we implement a new node we can specify which one it belongs to. |
I thought about this a little more. What about keeping a reference count for each node, then providing this reference count to the |
I don't fully understand the details of your proposal, but it sounds like extending the count from being first args of matmul to being an arg of anything, and a cleverer way of setting the threshold. sure, sounds good for me, go for it. another option I thought about (which is a little less automatic) is to add a "shared" flag to nodes. Param nodes will have this on by default, for others it can be turned on. The autobatch_sig and autobatch_concat methods will look at this flag for the args and act accordingly. But if you can get your proposal to work, lets go for it. |
Cool, I'll give this a shot. |
46c9a58
to
06ec749
Compare
This introduces two batchable versions of matmul: one in which the first argument is part of the signature, and a second in which it is not. The first version is triggered for cases where the first argument is shared with >2 matmul nodes.