Skip to content
/ ACS Public

Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation

License

Notifications You must be signed in to change notification settings

MECLabTUDA/ACS

Repository files navigation

Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation

Abstract

Deep learning for medical imaging suffers from temporal and privacy-related restrictions on data availability. To still obtain viable models, continual learning aims to train in sequential order, as and when data is available. The main challenge that continual learning methods face is to prevent catastrophic forgetting, i.e., a decrease in performance on the data encountered earlier. This issue makes continuous training of segmentation models for medical applications extremely difficult. Yet, often, data from at least two different domains is available which we can exploit to train the model in a way that it disregards domain-specific information. We propose an architecture that leverages the simultaneous availability of two or more datasets to learn a disentanglement between the content and domain in an adversarial fashion. The domain-invariant content representation then lays the base for continual semantic segmentation. Our approach takes inspiration from domain adaptation and combines it with continual learning for hippocampal segmentation in brain MRI. We showcase that our method reduces catastrophic forgetting and outperforms state-of-the-art continual learning methods.

For more information please refer to our paper.

ACS Architecture

alt text

Qualitative Results

alt text Legend: VP MRI (original), GT (groud truth segmentation), ACS (segmentation), GAN O/P (output of GAN generator)

Setup

This repository builds on medical_pytorch and torchio. Please install this repository as explained in medical_pytorch. We provide an implementation of the adversarial continual segmenter (ACS), and the baselines memory aware synapses (MAS), and knowledge distillation (KD).

Our core implementation consists of the following structure:

mp
├── agents
|   ├── ACSAgent
|   ├── KDAgent
|   ├── MASAgent
|   ├── UNETAgent
├── data
|   ├── PytorchSeg2DDatasetDomain
├── models
|   ├── continual
|       ├── ACS
|       ├── MAS
|       ├── KD
├── *_train.py

Usage

Use the *_train.py scripts to run the experiments and set arguments corresponding to args.py. We also provide the execution commands that we used to produce the results in experiments.txt and provide console logs of the runs in logs/.

Datasets

Datasets should be placed in storage/data/ and can be loaded via dataloaders provided in mp.data.datasets or by custom implementations. We use the following three datasets:

  • DecathlonHippocampus A
  • DryadHippocampus B
  • HarP C.

Additional Features

  • extensive logging via tensorboard
  • load/save/resume training
  • multi GPU support

Acknowledgements

Supported by the Bundesministerium für Gesundheit (BMG) with grant [ZMVI1- 2520DAT03A]

About

Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages