-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e5f98d9
commit a576d0f
Showing
45 changed files
with
32,856 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,53 @@ | ||
# freeseed_ | ||
# FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction | ||
This is the official implementation of the paper "FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction" ([arxiv](https://arxiv.org/abs/2307.05890)) based on [torch-radon toolbox](https://github.com/matteo-ronchetti/torch-radon/tree/master). | ||
|
||
|
||
## Requirements | ||
``` | ||
- Linux Platform | ||
- python==3.7.16 | ||
- torch==1.7.1+cu110 # depends on the CUDA version of your machine | ||
- torchaudio==0.7.2 | ||
- torchvision==0.8.2+cu110 | ||
- torch-radon==1.0.0 | ||
- monai==1.0.1 | ||
- scipy==1.7.3 | ||
- einops==0.6.1 | ||
- opencv-python==4.7.0.72 | ||
- SimpleITK==2.2.1 | ||
- numpy==1.21.6 | ||
- pandas==1.3.5 # optional | ||
- tensorboard==2.11.2 # optional | ||
- wandb==0.15.2 # optional | ||
- tqdm==4.65.0 # optional | ||
``` | ||
|
||
|
||
## Data Preparation | ||
The AAPM-Myo dataset can be downloaded from: [CT Clinical Innovation Center](https://ctcicblog.mayo.edu/2016-low-dose-ct-grand-challenge/) | ||
(or the [box link](https://aapm.app.box.com/s/eaw4jddb53keg1bptavvvd1sf4x3pe9h/folder/144594475090)). Please walk through `./datasets/process_aapm.ipynb` for more details on preparing the dataset. | ||
|
||
|
||
|
||
## Training & Inference | ||
Please check `train.sh` (or `test.sh`) for the corresponding scripts once the data is well prepared. Specify the dataset path and other setting in the script, and simply run the script in the terminal. | ||
|
||
|
||
|
||
## Other Notes | ||
We choose torch-radon toolbox because it processes tomography real fast! For those who have problems installing torch-radon toolbox: | ||
- There's other forks of torch-radon like [this](https://github.com/faebstn96/torch-radon) that can be installed via `python setup.py install` without triggering too many compilation errors🤔. | ||
- Check the [issues](https://github.com/matteo-ronchetti/torch-radon/issues) of torch-radon (both open & closed), since there is discussion about any possible errors you may encountered when installing it. | ||
|
||
|
||
|
||
## Citation | ||
If you find our work and code helpful, please kindly cite the corresponding paper: | ||
``` | ||
@inproceedings{ma2023freeseed, | ||
title={FreeSeed: Frequency-band-aware and Self-guided Network for Sparse-view CT Reconstruction}, | ||
author={Ma, Chenglong and Li, Zilong and Zhang, Yi and Zhang, Junping and Shan, Hongming}, | ||
booktitle={Medical Image Computing and Computer Assisted Intervention -- MICCAI 2023}, | ||
year={2023} | ||
} | ||
``` |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import cv2 | ||
import torch | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class CTTools: | ||
def __init__(self, mu_water=0.192): | ||
self.mu_water = mu_water | ||
|
||
def HU2mu(self, hu_img): | ||
mu_img = hu_img / 1000 * self.mu_water + self.mu_water | ||
return mu_img | ||
|
||
def mu2HU(self, mu_img): | ||
hu_img = (mu_img - self.mu_water) / self.mu_water * 1000 | ||
return hu_img | ||
|
||
def window_transform(self, hu_img, width=3000, center=500, norm=False): | ||
# HU -> 0-1 normalized | ||
min_window = float(center) - 0.5 * float(width) | ||
win_img = (hu_img - min_window) / float(width) | ||
win_img[win_img < 0] = 0 | ||
win_img[win_img > 1] = 1 | ||
if norm: | ||
print('normalize to 0-255') | ||
win_img = (win_img * 255).astype('float') | ||
return win_img | ||
|
||
def back_window_transform(self, win_img, width=3000, center=500, norm=False): | ||
# 0-1 normalized -> HU | ||
min_window = float(center) - 0.5 * float(width) | ||
win_img = win_img / 255 if norm else win_img | ||
hu_img = win_img * float(width) + min_window | ||
return hu_img | ||
|
||
|
||
class AAPMMyoDataset(Dataset): | ||
def __init__(self, src_path_list, dataset_shape=512, spatial_dims=2, mode='train', num_train=5410, num_val=526): | ||
assert mode in ['train', 'val'], f'Invalid mode: {mode}.' | ||
self.mode = mode | ||
self.num_train = num_train | ||
self.num_val = num_val | ||
self.cttool = CTTools() | ||
if not isinstance(dataset_shape, (list, tuple)): | ||
dataset_shape = (dataset_shape,) * spatial_dims | ||
self.dataset_shape = dataset_shape | ||
|
||
self.src_path_list = self.get_path_list_from_dir(src_path_list, ext='.npy', keyword='') | ||
print(f'finish loading AAPM-myo {mode} dataset, total images {len(self.src_path_list)}') | ||
|
||
def get_path_list_from_dir(self, src_dir, ext='.npy', keyword=''): | ||
assert isinstance(src_dir, (list, tuple)) or (isinstance(src_dir, str) and os.path.isdir(src_dir)), \ | ||
f'Input should either be a directory containing taget files or a list containing paths of target files, got {src_dir}.' | ||
|
||
if isinstance(src_dir, str) and os.path.isdir(src_dir): | ||
src_path_list = sorted([os.path.join(src_dir, _) for _ in os.listdir(src_dir) if ext in _ and keyword in _]) | ||
elif isinstance(src_dir, (list, tuple)): | ||
src_path_list = src_dir | ||
|
||
mode = self.mode | ||
train_path_list, val_path_list = self.simple_split(src_path_list, self.num_train, self.num_val) | ||
return train_path_list if mode == 'train' else val_path_list | ||
|
||
@staticmethod | ||
def simple_split(path_list, num_train=None, num_val=None): | ||
num_imgs = len(path_list) | ||
if num_train is None or num_val is None: | ||
num_train = int(num_imgs * 0.8) | ||
num_val = num_imgs - num_train | ||
|
||
if num_train > num_val: | ||
train_list = path_list[:num_train] | ||
val_list = path_list[-num_val:] | ||
print('dataset:{}, training set:{}, val set:{}'.format(len(path_list), len(train_list), len(val_list))) | ||
else: | ||
raise ValueError(f'aapm_myo dataset simple_split() error. num_imgs={num_imgs}, while num_train={num_train}, num_val={num_val}.') | ||
return train_list, val_list | ||
|
||
def __getitem__(self, idx): | ||
src_path = self.src_path_list[idx] | ||
src_hu = np.load(src_path).squeeze() | ||
W, H = src_hu.shape[-1], src_hu.shape[-2] | ||
|
||
if H != self.dataset_shape[0] or W != self.dataset_shape[1]: | ||
src_hu = cv2.resize(src_hu, self.dataset_shape, cv2.INTER_CUBIC) | ||
|
||
src_mu = self.cttool.HU2mu(src_hu) | ||
src_mu = torch.from_numpy(src_mu).unsqueeze(0).float() | ||
return src_mu | ||
|
||
def __len__(self): | ||
return len(self.src_path_list) | ||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
from torch.utils.data import DataLoader | ||
import tqdm | ||
root_path = '/mnt/data_jixie1/clma/aapm_tr5410_te526' | ||
# image_list_path = './image_list.txt' | ||
aapm_dataset = AAPMMyoDataset(root_path, mode='val', dataset_shape=256, num_train=5410, num_val=526) | ||
val_loader = DataLoader(aapm_dataset, batch_size=1, num_workers=2) | ||
pbar = tqdm.tqdm(val_loader, ncols=60) | ||
for i, data in enumerate(pbar): | ||
print(i) | ||
print(data.shape) | ||
break |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import os | ||
import sys | ||
import random | ||
import argparse | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from networks.freeseed import FreeNet, SeedNet | ||
from networks.dudofree import DuDoFreeNet | ||
from trainers.simple_trainer import SimpleTrainer | ||
from trainers.freeseed_trainer import FreeSeedTrainer | ||
from trainers.simple_tester import SimpleTester | ||
from trainers.dudo_trainer import DuDoFreeNetTrainer | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser(description='Sparse CT Main') | ||
# logging interval by iteration | ||
parser.add_argument('--log_interval', type=int, default=400, help='logging interval by iteration') | ||
# tensorboard config | ||
parser.add_argument('--checkpoint_root', type=str, default='', help='where to save the checkpoint') | ||
parser.add_argument('--checkpoint_dir', type=str, default='test', help='detail folder of checkpoint') | ||
parser.add_argument('--use_tensorboard', action='store_true', default=False, help='whether to use tensorboard') | ||
parser.add_argument('--tensorboard_root', type=str, default='', help='root path of tensorboard, project path') | ||
parser.add_argument('--tensorboard_dir', type=str, required=True, help='detail folder of tensorboard') | ||
# wandb config | ||
parser.add_argument('--use_tqdm', action='store_true', default=False, help='whether to use tqdm') | ||
parser.add_argument('--use_wandb', action='store_true', default=False, help='whether to use wandb') | ||
parser.add_argument('--wandb_project', type=str, default='Sparse_CT') | ||
parser.add_argument('--wandb_root', type=str, default='') | ||
parser.add_argument('--wandb_dir', type=str, default='') | ||
# DDP | ||
parser.add_argument('--local_rank', type=int, default=-1, help='node rank for torch distributed training') | ||
# data_path | ||
parser.add_argument('--dataset_path', type=str, default='', help='dataset path') | ||
parser.add_argument('--dataset_name', default='aapm', type=str, help='which dataset, size640,size320,deepleision.etc.') | ||
parser.add_argument('--dataset_shape', type=int, default=512, help='modify shape in dataset') | ||
parser.add_argument('--num_train', default=5410, type=int, help='number of training examples') | ||
parser.add_argument('--num_val', default=526, type=int, help='number of validation examples') | ||
# dataloader | ||
parser.add_argument('--batch_size', default=4, type=int, help='batch_size') | ||
parser.add_argument('--shuffle', default=True, type=bool, help='dataloader shuffle, False if test and val') | ||
parser.add_argument('--num_workers', default=4, type=int, help='dataloader num_workers, 4 is a good choice') | ||
parser.add_argument('--drop_last', default=False, type=bool, help='dataloader droplast') | ||
# optimizer | ||
parser.add_argument('--optimizer', default='adam', type=str, help='name of the optimizer') | ||
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate') | ||
parser.add_argument('--beta1', default=0.5, type=float, help='Adam beta1') | ||
parser.add_argument('--beta2', default=0.999, type=float, help='Adam beta2') | ||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for SGD optimizer') | ||
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay for optimizer') | ||
parser.add_argument('--epochs', default=30, type=int, help='number of training epochs') | ||
parser.add_argument('--save_epochs', default=10, type=int) | ||
# scheduler | ||
parser.add_argument('--scheduler', default='step', type=str, help='name of the scheduler') | ||
parser.add_argument('--step_size', default=10, type=int, help='step size for StepLR') | ||
parser.add_argument('--milestones', nargs='+', type=int, help='milestones for MultiStepLR') | ||
parser.add_argument('--step_gamma', default=0.5, type=float, help='learning rate reduction factor') | ||
parser.add_argument('--poly_iters', default=10, type=int, help='the number of steps that the scheduler decays the learning rate') | ||
parser.add_argument('--poly_power', default=2, type=float, help='the power of the polynomial') | ||
|
||
# checkpath && resume training | ||
parser.add_argument('--resume', default=False, action='store_true', help='resume network training or not, load network param') | ||
parser.add_argument('--resume_opt', default=False, action='store_true', help='resume optimizer or not, load opt param') | ||
parser.add_argument('--net_checkpath', default='', type=str, help='network checkpoint path') | ||
parser.add_argument('--opt_checkpath', default='', type=str, help='optimizer checkpath') | ||
parser.add_argument('--net_checkpath2', default='', type=str, help='another network checkpoint path') | ||
|
||
# network hyper args | ||
parser.add_argument('--trainer_mode', default='train', type=str, help='train or test') | ||
parser.add_argument('--ablation_mode', default='sparse', type=str, help='default sparse, cycle: cycle_sparse') | ||
parser.add_argument('--loss', default='l1', type=str, help='loss type') | ||
parser.add_argument('--loss2', default='l2', type=str, help='another loss type') | ||
parser.add_argument('--network', default='', type=str, help='networkname') | ||
|
||
# tester args | ||
parser.add_argument('--tester_save_name', default='default_save', type=str, help='name of test' ) | ||
parser.add_argument('--tester_save_image', default=False, action='store_true', help='whether to save visualization result' ) | ||
parser.add_argument('--tester_save_path', default='', type=str, help='path for saving tester result' ) | ||
# sparse ct args | ||
parser.add_argument('--num_views', default=18, type=int, help='common setting: 18/36/72/144 out of 720') | ||
parser.add_argument('--num_full_views', default=720, type=int, help='720 for fanbeam 2D') | ||
|
||
# network args | ||
parser.add_argument('--net_dict', default='{}', type=str, help='string of dict containing network arguments') | ||
# freeseed args | ||
parser.add_argument('--use_mask', default=False, type=bool,) | ||
parser.add_argument('--soft_mask', default=True, type=bool,) | ||
return parser | ||
|
||
|
||
def sparse_main(opt): | ||
net_name = opt.network | ||
net2 = None | ||
print('Network name: ', net_name) | ||
wrapper_kwargs = { | ||
'num_views': opt.num_views, | ||
'num_full_views': opt.num_full_views, | ||
'img_size': opt.dataset_shape} | ||
|
||
if net_name == 'fbp': | ||
net = nn.Identity() # only for test | ||
elif net_name == 'freenet': | ||
mask_type = 'bp-gaussian-mc' | ||
net_dict = dict( | ||
ratio_ginout=0.5, | ||
mask_type_init=mask_type, mask_type_down=mask_type, | ||
fft_size=(opt.dataset_shape, opt.dataset_shape // 2 + 1)) | ||
net_dict.update(eval(opt.net_dict)) | ||
net = FreeNet(1, 1, **net_dict, **wrapper_kwargs) | ||
elif net_name == 'seednet': | ||
net_dict = dict( | ||
ngf=64, n_downsample=1, n_blocks=3, ratio_gin=0.5, ratio_gout=0.5, | ||
enable_lfu=False, gated=False, global_skip=True) | ||
net_dict.update(eval(opt.net_dict)) | ||
net = SeedNet(1, 1, **net_dict, **wrapper_kwargs) | ||
elif net_name == 'dudofreenet': | ||
mask_type = 'bp-gaussian-mc' | ||
net_dict = dict(ratio_ginout=0.5, mask_type=mask_type) | ||
net_dict.update(eval(opt.net_dict)) | ||
print(net_dict) | ||
net = DuDoFreeNet(**net_dict, **wrapper_kwargs) | ||
elif 'freeseed' in net_name: | ||
# use net_dict to specify some arguments for FreeNet | ||
# use 'freeseed_0.5_1_5' as an example to specify arguments for SeedNet | ||
mask_type = 'bp-gaussian-mc' | ||
freenet_dict = dict( | ||
ratio_ginout=0.5, | ||
mask_type_init=mask_type, mask_type_down=mask_type, | ||
fft_size=(opt.dataset_shape, opt.dataset_shape // 2 + 1)) | ||
freenet_dict.update(eval(opt.net_dict)) | ||
|
||
elems = net_name.split('_') | ||
ratio = float(elems[1]) | ||
n_downsampling = int(elems[2]) | ||
n_blocks = int(elems[3]) | ||
net = FreeNet(1, 1, **freenet_dict, **wrapper_kwargs) | ||
net2 = SeedNet( | ||
1, 1, ratio_gin=ratio, ratio_gout=ratio, n_downsampling=n_downsampling, | ||
n_blocks=n_blocks, enable_lfu=False, gated=False, **wrapper_kwargs) | ||
else: | ||
raise ValueError(f'opt.network selected error, network: {opt.network} not defined') | ||
|
||
if opt.trainer_mode == 'train': | ||
if 'freeseed' in net_name: | ||
trainer = FreeSeedTrainer(opt=opt, net1=net, net2=net2, loss_type=opt.loss) | ||
elif net_name in ['freenet', 'seednet']: | ||
trainer = SimpleTrainer(opt=opt, net=net, loss_type=opt.loss) | ||
elif net_name in ['dudofreenet']: | ||
trainer = DuDoFreeNetTrainer(opt=opt, net=net, loss_type=opt.loss) | ||
trainer.fit() | ||
elif opt.trainer_mode == 'test': | ||
tester = SimpleTester(opt=opt, net=net, test_window=None, net2=net2) | ||
tester.run() | ||
else: | ||
raise ValueError('opt trainer mode error: must be train or test, not {}'.format(opt.trainer_mode)) | ||
|
||
print('finish') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = get_parser() | ||
opt = parser.parse_args() | ||
sparse_main(opt) | ||
|
||
|
Empty file.
Oops, something went wrong.