Skip to content

Multi-input Image-to-Image Translation in PyTorch (adapted from Pix2Pix)

License

Notifications You must be signed in to change notification settings

Davegdd/AstroGAN

 
 

Repository files navigation

Multiinput image-to-image translation with Pix2Pix and Aladin Sky Atlas

This project uses the conditional GAN Pix2Pix in combination with the Aladin Sky Atlas software to generate images of a portion of the sky in a desired wavelength from images of the same portion of the sky in other wavelengths. The code of the original Pix2Pix to input a dataset has been slightly modified in order to accept several input images, thus allowing the model to go not only from B to A but from B + C + D… to A. To use this project just follow the instructions in the original Pix2Pix repository but using the modified code and multiinput dataset here.

Motivation and theoretical rationale

Different sky surveys have yielded data about portions of the sky in different wavelengths, with different resolutions and with a different total coverage. Very few of them are all-sky surveys, most of the data available has either very limited coverage (HST) or low resolution (2MASS). This results in parts of the sky where there might be data available in different resolutions and wavelengths but not in others. This project intends to find out if it is possible and feasible to use that available and complementary information to generate a target image in the unavailable wavelength/resolution.

From a theoretical point of view, the physical law most relevant to our use case is Planck’s law. This law describes the spectral density of electromagnetic radiation emitted by a black body in thermal equilibrium at a given temperature T. A black-body is an idealized object which absorbs and emits all radiation frequencies. Planck radiation has a maximum intensity at a wavelength that depends on the temperature of the body (Wikipedia).

Family of curves for different temperatures (left) and the Sun approximated as a black-body (right)

In other words, Planck’s law shows that it may be possible to use complementary wavelengths (astronomical observations) to reconstruct a missing one.

Dataset creation and training

Each data point or element in the dataset should look like this:

imagen

A single image made up of several (4 in this case) horizontally concatenated images of the same object from different observations/wavelengths (SDSS9 (high-res optical), GALEX (UV), 2MASS (IR) and DSS2 (low-res optical), respectively in the previous example).

A link to this complete dataset together with the 3 pretrained models used in the section Some results and comments further below can be found in the folder datasets.

To create a different dataset for training and/or a data point for inference the desired images have to be extracted from Aladin, a sky atlas software. The script aladin_dataset.py in the datasets folder contains a basic example to automatically extract the 7840 astronomical objects in the New General Catalogue (NGC) in four different wavelengths from Aladin and concatenate them into a dataset. It is possible to get any other astronomical object or portion of the sky by modifying the script. The name of astronomical objects as well as specific coordinates can be used as input to extract images from Aladin. To know more about how to use Aladin, please check the official documentation.

For training:

python train.py --dataroot /pathtodataset/ --name yourmodelname --model pix2pix --direction BtoA --use_wandb --input_nc 9

For testing:

python test.py --dataroot /pathtodataset/ --direction BtoA --model pix2pix --name pretrainedmodelname --input_nc 9

The two most important things here are to select direction BtoA and input_nc 9. For testing, the pretrained model file (.pth format) should be correctly named and placed in checkpoints/pretrainedmodelname/latest_net_G.pth. For further details on training/testing please check the original Pix2Pix repository.

Pix2Pix code modification

For the model to correctly accept data points like the example above and process them into several input images and one output image, the last part of the code aligned_dataset.py has been modified as follows. This code assumes that 4 images of 3 channels each are fed, crops them, transforms them and concatenates the inputs as tensors (3 inputs and 1 target, B+C+D =>A) (in principle it should be possible to allow an arbitrary number of input images with a more efficient modification):

[...]
import torch
[...]
        # split AB image into A, B, C and D
        w, h = AB.size
        w4 = int(w / 4)
        A = AB.crop((0, 0, w4, h))
        B = AB.crop((w4, 0, w4*2, h))
        C = AB.crop((w4*2, 0, w4*3, h))
        D = AB.crop((w4*3, 0, w, h))

        # apply the same transform to A, B, C and D and concatenate B as B+C+D
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)
        C = B_transform(C)
        D = B_transform(D)
        B = torch.cat((B, C, D))

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

Additionally, the current original Pix2Pix code does not support visualization of images with a number of channels different from 1 or 3. To avoid an error, the function called tensor2im in util.py has been modified by adding the following:

[...]
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        if image_numpy.shape[0] == 9:  
            image_numpy, b, c = np.vsplit(image_numpy, 3)
[...]

This will just resplit the input images that were concatenated to feed the network and return one of them for visualization as real B (origin) for reference.

Some results and comments

A Pix2Pix model was trained on a dataset of 3019 images of objects of the NGC to go from low quality optical (DSS2) to better quality optical (SDSS9), using (i) only low quality optical, (ii) low quality optical + IR and (iii) low quality optical + IR + UV. The training was stopped at around 90 epochs given a non-improving divergence in the adversarial loss, and the model at epoch 65 (before divergence started) used for testing. The results were compared in the W&B platform using the generator loss as metric and by visual inspection.

The loss function of the generator in Pix2Pix is made up of two terms, the adversarial loss and the L1 loss:

imagen

According to Jason Brownlee from Machine Learning MasteryThe adversarial loss influences whether the generator model can output images that are plausible in the target domain, whereas the L1 loss regularizes the generator model to output images that are a plausible translation of the source image”.

The parameter importance tool from W&B shows that there is a relevant negative correlation between the adversarial loss and the parameter input_nc (number of input channels, i.e., number of input images), that is, the higher the number of input images the lower the loss, as expected. Regarding the L1 loss, the correlation varied during training between positive and negative while comparing only the three abovementioned runs (low quality optical vs low quality optical + IR vs low quality optical + IR + UV) but was negative if comparing all other previous experimental runs (~7 smaller runs with less epochs, images, etc.). My best guess about this is that either the extra information is “confusing” the generator somehow or outliers are distorting the correlation, according to W&B docs “correlations are sensitive to outliers, which might turn a strong relationship to a moderate one, specially if the sample size of hyperparameters tried is small”.

imagen imagen

Although a visual comparison of the results is hard and subjective given the high similarity of the images, there seems to be several repeated patterns that could point to some conclusions about using additional information from several wavelengths (red circle = incorrect, green circle = correct):

It seems to reinforce generation confidence when there are overlapping astronomical objects:

imagen

Better rendering of brightness and color of point-like objects: imagen

Incorrect inclusion of extremely faint groups of small background point-like objects (possible confusion due to additional information): imagen

With regard to this last issue, many instances were detected where the additional information seemed to have actually helped “turn off” those faint background point-like objects in optical +IR+UV while being shown (incorrectly) if using only optical as source image.

Further experimentation with more complex images and additional wavelengths is needed but, judging by the results, the model seems to be using the additional information successfully to constrain and better render the target image. Specially interesting is the fact that color is better translated when using information from additional wavelengths, this is in line with Plank’s black-body radiation: visual appearance and radiation spectral density (and temperature) are interrelated.

Acknowledgments

@inproceedings{CycleGAN2017, title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks}, author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, year={2017} }

@inproceedings{isola2017image, title={Image-to-Image Translation with Conditional Adversarial Networks}, author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on}, year={2017} }

This project has made use of "Aladin sky atlas" developed at CDS, Strasbourg Observatory, France → 2000A&AS..143...33B and 2014ASPC..485..277B.

About

Multi-input Image-to-Image Translation in PyTorch (adapted from Pix2Pix)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 83.7%
  • Jupyter Notebook 7.4%
  • AspectJ 4.6%
  • Shell 2.3%
  • Other 2.0%