Skip to content

Commit

Permalink
Override NVIDIA environment variable for CUDA 12.1 images (#2742)
Browse files Browse the repository at this point in the history
* Update NVIDIA_REQUIRE_CUDA_OVERRIDE env variable for CUDA 12.1 Docker images
  • Loading branch information
bandish-shah authored Nov 28, 2023
1 parent 2b3e2a6 commit 66e14ea
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
30 changes: 28 additions & 2 deletions docker/build_matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,20 @@
CUDA_VERSION: 12.1.0
IMAGE_NAME: torch-2-1-0-cu121
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
Expand All @@ -19,7 +32,20 @@
CUDA_VERSION: 12.1.0
IMAGE_NAME: torch-2-1-0-cu121-aws
MOFED_VERSION: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
Expand Down
55 changes: 43 additions & 12 deletions docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,48 @@ def _get_cuda_version_tag(cuda_version: str):
return 'cu' + ''.join(cuda_version.split('.')[:2])


def _get_cuda_override(cuda_version: str):
if cuda_version == '12.1.0':
cuda_121_override_string = ('cuda>=12.1 brand=tesla,driver>=450,driver<451 '
'brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 '
'brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 '
'brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 '
'brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 '
'brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 '
'brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 '
'brand=nvidia,driver>=510,driver<511 brand=nvidiartx,driver>=510,driver<511 '
'brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 '
'brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 '
'brand=titan,driver>=510,driver<511 brand=titanrtx,driver>=510,driver<511 '
'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 '
'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 '
'brand=geforce,driver>=515,driver<516 brand=geforcertx,driver>=515,driver<516 '
'brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 '
'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 '
'brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 '
'brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 '
'brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 '
'brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 '
'brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526')

return cuda_121_override_string

if cuda_version == '11.8.0':
cuda_118_override_string = ('cuda>=11.8 brand=tesla,driver>=470,driver<471 '
'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=470,driver<471 '
'brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=470,driver<471 '
'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=470,driver<471 '
'brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=470,driver<471 '
'brand=geforce,driver>=515,driver<516 brand=quadro,driver>=470,driver<471 '
'brand=quadro,driver>=515,driver<516 brand=titan,driver>=470,driver<471 '
'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=470,driver<471 '
'brand=titanrtx,driver>=515,driver<516')

return cuda_118_override_string

return ''


def _get_pytorch_tags(python_version: str, pytorch_version: str, cuda_version: str, stage: str, interconnect: str):
if stage == 'pytorch_stage':
base_image_name = 'mosaicml/pytorch'
Expand Down Expand Up @@ -136,17 +178,6 @@ def _main():

cuda_version = _get_cuda_version(pytorch_version=pytorch_version, use_cuda=use_cuda)

override_string = ('cuda>=11.8 brand=tesla,driver>=470,driver<471 '
'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=470,driver<471 '
'brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=470,driver<471 '
'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=470,driver<471 '
'brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=470,driver<471 '
'brand=geforce,driver>=515,driver<516 brand=quadro,driver>=470,driver<471 '
'brand=quadro,driver>=515,driver<516 brand=titan,driver>=470,driver<471 '
'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=470,driver<471 '
'brand=titanrtx,driver>=515,driver<516')
nvidia_require_cuda_override = '' if cuda_version != '11.8.0' else override_string

entry = {
'IMAGE_NAME':
_get_image_name(pytorch_version, cuda_version, stage, interconnect),
Expand Down Expand Up @@ -175,7 +206,7 @@ def _main():
'PYTORCH_NIGHTLY_VERSION':
'',
'NVIDIA_REQUIRE_CUDA_OVERRIDE':
nvidia_require_cuda_override,
_get_cuda_override(cuda_version),
}

# Only build EFA image on latest python with cuda on pytorch_stage
Expand Down

0 comments on commit 66e14ea

Please sign in to comment.