Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error during training #3

Closed
hanshuo-shuo opened this issue Nov 7, 2023 · 13 comments
Closed

Error during training #3

hanshuo-shuo opened this issue Nov 7, 2023 · 13 comments

Comments

@hanshuo-shuo
Copy link

hanshuo-shuo commented Nov 7, 2023

Hi, Thanks for your code~
I try to apply the code to my environment. During training, I found out

image image

I was wondering if you also had issues like this, is it due to some uninstalled module?

@hanshuo-shuo
Copy link
Author

I use this to solve the problem, but I am not sure this is correct though:
image

@aditya-spood
Copy link

The vmap needs to be initialized again after setting the dropout. Try doing this

        def set_drop_out(self, dropout):
                for m in self.original_modules.modules():
                        if isinstance(m, nn.Dropout):
                                m.p = dropout
                fn, params, _ = combine_state_for_ensemble(self.original_modules)
                self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **self.kwargs)

@hanshuo-shuo
Copy link
Author

@aditya-spood Thanks a lot !

@nicklashansen
Copy link
Owner

Hi @hanshuo-shuo @aditya-spood , thanks for initiating the discussion. Is this an issue with the current code, or does it only occur with some changes? I'd like to dig into this and fix any issues with the codebase (or anything that would make it easier to use/extend). It is possible that some behaviors are inconsistent across package versions, in which case I'd like for the code to be less sensitive to the exact versions installed.

@nicklashansen nicklashansen reopened this Nov 9, 2023
@hanshuo-shuo
Copy link
Author

hanshuo-shuo commented Nov 9, 2023

@nicklashansen Hi, thanks for your reply.

https://github.com/hanshuo-shuo/tdmpc2-prey/tree/main

  • I made some changes to the envs file and changed the config.yaml
    - override hydra/launcher: basic
  • And I also change all the devices into cpu: simply change all the device into self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') And I also store the memory on cpu only.

Then during training, it produces such errors. I checked pytorch document, and I don't think torch.vmap has the wrapped attribute. But after all the changes I listed above, I can get a really good training result with the smallest model parameter setting after training for 100000 steps.

@nicklashansen
Copy link
Owner

Thanks for sharing. In that case, I suspect that it might be due to package versions. Would you mind copy/pasting the output of conda list or conda env export here? I'll see if I can reproduce the error on my end.

@hanshuo-shuo
Copy link
Author

@nicklashansen Hi, the output is

channels:
  - pytorch
  - conda-forge
dependencies:
  - aom=3.6.1=hb765f3a_0
  - brotli-python=1.1.0=py39hb198ff7_1
  - bzip2=1.0.8=h93a5062_5
  - ca-certificates=2023.7.22=hf0a4a13_0
  - cairo=1.18.0=hd1e100b_0
  - dav1d=1.2.1=hb547adb_0
  - expat=2.5.0=hb7217d7_1
  - ffmpeg=6.0.0=gpl_h1ceb99f_105
  - filelock=3.13.1=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=hab24e00_0
  - fontconfig=2.14.2=h82840c6_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - freetype=2.12.1=hadb7bae_2
  - fribidi=1.0.10=h27ca646_0
  - gettext=0.21.1=h0186832_0
  - gmp=6.2.1=h9f76cd9_0
  - gmpy2=2.1.2=py39h0b4f9c6_1
  - gnutls=3.7.8=h9f1a10d_0
  - graphite2=1.3.13=h9f76cd9_1001
  - harfbuzz=8.2.1=hf1a6348_0
  - icu=73.2=hc8870d7_0
  - idna=3.4=pyhd8ed1ab_0
  - jinja2=3.1.2=pyhd8ed1ab_1
  - lame=3.100=h1a8c8d9_1003
  - lcms2=2.15=hf2736f0_3
  - lerc=4.0.0=h9a09cb3_0
  - libass=0.17.1=hf7da4fe_1
  - libblas=3.9.0=19_osxarm64_openblas
  - libcblas=3.9.0=19_osxarm64_openblas
  - libcxx=16.0.6=h4653b0c_0
  - libdeflate=1.19=hb547adb_0
  - libexpat=2.5.0=hb7217d7_1
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=13_2_0_hd922786_1
  - libgfortran5=13.2.0=hf226fd6_1
  - libglib=2.78.1=hd9b11f9_0
  - libiconv=1.17=he4db4b2_0
  - libidn2=2.3.4=h1a8c8d9_0
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=19_osxarm64_openblas
  - libopenblas=0.3.24=openmp_hd76b1f2_0
  - libopus=1.3.1=h27ca646_1
  - libpng=1.6.39=h76d750c_0
  - libsqlite=3.43.2=h091b4b1_0
  - libtasn1=4.19.0=h1a8c8d9_0
  - libtiff=4.6.0=ha8a6c65_2
  - libunistring=0.9.10=h3422bc3_0
  - libvpx=1.13.1=hb765f3a_0
  - libwebp-base=1.3.2=hb547adb_0
  - libxcb=1.15=hf346824_0
  - libxml2=2.11.5=h25269f3_1
  - libzlib=1.2.13=h53f4e23_5
  - llvm-openmp=17.0.4=hcd81f8e_0
  - markupsafe=2.1.3=py39h0f82c59_1
  - mpc=1.3.1=h91ba8db_0
  - mpfr=4.2.1=h9546428_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - ncurses=6.4=h7ea286d_0
  - nettle=3.8.1=h63371fa_1
  - openh264=2.3.1=hb7217d7_2
  - openjpeg=2.5.0=h4c1507b_3
  - openssl=3.1.4=h0d3ecfb_0
  - p11-kit=0.24.1=h29577a5_0
  - pcre2=10.40=hb34f9b4_0
  - pillow=10.1.0=py39h755f0b7_0
  - pip=23.2.1=pyhd8ed1ab_0
  - pixman=0.42.2=h13dd4ca_0
  - pthread-stubs=0.4=h27ca646_1001
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.7=hc0da0df_3_cpython
  - python_abi=3.9=4_cp39
  - pytorch=2.1.0=py3.9_0
  - pyyaml=6.0.1=py39h0f82c59_1
  - readline=8.2=h92ec313_1
  - requests=2.31.0=pyhd8ed1ab_0
  - setuptools=68.2.2=pyhd8ed1ab_0
  - sqlite=3.43.2=hf2abe2d_0
  - svt-av1=1.7.0=hb765f3a_0
  - sympy=1.12=pypyh9d50eac_103
  - tk=8.6.13=hb31c410_0
  - torchaudio=2.1.0=py39_cpu
  - torchvision=0.16.0=py39_cpu
  - typing_extensions=4.8.0=pyha770c72_0
  - wheel=0.41.2=pyhd8ed1ab_0
  - x264=1!164.3095=h57fd34a_2
  - x265=3.5=hbc6ce65_3
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - zlib=1.2.13=h53f4e23_5
  - zstd=1.5.5=h4f39d0f_0
  - pip:
    - absl-py==2.0.0
    - antlr4-python3-runtime==4.9.3
    - astunparse==1.6.3
    - cachetools==5.3.1
    - cellworld==0.0.376
    - certifi==2023.7.22
    - charset-normalizer==3.3.0
    - chex==0.1.83
    - cloudpickle==2.2.1
    - contourpy==1.1.1
    - crafter==1.8.1
    - cv==1.0.0
    - cycler==0.12.1
    - decorator==5.1.1
    - dm-tree==0.1.8
    - farama-notifications==0.0.4
    - flatbuffers==23.5.26
    - fonttools==4.43.1
    - fsspec==2023.10.0
    - gast==0.5.4
    - google-auth==2.23.3
    - google-auth-oauthlib==1.0.0
    - google-pasta==0.2.0
    - grpcio==1.59.0
    - gym==0.26.2
    - gym-notices==0.0.8
    - gymnasium==0.29.1
    - h5py==3.10.0
    - huggingface-hub==0.17.3
    - hydra-core==1.3.2
    - imageio==2.31.5
    - importlib-metadata==6.8.0
    - importlib-resources==6.1.0
    - jax==0.4.18
    - jaxlib==0.4.18
    - json-cpp==1.0.91
    - keras==2.14.0
    - kiwisolver==1.4.5
    - libclang==16.0.6
    - markdown==3.5
    - markdown-it-py==3.0.0
    - matplotlib==3.8.0
    - mdurl==0.1.2
    - ml-dtypes==0.2.0
    - networkx==3.1
    - numpy==1.26.1
    - oauthlib==3.2.2
    - omegaconf==2.3.0
    - opensimplex==0.4.5
    - opt-einsum==3.3.0
    - optax==0.1.7
    - packaging==23.2
    - pandas==2.1.2
    - pettingzoo==1.24.1
    - protobuf==4.24.4
    - pyasn1==0.5.0
    - pyasn1-modules==0.3.0
    - pygments==2.16.1
    - pyparsing==3.1.1
    - python-dateutil==2.8.2
    - pytz==2023.3.post1
    - regex==2023.10.3
    - requests-oauthlib==1.3.1
    - rich==13.6.0
    - rsa==4.9
    - ruamel-yaml==0.17.35
    - ruamel-yaml-clib==0.2.8
    - safetensors==0.4.0
    - scipy==1.11.3
    - six==1.16.0
    - stable-baselines3==2.1.0
    - supersuit==3.9.0
    - tcp-messages==1.0.45
    - tensorboard==2.14.1
    - tensorboard-data-server==0.7.1
    - tensordict==0.2.1
    - tensordict-nightly==2023.6.8
    - tensorflow==2.14.0
    - tensorflow-estimator==2.14.0
    - tensorflow-io-gcs-filesystem==0.34.0
    - tensorflow-macos==2.14.0
    - tensorflow-probability==0.22.0
    - termcolor==2.3.0
    - tinyscaler==1.2.7
    - tokenizers==0.14.1
    - toolz==0.12.0
    - torch==2.1.0
    - torchrl==0.2.1
    - tqdm==4.66.1
    - transformers==4.34.1
    - typing-extensions==4.5.0
    - tzdata==2023.3
    - urllib3==2.0.6
    - werkzeug==3.0.0
    - wrapt==1.14.1
    - zipp==3.17.0

Sorry, my env is a mess.

@purewater0901
Copy link

purewater0901 commented Nov 15, 2023

Hello! Thank you for sharing your great work with all of us.

I also got the same errors when I ran the following command, which is from README instructions. I did not change anything in the code.

python train.py task=dog-run steps=7000000

The problem is solved with the advice in this thread.

I wonder if these revision does not change the final result.

@VitaLemonTea1
Copy link

@purewater0901 I get same error too,but it still dosn't work after I edit the code like above. Could you please tell me how you fix it.

@hanshuo-shuo
Copy link
Author

@VitaLemonTea1 I only did this part and it works quite well(Just overwrite it on previous layer's code)

I use this to solve the problem, but I am not sure this is correct though: image

@purewater0901
Copy link

@VitaLemonTea1 @hanshuo-shuo
Yes, I changed the code as @hanshuo-shuo showed here.

You can also add a new function to this class that is indicated in this thread, but I'm not sure if it's necessary.

@VitaLemonTea1
Copy link

@purewater0901 @hanshuo-shuo
Thanks a lot!

@nicklashansen
Copy link
Owner

Hi all,

Thank you for your patience. I spent a few hours investigating this today, and it appears that the existence of vmap.__wrapped__ in torch/functorch was very short-lived. I have issued a fix here f313929 which removes the modules() function call altogether. This specific part of the implementation is not critical, and I have verified that results are reproduced on a handful of tasks.

This commit also updates environment.yaml to circumvent the broken gym setup mentioned here openai/gym#3176 and here #2.

I'll close this issue, but please feel free to re-open if you run into any other issues.

DarrienMcKenzie added a commit to DarrienMcKenzie/tdmpc2 that referenced this issue Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants