Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add a memory usage regression test to the OSS benchmark #62

Merged
merged 10 commits into from
Sep 3, 2020
51 changes: 33 additions & 18 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os
import time
from typing import Any, List
from typing import Any, List, Union, cast

import torch
import torch.distributed as dist
Expand All @@ -19,6 +19,7 @@
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop


def dist_init(rank, world_size):
Expand All @@ -36,7 +37,9 @@ def train(
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):

# DDP
dist_init(rank, world_size)

Expand All @@ -50,21 +53,18 @@ def collate(inputs: List[Any]):
"label": torch.stack([i[1] for i in inputs]).to(rank),
}

def print_(msg):
if dist.get_rank() == 0:
print(msg)

dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)

# Shard the optimizer
optimizer = (
OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9)
if use_oss
else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
)
optimizer: Union[OSS, OPTIM] = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)

# Dummy training loop
torch.cuda.synchronize(rank)
Expand All @@ -90,8 +90,19 @@ def closure():
optimizer.step(closure)

epoch_end = time.monotonic()

if use_oss:
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
# _ = optimizer.state_dict()
print("... State dict collected")

measurements.append(data_size / (epoch_end - epoch_start))
print_(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
Expand All @@ -101,13 +112,15 @@ def closure():
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

if use_oss and check_regression and dist.get_rank() == 0:
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}")
assert (mean - 3.0 * std) < reference_speed, "Regression detected"
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a 5% tolerance here, guessing that some CUDA or torch version changes could affect (I don't have any STD for this value, it's the max memory used over the whole run)

print("[Regression Test] VALID")


Expand All @@ -122,10 +135,11 @@ def closure():
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)

args = parser.parse_args()

print("\nBenchmark vanilla SGD")
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
Expand All @@ -144,6 +158,7 @@ def closure():
True,
args.check_regression,
args.reference_speed,
args.reference_memory,
),
nprocs=args.world_size,
join=True,
Expand Down