Deep learning for medical imaging suffers from temporal and privacy-related restrictions on data availability. To still obtain viable models, continual learning aims to train in sequential order, as and when data is available. The main challenge that continual learning methods face is to prevent catastrophic forgetting, i.e., a decrease in performance on the data encountered earlier. This issue makes continuous training of segmentation models for medical applications extremely difficult. Yet, often, data from at least two different domains is available which we can exploit to train the model in a way that it disregards domain-specific information. We propose an architecture that leverages the simultaneous availability of two or more datasets to learn a disentanglement between the content and domain in an adversarial fashion. The domain-invariant content representation then lays the base for continual semantic segmentation. Our approach takes inspiration from domain adaptation and combines it with continual learning for hippocampal segmentation in brain MRI. We showcase that our method reduces catastrophic forgetting and outperforms state-of-the-art continual learning methods.
For more information please refer to our paper.
Legend: VP MRI (original), GT (groud truth segmentation), ACS (segmentation), GAN O/P (output of GAN generator)
This repository builds on medical_pytorch and torchio. Please install this repository as explained in medical_pytorch. We provide an implementation of the adversarial continual segmenter (ACS), and the baselines memory aware synapses (MAS), and knowledge distillation (KD).
Our core implementation consists of the following structure:
mp
├── agents
| ├── ACSAgent
| ├── KDAgent
| ├── MASAgent
| ├── UNETAgent
├── data
| ├── PytorchSeg2DDatasetDomain
├── models
| ├── continual
| ├── ACS
| ├── MAS
| ├── KD
├── *_train.py
Use the *_train.py
scripts to run the experiments and set arguments corresponding to args.py
. We also provide the execution commands that we used to produce the results in experiments.txt
and provide console logs of the runs in logs/
.
Datasets should be placed in storage/data/
and can be loaded via dataloaders provided in mp.data.datasets
or by custom implementations. We use the following three datasets:
- extensive logging via tensorboard
- load/save/resume training
- multi GPU support
Supported by the Bundesministerium für Gesundheit (BMG) with grant [ZMVI1- 2520DAT03A]