This repository contains the official PyTorch implementation of grounded diffusion: https://arxiv.org/abs/2301.05221.
A suitable conda environment named grounded-diffusion
can be created
and activated with:
conda env create -f environment.yaml
conda activate grounded-diffusion
https://drive.google.com/drive/folders/1HlagN6jVhmC_UbrOAy133LkN4Qgf2Scv?usp=sharing
Before training, please download the checkpoint of the off-the-shelf detector into a folder called mmdetection/checkpoint/
.
python train.py --class_split 1 --train_data random --save_name pascal_1_random
python test.py --sd_ckpt 'xxx/stable_diffusion.ckpt' \
--grounding_ckpt 'xxx/grounding_module.pth' \
--prompt "a photo of a lion on a mountain top at sunset" \
--category "lion"
If you use this code for your research or project, please cite:
@article{li2023grounded,
title = {Open-vocabulary Object Segmentation with Diffusion Models},
author = {Li, Ziyi and Zhou, Qinye and Zhang, Xiaoyun and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year = {2023}
}
Many thanks to the code bases from Stable Diffusion, CLIP, taming-transformers.