From 66e14ead782cbd9f4edb15ee301ef29393dcb15b Mon Sep 17 00:00:00 2001 From: bandish-shah <86627118+bandish-shah@users.noreply.github.com> Date: Mon, 27 Nov 2023 20:10:28 -0800 Subject: [PATCH] Override NVIDIA environment variable for CUDA 12.1 images (#2742) * Update NVIDIA_REQUIRE_CUDA_OVERRIDE env variable for CUDA 12.1 Docker images --- docker/build_matrix.yaml | 30 ++++++++++++++++-- docker/generate_build_matrix.py | 55 ++++++++++++++++++++++++++------- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index 645094868d..c6ba59e547 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -4,7 +4,20 @@ CUDA_VERSION: 12.1.0 IMAGE_NAME: torch-2-1-0-cu121 MOFED_VERSION: 5.5-1.0.3.2 - NVIDIA_REQUIRE_CUDA_OVERRIDE: '' + NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 + brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 + brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 + brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 + brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511 + brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 + brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511 + brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 + brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516 + brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 + brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526 + brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 + brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 + brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 PYTHON_VERSION: '3.10' PYTORCH_NIGHTLY_URL: '' PYTORCH_NIGHTLY_VERSION: '' @@ -19,7 +32,20 @@ CUDA_VERSION: 12.1.0 IMAGE_NAME: torch-2-1-0-cu121-aws MOFED_VERSION: '' - NVIDIA_REQUIRE_CUDA_OVERRIDE: '' + NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 + brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 + brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 + brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 + brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511 + brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 + brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511 + brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 + brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516 + brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 + brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526 + brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 + brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 + brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 PYTHON_VERSION: '3.10' PYTORCH_NIGHTLY_URL: '' PYTORCH_NIGHTLY_VERSION: '' diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index 690b71fec8..3e8556307f 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -56,6 +56,48 @@ def _get_cuda_version_tag(cuda_version: str): return 'cu' + ''.join(cuda_version.split('.')[:2]) +def _get_cuda_override(cuda_version: str): + if cuda_version == '12.1.0': + cuda_121_override_string = ('cuda>=12.1 brand=tesla,driver>=450,driver<451 ' + 'brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 ' + 'brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 ' + 'brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 ' + 'brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 ' + 'brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 ' + 'brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 ' + 'brand=nvidia,driver>=510,driver<511 brand=nvidiartx,driver>=510,driver<511 ' + 'brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 ' + 'brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 ' + 'brand=titan,driver>=510,driver<511 brand=titanrtx,driver>=510,driver<511 ' + 'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 ' + 'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 ' + 'brand=geforce,driver>=515,driver<516 brand=geforcertx,driver>=515,driver<516 ' + 'brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 ' + 'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 ' + 'brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 ' + 'brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 ' + 'brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 ' + 'brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 ' + 'brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526') + + return cuda_121_override_string + + if cuda_version == '11.8.0': + cuda_118_override_string = ('cuda>=11.8 brand=tesla,driver>=470,driver<471 ' + 'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=470,driver<471 ' + 'brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=470,driver<471 ' + 'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=470,driver<471 ' + 'brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=470,driver<471 ' + 'brand=geforce,driver>=515,driver<516 brand=quadro,driver>=470,driver<471 ' + 'brand=quadro,driver>=515,driver<516 brand=titan,driver>=470,driver<471 ' + 'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=470,driver<471 ' + 'brand=titanrtx,driver>=515,driver<516') + + return cuda_118_override_string + + return '' + + def _get_pytorch_tags(python_version: str, pytorch_version: str, cuda_version: str, stage: str, interconnect: str): if stage == 'pytorch_stage': base_image_name = 'mosaicml/pytorch' @@ -136,17 +178,6 @@ def _main(): cuda_version = _get_cuda_version(pytorch_version=pytorch_version, use_cuda=use_cuda) - override_string = ('cuda>=11.8 brand=tesla,driver>=470,driver<471 ' - 'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=470,driver<471 ' - 'brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=470,driver<471 ' - 'brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=470,driver<471 ' - 'brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=470,driver<471 ' - 'brand=geforce,driver>=515,driver<516 brand=quadro,driver>=470,driver<471 ' - 'brand=quadro,driver>=515,driver<516 brand=titan,driver>=470,driver<471 ' - 'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=470,driver<471 ' - 'brand=titanrtx,driver>=515,driver<516') - nvidia_require_cuda_override = '' if cuda_version != '11.8.0' else override_string - entry = { 'IMAGE_NAME': _get_image_name(pytorch_version, cuda_version, stage, interconnect), @@ -175,7 +206,7 @@ def _main(): 'PYTORCH_NIGHTLY_VERSION': '', 'NVIDIA_REQUIRE_CUDA_OVERRIDE': - nvidia_require_cuda_override, + _get_cuda_override(cuda_version), } # Only build EFA image on latest python with cuda on pytorch_stage