Skip to content

Latest commit

 

History

History
56 lines (39 loc) · 3.33 KB

03-pretrained-models.md

File metadata and controls

56 lines (39 loc) · 3.33 KB

8.3 Pre-trained convolutional neural networks

Slides

Important: If you rent a GPU from a cloud provider (such as AWS), don't forget to turn it off after you finish. It's not free and you might get a large bill at the end of the month.

Links

  • The keras.applications module has different pre-trained models with different architectures. We'll use the model Xception which takes the input image size of (229, 229) and each image's pixel is scaled between -1 and 1.
  • We create the instance of the pre-trained model using model = Xception(weights='imagenet', input_shape=(299, 229, 3)). Our model will use the weights from pre-trained imagenet and expect the input shape of (229, 229, 3) for images.
  • Along with image size, the model also expects the batch_size which is the size of the batches of data (default 32). If one image is passed to the model, then the expected shape of the model should be (1, 229, 229, 3).
  • The image data was peprocessed using preprocess_input function during Xception model's pre-taining. Therefore, we'll have to use this function on our data before making predictions, like so: X = preprocess_input(X).
  • The pred = model.predict(X) function returns 2D array of shape (1, 1000), where 1000 is the probablity of the image classes. decode_predictions(pred) can be used to get the class names and their probabilities in readable format.
  • In order to make the pre-trained model useful specific to our case, we'll have to do some tweak, which we'll do in the coming sections.

Classes, functions, and methods:

  • from tensorflow.keras.applications.xception import Xception: import the model from keras applications
  • from tensorflow.keras.application.xception import preprocess_input: function to perform preprocessing on images
  • from tensorflow.keras.applications.xception import decode_predictions: extract the predictions class names in the form of tuple of list
  • model.predict(X): function to make predictions on the test images

Links:

Notes

Add notes from the video (PRs are welcome)

⚠️ The notes are written by the community.
If you see an error here, please create a PR with a fix.

Navigation