Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Masaaki-75 authored Jul 24, 2023
1 parent e5f98d9 commit a576d0f
Show file tree
Hide file tree
Showing 45 changed files with 32,856 additions and 1 deletion.
54 changes: 53 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,53 @@
# freeseed_
# FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction
This is the official implementation of the paper "FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction" ([arxiv](https://arxiv.org/abs/2307.05890)) based on [torch-radon toolbox](https://github.com/matteo-ronchetti/torch-radon/tree/master).


## Requirements
```
- Linux Platform
- python==3.7.16
- torch==1.7.1+cu110 # depends on the CUDA version of your machine
- torchaudio==0.7.2
- torchvision==0.8.2+cu110
- torch-radon==1.0.0
- monai==1.0.1
- scipy==1.7.3
- einops==0.6.1
- opencv-python==4.7.0.72
- SimpleITK==2.2.1
- numpy==1.21.6
- pandas==1.3.5 # optional
- tensorboard==2.11.2 # optional
- wandb==0.15.2 # optional
- tqdm==4.65.0 # optional
```


## Data Preparation
The AAPM-Myo dataset can be downloaded from: [CT Clinical Innovation Center](https://ctcicblog.mayo.edu/2016-low-dose-ct-grand-challenge/)
(or the [box link](https://aapm.app.box.com/s/eaw4jddb53keg1bptavvvd1sf4x3pe9h/folder/144594475090)). Please walk through `./datasets/process_aapm.ipynb` for more details on preparing the dataset.



## Training & Inference
Please check `train.sh` (or `test.sh`) for the corresponding scripts once the data is well prepared. Specify the dataset path and other setting in the script, and simply run the script in the terminal.



## Other Notes
We choose torch-radon toolbox because it processes tomography real fast! For those who have problems installing torch-radon toolbox:
- There's other forks of torch-radon like [this](https://github.com/faebstn96/torch-radon) that can be installed via `python setup.py install` without triggering too many compilation errors🤔.
- Check the [issues](https://github.com/matteo-ronchetti/torch-radon/issues) of torch-radon (both open & closed), since there is discussion about any possible errors you may encountered when installing it.



## Citation
If you find our work and code helpful, please kindly cite the corresponding paper:
```
@inproceedings{ma2023freeseed,
title={FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction},
author={Ma, Chenglong and Li, Zilong and Zhang, Yi and Zhang, Junping and Shan, Hongming},
booktitle={Medical Image Computing and Computer Assisted Intervention -- MICCAI 2023},
year={2023}
}
```
5,936 changes: 5,936 additions & 0 deletions datasets/aapm.txt

Large diffs are not rendered by default.

110 changes: 110 additions & 0 deletions datasets/aapmmyo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset


class CTTools:
def __init__(self, mu_water=0.192):
self.mu_water = mu_water

def HU2mu(self, hu_img):
mu_img = hu_img / 1000 * self.mu_water + self.mu_water
return mu_img

def mu2HU(self, mu_img):
hu_img = (mu_img - self.mu_water) / self.mu_water * 1000
return hu_img

def window_transform(self, hu_img, width=3000, center=500, norm=False):
# HU -> 0-1 normalized
min_window = float(center) - 0.5 * float(width)
win_img = (hu_img - min_window) / float(width)
win_img[win_img < 0] = 0
win_img[win_img > 1] = 1
if norm:
print('normalize to 0-255')
win_img = (win_img * 255).astype('float')
return win_img

def back_window_transform(self, win_img, width=3000, center=500, norm=False):
# 0-1 normalized -> HU
min_window = float(center) - 0.5 * float(width)
win_img = win_img / 255 if norm else win_img
hu_img = win_img * float(width) + min_window
return hu_img


class AAPMMyoDataset(Dataset):
def __init__(self, src_path_list, dataset_shape=512, spatial_dims=2, mode='train', num_train=5410, num_val=526):
assert mode in ['train', 'val'], f'Invalid mode: {mode}.'
self.mode = mode
self.num_train = num_train
self.num_val = num_val
self.cttool = CTTools()
if not isinstance(dataset_shape, (list, tuple)):
dataset_shape = (dataset_shape,) * spatial_dims
self.dataset_shape = dataset_shape

self.src_path_list = self.get_path_list_from_dir(src_path_list, ext='.npy', keyword='')
print(f'finish loading AAPM-myo {mode} dataset, total images {len(self.src_path_list)}')

def get_path_list_from_dir(self, src_dir, ext='.npy', keyword=''):
assert isinstance(src_dir, (list, tuple)) or (isinstance(src_dir, str) and os.path.isdir(src_dir)), \
f'Input should either be a directory containing taget files or a list containing paths of target files, got {src_dir}.'

if isinstance(src_dir, str) and os.path.isdir(src_dir):
src_path_list = sorted([os.path.join(src_dir, _) for _ in os.listdir(src_dir) if ext in _ and keyword in _])
elif isinstance(src_dir, (list, tuple)):
src_path_list = src_dir

mode = self.mode
train_path_list, val_path_list = self.simple_split(src_path_list, self.num_train, self.num_val)
return train_path_list if mode == 'train' else val_path_list

@staticmethod
def simple_split(path_list, num_train=None, num_val=None):
num_imgs = len(path_list)
if num_train is None or num_val is None:
num_train = int(num_imgs * 0.8)
num_val = num_imgs - num_train

if num_train > num_val:
train_list = path_list[:num_train]
val_list = path_list[-num_val:]
print('dataset:{}, training set:{}, val set:{}'.format(len(path_list), len(train_list), len(val_list)))
else:
raise ValueError(f'aapm_myo dataset simple_split() error. num_imgs={num_imgs}, while num_train={num_train}, num_val={num_val}.')
return train_list, val_list

def __getitem__(self, idx):
src_path = self.src_path_list[idx]
src_hu = np.load(src_path).squeeze()
W, H = src_hu.shape[-1], src_hu.shape[-2]

if H != self.dataset_shape[0] or W != self.dataset_shape[1]:
src_hu = cv2.resize(src_hu, self.dataset_shape, cv2.INTER_CUBIC)

src_mu = self.cttool.HU2mu(src_hu)
src_mu = torch.from_numpy(src_mu).unsqueeze(0).float()
return src_mu

def __len__(self):
return len(self.src_path_list)




if __name__ == '__main__':
from torch.utils.data import DataLoader
import tqdm
root_path = '/mnt/data_jixie1/clma/aapm_tr5410_te526'
# image_list_path = './image_list.txt'
aapm_dataset = AAPMMyoDataset(root_path, mode='val', dataset_shape=256, num_train=5410, num_val=526)
val_loader = DataLoader(aapm_dataset, batch_size=1, num_workers=2)
pbar = tqdm.tqdm(val_loader, ncols=60)
for i, data in enumerate(pbar):
print(i)
print(data.shape)
break
22,310 changes: 22,310 additions & 0 deletions datasets/process_aapm.ipynb

Large diffs are not rendered by default.

167 changes: 167 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
import sys
import random
import argparse
import numpy as np
import torch
import torch.nn as nn

from networks.freeseed import FreeNet, SeedNet
from networks.dudofree import DuDoFreeNet
from trainers.simple_trainer import SimpleTrainer
from trainers.freeseed_trainer import FreeSeedTrainer
from trainers.simple_tester import SimpleTester
from trainers.dudo_trainer import DuDoFreeNetTrainer


def get_parser():
parser = argparse.ArgumentParser(description='Sparse CT Main')
# logging interval by iteration
parser.add_argument('--log_interval', type=int, default=400, help='logging interval by iteration')
# tensorboard config
parser.add_argument('--checkpoint_root', type=str, default='', help='where to save the checkpoint')
parser.add_argument('--checkpoint_dir', type=str, default='test', help='detail folder of checkpoint')
parser.add_argument('--use_tensorboard', action='store_true', default=False, help='whether to use tensorboard')
parser.add_argument('--tensorboard_root', type=str, default='', help='root path of tensorboard, project path')
parser.add_argument('--tensorboard_dir', type=str, required=True, help='detail folder of tensorboard')
# wandb config
parser.add_argument('--use_tqdm', action='store_true', default=False, help='whether to use tqdm')
parser.add_argument('--use_wandb', action='store_true', default=False, help='whether to use wandb')
parser.add_argument('--wandb_project', type=str, default='Sparse_CT')
parser.add_argument('--wandb_root', type=str, default='')
parser.add_argument('--wandb_dir', type=str, default='')
# DDP
parser.add_argument('--local_rank', type=int, default=-1, help='node rank for torch distributed training')
# data_path
parser.add_argument('--dataset_path', type=str, default='', help='dataset path')
parser.add_argument('--dataset_name', default='aapm', type=str, help='which dataset, size640,size320,deepleision.etc.')
parser.add_argument('--dataset_shape', type=int, default=512, help='modify shape in dataset')
parser.add_argument('--num_train', default=5410, type=int, help='number of training examples')
parser.add_argument('--num_val', default=526, type=int, help='number of validation examples')
# dataloader
parser.add_argument('--batch_size', default=4, type=int, help='batch_size')
parser.add_argument('--shuffle', default=True, type=bool, help='dataloader shuffle, False if test and val')
parser.add_argument('--num_workers', default=4, type=int, help='dataloader num_workers, 4 is a good choice')
parser.add_argument('--drop_last', default=False, type=bool, help='dataloader droplast')
# optimizer
parser.add_argument('--optimizer', default='adam', type=str, help='name of the optimizer')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--beta1', default=0.5, type=float, help='Adam beta1')
parser.add_argument('--beta2', default=0.999, type=float, help='Adam beta2')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for SGD optimizer')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay for optimizer')
parser.add_argument('--epochs', default=30, type=int, help='number of training epochs')
parser.add_argument('--save_epochs', default=10, type=int)
# scheduler
parser.add_argument('--scheduler', default='step', type=str, help='name of the scheduler')
parser.add_argument('--step_size', default=10, type=int, help='step size for StepLR')
parser.add_argument('--milestones', nargs='+', type=int, help='milestones for MultiStepLR')
parser.add_argument('--step_gamma', default=0.5, type=float, help='learning rate reduction factor')
parser.add_argument('--poly_iters', default=10, type=int, help='the number of steps that the scheduler decays the learning rate')
parser.add_argument('--poly_power', default=2, type=float, help='the power of the polynomial')

# checkpath && resume training
parser.add_argument('--resume', default=False, action='store_true', help='resume network training or not, load network param')
parser.add_argument('--resume_opt', default=False, action='store_true', help='resume optimizer or not, load opt param')
parser.add_argument('--net_checkpath', default='', type=str, help='network checkpoint path')
parser.add_argument('--opt_checkpath', default='', type=str, help='optimizer checkpath')
parser.add_argument('--net_checkpath2', default='', type=str, help='another network checkpoint path')

# network hyper args
parser.add_argument('--trainer_mode', default='train', type=str, help='train or test')
parser.add_argument('--ablation_mode', default='sparse', type=str, help='default sparse, cycle: cycle_sparse')
parser.add_argument('--loss', default='l1', type=str, help='loss type')
parser.add_argument('--loss2', default='l2', type=str, help='another loss type')
parser.add_argument('--network', default='', type=str, help='networkname')

# tester args
parser.add_argument('--tester_save_name', default='default_save', type=str, help='name of test' )
parser.add_argument('--tester_save_image', default=False, action='store_true', help='whether to save visualization result' )
parser.add_argument('--tester_save_path', default='', type=str, help='path for saving tester result' )
# sparse ct args
parser.add_argument('--num_views', default=18, type=int, help='common setting: 18/36/72/144 out of 720')
parser.add_argument('--num_full_views', default=720, type=int, help='720 for fanbeam 2D')

# network args
parser.add_argument('--net_dict', default='{}', type=str, help='string of dict containing network arguments')
# freeseed args
parser.add_argument('--use_mask', default=False, type=bool,)
parser.add_argument('--soft_mask', default=True, type=bool,)
return parser


def sparse_main(opt):
net_name = opt.network
net2 = None
print('Network name: ', net_name)
wrapper_kwargs = {
'num_views': opt.num_views,
'num_full_views': opt.num_full_views,
'img_size': opt.dataset_shape}

if net_name == 'fbp':
net = nn.Identity() # only for test
elif net_name == 'freenet':
mask_type = 'bp-gaussian-mc'
net_dict = dict(
ratio_ginout=0.5,
mask_type_init=mask_type, mask_type_down=mask_type,
fft_size=(opt.dataset_shape, opt.dataset_shape // 2 + 1))
net_dict.update(eval(opt.net_dict))
net = FreeNet(1, 1, **net_dict, **wrapper_kwargs)
elif net_name == 'seednet':
net_dict = dict(
ngf=64, n_downsample=1, n_blocks=3, ratio_gin=0.5, ratio_gout=0.5,
enable_lfu=False, gated=False, global_skip=True)
net_dict.update(eval(opt.net_dict))
net = SeedNet(1, 1, **net_dict, **wrapper_kwargs)
elif net_name == 'dudofreenet':
mask_type = 'bp-gaussian-mc'
net_dict = dict(ratio_ginout=0.5, mask_type=mask_type)
net_dict.update(eval(opt.net_dict))
print(net_dict)
net = DuDoFreeNet(**net_dict, **wrapper_kwargs)
elif 'freeseed' in net_name:
# use net_dict to specify some arguments for FreeNet
# use 'freeseed_0.5_1_5' as an example to specify arguments for SeedNet
mask_type = 'bp-gaussian-mc'
freenet_dict = dict(
ratio_ginout=0.5,
mask_type_init=mask_type, mask_type_down=mask_type,
fft_size=(opt.dataset_shape, opt.dataset_shape // 2 + 1))
freenet_dict.update(eval(opt.net_dict))

elems = net_name.split('_')
ratio = float(elems[1])
n_downsampling = int(elems[2])
n_blocks = int(elems[3])
net = FreeNet(1, 1, **freenet_dict, **wrapper_kwargs)
net2 = SeedNet(
1, 1, ratio_gin=ratio, ratio_gout=ratio, n_downsampling=n_downsampling,
n_blocks=n_blocks, enable_lfu=False, gated=False, **wrapper_kwargs)
else:
raise ValueError(f'opt.network selected error, network: {opt.network} not defined')

if opt.trainer_mode == 'train':
if 'freeseed' in net_name:
trainer = FreeSeedTrainer(opt=opt, net1=net, net2=net2, loss_type=opt.loss)
elif net_name in ['freenet', 'seednet']:
trainer = SimpleTrainer(opt=opt, net=net, loss_type=opt.loss)
elif net_name in ['dudofreenet']:
trainer = DuDoFreeNetTrainer(opt=opt, net=net, loss_type=opt.loss)
trainer.fit()
elif opt.trainer_mode == 'test':
tester = SimpleTester(opt=opt, net=net, test_window=None, net2=net2)
tester.run()
else:
raise ValueError('opt trainer mode error: must be train or test, not {}'.format(opt.trainer_mode))

print('finish')


if __name__ == '__main__':
parser = get_parser()
opt = parser.parse_args()
sparse_main(opt)


Empty file added modules/__init__.py
Empty file.
Loading

0 comments on commit a576d0f

Please sign in to comment.