Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Torch Training code for UNet Model #68

Merged
merged 9 commits into from
Dec 17, 2021
Merged

Conversation

quietrex
Copy link
Contributor

Description

  • This repo comprises of the training, inference and model conversion (pytorch model to onnx model) code for UNet masking model.

Changes made

  • Added the component required for the training and inferencing.
  1. ConvertTorchModelToOnnx
  2. CreateUnetModel
  3. ImageTrainTestSplit
  4. PrepareUnetDataLoader
  5. TrainUnet
  6. UNetModel
  7. UnetPredict
  8. ReadMaskDataset
  • Added xpipes sample.
  1. PyTorchUnetInferenceSample.xpipes - To allow inferencing of trained Unet model in either pytorch or onnx format.
  2. PyTorchUnetTrainSample.xpipes - To allow training of binary class masking with Unet model
  3. ConvertToOnnxSample.xpipes - To allow conversion of torch model to onnx model.

To Test

  1. To perform the training in PyTorchUnetTrainSample.xpipes. Please download any mask dataset that contains the images and segmentation images.
  • The dataset that I used would be Leeds butterfly dataset and ensure the tree follows as below:
    image
  • The trained model should be saved in the examples folder by default but you are allowed to modify the save path.
  • To allow it to work, you must have the same image name with the image and mask image. I have a sample preprocessing code that should ensure the image name is same.
import os
for folder in os.listdir("segmentations"):
    path = os.path.join("segmentations", folder)
    for mask_image in os.listdir(path):
        old_path = os.path.join(path, mask_image)
        new_path = os.path.join(path, mask_image.replace("_seg0",""))
        os.rename(old_path, new_path)
  1. To perform the inference code in PyTorchUnetInferenceSample.xpipes. Please put the path that model has be saved as model_path and inference image as image_path.

  2. To perform the model conversion code in ConvertToOnnxSample.xpipes. Please put the pth format model as input_model_path and onnx format path as output_model_path.

Citing Leeds buttefly dataset
Josiah Wang, Katja Markert, and Mark Everingham
Learning Models for Object Recognition from Natural Language Descriptions
In Proceedings of the 20th British Machine Vision Conference (BMVC2009)

http://www.josiahwang.com/dataset/leedsbutterfly/

@quietrex quietrex added the enhancement New feature or request label Dec 15, 2021
@MFA-X-AI MFA-X-AI self-requested a review December 16, 2021 03:51
Copy link
Member

@MFA-X-AI MFA-X-AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I think this will take a bit of time to review in depth, so I'll do a quick one now on PyTorchUnetTrainSample.xpipes.

In general I have the philosophy of making things very fast for the users to run, so usually I set almost every inPort to have a default value. I see that you've implemented it already like verbose = self.verbose.value if self.verbose.value else True, but there are some more that you could provide.

PyTorchUnetTrainSample.xpipes

ImageTrainTestSplit

  1. 0.8 sounds like a good default split ratio.

CreateUnetModel

  1. no_epochs -> epochs
  2. set default values for learning rate, patience, early stop.
  3. it's kinda odd to see setting the number of epochs in both CreateUnetModeland TrainUnet, but I guess it's needed for the early stop?

TrainUnet

  1. For modelname, I used filename = Path(sys.argv[0]).stem so it matches the xpipes name if the user does not provide.

Errors:

ModuleNotFoundError: No module named 'xai_components.xai_torch'
The training.py still uses the old imports.

Copy link
Member

@MFA-X-AI MFA-X-AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Can confirm the 3 examples work.

@MFA-X-AI MFA-X-AI merged commit 613c93e into master Dec 17, 2021
@MFA-X-AI MFA-X-AI deleted the rex-torch-unet-training branch December 17, 2021 08:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants