This is a PyTorch Implementation of the ViTON paper link
Currently this has the generator part of the paper. I am working on the Refinement network.
The preprocessing takes place in stages.
- First the human input image is passed through a pretrained Human Parts Parsing model link
- Then masks according to the needs are extracted
- Face & Hair mask is extracted and used to extract the face and hair from the original image.
- The Body mask is extracted excluding the Face & Hair masks and the Background.
- A Cloth mask is also extracted by combining the Top clothing or the Dress clothing classes.
- A heat map is also generated for the image of the user
- First pose keypoints are calcualted for the input image using a pretrained Keypoint calculation network link. {Already included in the code}
- For each coordinate of keypoint a heatmap is created and concatenated along the channels.
It uses a Unet architecture for the image + mask generation. The input to the Generator Network is a 24 channel Tensor consisting of
- Pose heatmaps of the human user that has the target pose. (17 Channels)
- A downsampled Body mask excluding the face and hair which will be used for the overall shape of the output image. (1 Channel)
- Face and Hair extracted from the human image. This will be used for identiy in the target image. (3 Channels)
- The target clothing that needs to be applied on the human picture. (3 Channels)
The output will be a 4 Channel Tensor
- The target ouput Image with the Human wearing the target clothing item. (3 Channels)
- The generator mask of the target clothing item from the output image.
The loss function is a combination of Perceptual Loss for checking the loss between the Target Output Image and the Generator Image by the generator network & An L1 loss for the target mask and the generated mask of the target clothing item on human.
For Perceptual Loss this implementation uses 5 layers of pretrained VGG16. ( Conv1_2, Conv2_2, Conv3_2, Conv4_2, Conv5_2 )
These layers help by creating features maps using the trained parameters of the VGG16 model,
The generated and target image are passed through the VGG16 model and the feature maps from the specified layers are extracted. Then the loss between the two feature maps are calculated and added up to create the total loss for the generator. This loss is added to the L1 loss calculated between the Masks to create the Total Loss for the Generator Network.
It takes the input as X and Y and concatenates them
- X being the 24 channel input for the Generator Network
- Y being the target image to calculate
d_real_image
and Y being generated image to calculated_fake_image
This is then passed into the discriminator network.
The loss function used for the Discriminator network is BCE. The loss is calculated for both d_rea_image
and d_fake_image
and then added to calcualte the total loss for the discriminator.
The Data should be stored in the following format in the DataViTON directory.
- train
├─ parts
├─ outputs
├─ original_cloth_mask
├─ original_cloth
├─ body_mask
├─ cloth_mask
├─ faces
├─ heatmaps
└─ image
- val
├─ parts
├─ outputs
├─ original_cloth_mask
├─ original_cloth
├─ body_mask
├─ cloth_mask
├─ faces
├─ heatmaps
└─ image
The models should be stored in the following manner.
- model
├─ disc
│ └─ dimg.pth.tar
└─ gen
└─ gimg.pth.tar
The link to the pretrained generator network generator
The link to the pretrained discriminator netowrk discriminator