Skip to content

Commit

Permalink
Torch TRT ngc container changes (#3299)
Browse files Browse the repository at this point in the history
Co-authored-by: Dheeraj Peri <[email protected]>
  • Loading branch information
apbose and peri044 authored Dec 13, 2024
1 parent b4bc713 commit 3982401
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
10 changes: 10 additions & 0 deletions core/util/Exception.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
#if defined(__GNUC__) && !defined(__clang__)
#if __GNUC__ >= 13
#include <cstdint>
#endif
#elif defined(__clang__)
#if __clang_major__ >= 13
#include <cstdint>
#endif
#endif

#include "core/util/Exception.h"

#include <iostream>
Expand Down
7 changes: 3 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,12 @@ def run_dynamo_runtime_tests(session):
tests = [
"runtime",
]
skip_tests = "-k not hw_compat"
for test in tests:
if USE_HOST_DEPS:
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
else:
session.run_always("pytest", test)
session.run_always("pytest", test, skip_tests)


def run_dynamo_model_compile_tests(session):
Expand Down Expand Up @@ -332,7 +333,6 @@ def run_int8_accuracy_tests(session):
tests = [
"ptq/test_ptq_to_backend.py",
"ptq/test_ptq_dataloader_calibrator.py",
"qat/",
]
for test in tests:
if USE_HOST_DEPS:
Expand Down Expand Up @@ -473,7 +473,6 @@ def run_l1_int8_accuracy_tests(session):
install_deps(session)
install_torch_trt(session)
train_model(session)
finetune_model(session)
run_int8_accuracy_tests(session)
cleanup(session)

Expand Down
15 changes: 15 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@

import torch
import torch_tensorrt
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import TestCase, run_tests

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing

isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [
(8, 6),
(8, 7),
(8, 9),
]


class TestInputAsOutput(TestCase):
def test_input_as_output(self):
Expand Down Expand Up @@ -279,6 +286,10 @@ def forward(self, q, k, v):
"Test not supported on Windows",
)
class TestLowerFlashAttention(TestCase):
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
"Does not support fused SDPA or not SM86+ hardware",
)
def test_lower_flash_attention(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
Expand Down Expand Up @@ -348,6 +359,10 @@ def forward(self, q, k, v):
)
torch._dynamo.reset()

@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
"Does not support fused SDPA or not SM86+ hardware",
)
def test_flash_attention_converter(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
Expand Down

0 comments on commit 3982401

Please sign in to comment.