Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed with trl #6852

Open
sagie-dekel opened this issue Dec 11, 2024 · 6 comments
Open

DeepSpeed with trl #6852

sagie-dekel opened this issue Dec 11, 2024 · 6 comments
Assignees
Labels
bug Something isn't working training

Comments

@sagie-dekel
Copy link

Describe the bug
I am trying to train meta-llama/Llama-3.1-8B-Instruct with trl DPOTrainer.
After creating the trainer and starting the training loop, I'm getting the following error (in the forward pass):

[rank0]: Traceback (most recent call last):
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 771, in <module>
[rank0]:     main()
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 759, in main
[rank0]:     RLRF_Pipeline(
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 345, in RLRF_Pipeline
[rank0]:     RLRF_manager.RLRF_DPO(model_save_path=RLRF_model_save_path)
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/RLRF_main.py", line 367, in RLRF_DPO
[rank0]:     self.FTTrainer.train()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 3579, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1371, in compute_loss
[rank0]:     loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1330, in get_batch_loss_metrics
[rank0]:     ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 950, in compute_ref_log_probs
[rank0]:     ref_model_output = self.concatenated_forward(self.ref_model, batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1260, in concatenated_forward
[rank0]:     outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/engine.py", line 1909, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
[rank0]:     outputs = self.model(
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 891, in forward
[rank0]:     inputs_embeds = self.embed_tokens(input_ids)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1779, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/parameter_offload.py", line 285, in _pre_forward_module_hook
[rank0]:     self.pre_sub_module_forward_function(module)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/parameter_offload.py", line 460, in pre_sub_module_forward_function
[rank0]:     param_coordinator.fetch_sub_module(sub_module, forward=True)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 331, in fetch_sub_module
[rank0]:     assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
[rank0]: AssertionError: {'id': 682, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': False, 'grad_shape': None, 'persist': True, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([0])}
  0%|          | 0/168 [00:03<?, ?it/s]

I tried to downgrade transformers with no success.

System info (please complete the following information):

  • OS: Linux Rocky 8 (with Portable Batch System)
  • GPU count and types: one machine with 3 Nvidia A100 GPUs.
  • Python version: 3.10.15
  • Transformers: 4.46.3
  • Torch: 2.5.1+cu118
  • CUDA (python -c 'import torch; print(torch.version.cuda)'): 11.8
  • Packages Version:
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   haa98f57_10
absl-py                   2.1.0                    pypi_0    pypi
accelerate                1.2.0                    pypi_0    pypi
aiohappyeyeballs          2.4.4                    pypi_0    pypi
aiohttp                   3.11.10                  pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
annotated-types           0.6.0           py310h06a4308_0
async-timeout             5.0.1                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
binutils_impl_linux-64    2.40                 h5293946_0
blas                      1.0                         mkl
brotli-python             1.0.9           py310h6a678d5_8
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2024.11.26           h06a4308_0
certifi                   2024.8.30       py310h06a4308_0
charset-normalizer        3.3.2              pyhd3eb1b0_0
compilers                 0.0.0                    pypi_0    pypi
cuda                      11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-cccl                 11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-command-line-tools   11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-compiler             11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-cudart               11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-cudart-dev           11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-cuobjdump            11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-cupti                11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-cuxxfilt             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-demo-suite           11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-documentation        11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-driver-dev           11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-gdb                  11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-libraries            11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-libraries-dev        11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-memcheck             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nsight               11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nsight-compute       11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-nvcc                 11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvdisasm             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvml-dev             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvprof               11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-nvprune              11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvrtc                11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvrtc-dev            11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvtx                 11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvvp                 11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-profiler-api         11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-runtime              11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-sanitizer-api        11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-toolkit              11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-tools                11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-visual-tools         11.8.0                        0    nvidia/label/cuda-11.8.0
datasets                  3.1.0                    pypi_0    pypi
deepspeed                 0.16.2+9ca60160          pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.16.1                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.9.0                 pypi_0    pypi
gcc                       11.4.0              h602e360_13    conda-forge
gcc_impl_linux-64         11.4.0              h00c12a0_13    conda-forge
gds-tools                 1.4.0.31                      0    nvidia/label/cuda-11.8.0
giflib                    5.2.2                h5eee18b_0
gmp                       6.2.1                h295c915_3
gmpy2                     2.1.2           py310heeb90bb_0
gnutls                    3.6.15               he1e5248_0
grpcio                    1.68.1                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
huggingface-hub           0.26.5                   pypi_0    pypi
idna                      3.7             py310h06a4308_0
intel-openmp              2023.1.0         hdb19cb5_46306
jinja2                    3.1.4           py310h06a4308_1
jpeg                      9e                   h5eee18b_3
kernel-headers_linux-64   3.10.0              h57e8cba_10
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.40                 h12ee557_0
lerc                      3.0                  h295c915_0
libaio                    0.3.113              h5eee18b_0
libcublas                 11.11.3.6                     0    nvidia/label/cuda-11.8.0
libcublas-dev             11.11.3.6                     0    nvidia/label/cuda-11.8.0
libcufft                  10.9.0.58                     0    nvidia/label/cuda-11.8.0
libcufft-dev              10.9.0.58                     0    nvidia/label/cuda-11.8.0
libcufile                 1.4.0.31                      0    nvidia/label/cuda-11.8.0
libcufile-dev             1.4.0.31                      0    nvidia/label/cuda-11.8.0
libcurand                 10.3.0.86                     0    nvidia/label/cuda-11.8.0
libcurand-dev             10.3.0.86                     0    nvidia/label/cuda-11.8.0
libcusolver               11.4.1.48                     0    nvidia/label/cuda-11.8.0
libcusolver-dev           11.4.1.48                     0    nvidia/label/cuda-11.8.0
libcusparse               11.7.5.86                     0    nvidia/label/cuda-11.8.0
libcusparse-dev           11.7.5.86                     0    nvidia/label/cuda-11.8.0
libdeflate                1.17                 h5eee18b_1
libffi                    3.4.4                h6a678d5_1
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-devel_linux-64     11.4.0             h8f596e0_113    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libiconv                  1.16                 h5eee18b_3
libidn2                   2.3.4                h5eee18b_0
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
libnpp                    11.8.0.86                     0    nvidia/label/cuda-11.8.0
libnpp-dev                11.8.0.86                     0    nvidia/label/cuda-11.8.0
libnvjpeg                 11.9.0.86                     0    nvidia/label/cuda-11.8.0
libnvjpeg-dev             11.9.0.86                     0    nvidia/label/cuda-11.8.0
libpng                    1.6.39               h5eee18b_0
libsanitizer              11.4.0              h5763a12_13    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libtasn1                  4.19.0               h5eee18b_0
libtiff                   4.5.1                h6a678d5_0
libunistring              0.9.10               h27cfd23_0
libuuid                   1.41.5               h5eee18b_0
libwebp                   1.3.2                h11a3e52_0
libwebp-base              1.3.2                h5eee18b_1
llvm-openmp               14.0.6               h9e868ea_0
lz4-c                     1.9.4                h6a678d5_1
markdown                  3.7                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344
mkl-service               2.4.0           py310h5eee18b_1
mkl_fft                   1.3.11          py310h5eee18b_0
mkl_random                1.2.8           py310h1128e8f_0
mpc                       1.1.0                h10f8cd9_1
mpfr                      4.0.2                hb69a4c5_1
mpmath                    1.3.0           py310h06a4308_0
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.1.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
nettle                    3.7.3                hbbd107a_1
networkx                  3.4.2                    pypi_0    pypi
ninja                     1.11.1.2                 pypi_0    pypi
ninja-base                1.12.1               hdb19cb5_0
nsight-compute            2022.3.0.22                   0    nvidia/label/cuda-11.8.0
numpy                     1.22.4                   pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu11         9.1.0.70                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu11          2.21.5                   pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openh264                  2.1.1                h4ff587b_0
openjpeg                  2.5.2                he7f1fd0_0
openssl                   3.4.0                hb9d3cd8_0    conda-forge
pandas                    2.0.0                    pypi_0    pypi
pillow                    11.0.0          py310hfdbf927_0
pip                       24.2            py310h06a4308_0
propcache                 0.2.1                    pypi_0    pypi
protobuf                  5.29.1                   pypi_0    pypi
psutil                    6.1.0                    pypi_0    pypi
py-cpuinfo                9.0.0           py310h06a4308_0
pyarrow                   18.1.0                   pypi_0    pypi
pydantic                  2.8.2           py310h06a4308_0
pydantic-core             2.20.1          py310hb02cf49_0
pygments                  2.18.0                   pypi_0    pypi
pysocks                   1.7.1           py310h06a4308_0
python                    3.10.15              he870216_1
python-dateutil           2.9.0.post0              pypi_0    pypi
pytorch-cuda              11.8                 h7e8668a_6    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.2                   pypi_0    pypi
pyyaml                    6.0.2           py310h5eee18b_0
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3          py310h06a4308_1
rich                      13.9.4                   pypi_0    pypi
safetensors               0.4.5                    pypi_0    pypi
setuptools                75.6.0                   pypi_0    pypi
six                       1.17.0                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0
sympy                     1.13.1                   pypi_0    pypi
sysroot_linux-64          2.17                h57e8cba_10
tbb                       2021.8.0             hdb19cb5_0
tensorboard               2.18.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.20.3                   pypi_0    pypi
torch                     2.5.1+cu118              pypi_0    pypi
torchaudio                2.5.1                    pypi_0    pypi
torchvision               0.20.1                   pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.46.3                   pypi_0    pypi
triton                    3.1.0                    pypi_0    pypi
trl                       0.12.2                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
typing_extensions         4.11.0          py310h06a4308_0
tzdata                    2024.2                   pypi_0    pypi
urllib3                   2.2.3           py310h06a4308_0
werkzeug                  3.1.3                    pypi_0    pypi
wheel                     0.44.0          py310h06a4308_0
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1
yaml                      0.2.5                h7b6447c_0
yarl                      1.18.3                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_1
zstd                      1.5.6                hc292b87_0

my accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: DEEPSPEED
downcast_bf16: false
gpu_ids: all
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 3
rdzv_backend: static
same_network: true
fsdp_config: {}
mixed_precision: bf16
use_cpu: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
@sagie-dekel sagie-dekel added bug Something isn't working training labels Dec 11, 2024
@jomayeri
Copy link
Contributor

Can you try different offload settings? Only offload optimizer or parameters and not both. How much cpu memory does the system have?

@jomayeri jomayeri self-assigned this Dec 12, 2024
@sagie-dekel
Copy link
Author

sagie-dekel commented Dec 12, 2024

Hi @jomayeri, thanks for answering.

I already tried all the offload permutations and got the same error.

I don't know the exact cpu memory, but I have 32 cpu's and during the run i got the following info:

[2024-12-11 10:55:07,014] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 45.89 GB, percent = 4.6%

so I don't think CPU memory is the problem.

@tjruwase
Copy link
Contributor

tjruwase commented Dec 12, 2024

[rank0]: File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 331, in fetch_sub_module
[rank0]: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
[rank0]: AssertionError: {'id': 682, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': False, 'grad_shape': None, 'persist': True, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([0])}

This indicates that the parameter was not fetched or all-gathered as required before use. This is a very strange failure for zero stage 3. Are you able to share full repro steps? By this, I mean including command line and datasets.

@sagie-dekel
Copy link
Author

sagie-dekel commented Dec 12, 2024

perfernces_dataset_from_ranker_train_queries_and_baseline_doc.csv
@tjruwase

command line:

accelerate launch --config_file=/home/sagie.dekel/.cache/huggingface/accelerate/accelerate_deepspeed_config.yaml path_to_program_file.py path_to_param_file.json

The datasets include a small csv file with prompts and an accepted + rejected sample.

@jomayeri
Copy link
Contributor

jomayeri commented Dec 13, 2024

What does the program file consist of? And what's in param_file.json?

@sagie-dekel
Copy link
Author

Hi again @jomayeri

what do you mean by consist of? it's a regular python file executing DPO pipeline.
The param_file.json contains parameters for the program (e.g., model name) and HP for the DPO trainer (e.g., learning rate or beta).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants