Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send custom model without having to call before model.predict(dummy_data) #34

Open
yanndupis opened this issue Oct 22, 2019 · 2 comments
Labels
Good first issue 🎓 Perfect for beginners, welcome to OpenMined!

Comments

@yanndupis
Copy link
Contributor

If you look at Part 2 tutorial, for custom models (tf.keras.models.Model), before sending the model to the worker, we need to run model.predict(dummy_data) to set the input_shape ( required by tf.keras.models.save_model).

Ideally we would like to remove this step or just have to call model(dummy_data) before sending the model. You can find more information in this conversation.

@yanndupis yanndupis added the Good first issue 🎓 Perfect for beginners, welcome to OpenMined! label Oct 22, 2019
@arshjot
Copy link
Contributor

arshjot commented Nov 5, 2019

We can set the input shape while defining the model as shown below:

class CustomModel(tf.keras.Model):

    def __init__(self, num_classes=10):
        super(CustomModel, self).__init__(name='custom_model')
        self.num_classes = num_classes
        
        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
        
        # set input shape
        self._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense_1(x)
        return self.dense_2(x)
              
model = CustomModel(10)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model_ptr = model.send(bob)
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2, validation_split=0.2)

Or we can just replace model.predict(dummy_data) with model._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

Would any of these be a satisfactory solution?

@jvmncs
Copy link
Contributor

jvmncs commented Nov 7, 2019

Hmm, I don't think this is ideal, since that method _set_inputs is meant to be internal and not exposed to the user. Then again, I do like placing that in the constructor a bit more than model.predict(x) for the tutorial. I just reviewed the conversation @yanndupis & I had in the original PR, if Keras is explicitly requiring that their users call fit, predict, or _set_inputs, then I think it's okay for us to expect the same as well.

The only thing left to change here would be to handle this a bit more cleanly in the case of model.send(bob). It would be great if we had our own error to report & redirect, since a user might not realize that sending a model has this call to save_model, which could be confusing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good first issue 🎓 Perfect for beginners, welcome to OpenMined!
Projects
None yet
Development

No branches or pull requests

3 participants