Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric logger improvements #831

Merged
merged 13 commits into from
Apr 25, 2024
Merged

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Apr 22, 2024

This is inspired by @tcapelle's changes in #730 so props to him for getting the ball rolling here.

Quick primer/definition to make what follows clearer: by iteration I mean a single batch from our dataloader. By step I mean a single step with our optimizer. So without gradient accumulation iteration == step, with gradient accumulation there are gradient_accumulation_steps iterations per step (maybe the naming of this field is itself. confusing tbh)

The main set of changes here is:

  • Actually log once on every step rather than on every iteration of a step (as is being done here. This only matters when using gradient accumulation, but that's actually done for a lot of default configs so we need to get it right.
  • Accumulate loss correctly over iterations for logging. E.g. instead of just taking the loss on the final iteration, we sum it up over each iteration (normalized by gradient_accumulation_steps).
  • Add tokens_per_second metric. Similarly this is accumulated over iterations to log the per-step average.
  • Add a config field to control whether we log peak memory or not. Currently this was just a hardcoded log_peak_memory_every_n_steps=100 field, which (a) was completely arbitrary and (b) bloated up our training loop with multiple log_dict calls and multiple different logging frequencies. For BC with existing configs I do a default check in the recipe and set it to False
  • Fix our pbar to also be per-step instead of per-iteration. Otherwise it doesn't line up with our metric logger which is confusing
  • Skip our wandb logger test (for now). It was never actually running and it's broken 😢

Test plan

Full finetune single device

tune run full_finetune_single_device --config llama3/8B_full_single_device metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug optimizer_in_bwd=False gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 
epochs=1
...
1|10|Loss: 0.9859869480133057: 100%|█████████████| 10/10 [00:15<00:00,  1.14s/it]

Metric logger outputs

Full finetune distributed

tune run --nproc_per_node 2 full_finetune_distributed --config gemma/2B_full metric_logger=t
orchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 epochs=1
...
1|10|Loss: 1.1142131090164185: 100%|█████████| 10/10 [00:15<00:00,  1.50s/it]

Metric logger outputs

LoRA finetune single device

tune run lora_finetune_single_device --config llama2/7B_lora_single_device metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 epochs=1
...
1|10|Loss: 1.6541345119476318: 100%|██████████████| 10/10 [00:07<00:00,  1.59it/s]

Metric logger outputs

LoRA finetune distributed

tune run --nproc_per_node 2 lora_finetune_distributed --config mistral/7B_lora metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 epochs=1
...
1|5|Loss: 2.1742944717407227:  50%|█████                                                                 | 5/10 [00:16<00:16,  3.24s/it]

Metric logger outputs

LoRA DPO single device

tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 epochs=1
...
1|7|Loss: 0.6841517686843872:  70%|█████████                                       | 7/10 [00:14<00:05,  1.91s/it]

Metric logger outputs

LoRA DPO distributed

tune run --nproc_per_node 2 lora_dpo_distributed --config llama2/7B_lora_dpo metric_logger=torchtune.utils.
metric_logging.WandBLogger metric_logger.project=lora-debug gradient_accumulation_steps=2 log_every_n_steps=1 log_peak_memory_stats=True max_steps_per_epoch=10 epochs=1
...
1|8|Loss: 0.6825119256973267:  80%|████████                          | 8/10 [00:20<00:04,  2.42s/it]

Metric logger outputs

Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/831

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d66919e with merge base a46560e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@tcapelle
Copy link
Contributor

Hey, can I help you here? Looks similiar to what I was working: #730

tcapelle added a commit to tcapelle/torchtune that referenced this pull request Apr 22, 2024
@ebsmothers
Copy link
Contributor Author

Hey, can I help you here? Looks similiar to what I was working: #730

@tcapelle thanks, yeah actually this started from trying to get the gradient accumulation test on #730 to pass and kinda expanded from there. If it's easiest for you, I am happy to just let you commandeer this PR so you don't have to go adding the changes to all the other recipes. Let me know what you'd prefer.

@ebsmothers ebsmothers mentioned this pull request Apr 22, 2024
@ebsmothers ebsmothers changed the title very wip metric logger improvements Metric logger improvements Apr 24, 2024
@ebsmothers ebsmothers marked this pull request as ready for review April 24, 2024 17:48
@@ -107,7 +107,11 @@ def __init__(self, cfg: DictConfig) -> None:
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_peak_memory_every_n_steps = 100
self._log_peak_memory_stats = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._log_peak_memory_stats = (
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

@@ -107,7 +107,11 @@ def __init__(self, cfg: DictConfig) -> None:
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you're here

Suggested change
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)

log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second": num_tokens / time_per_step,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to multiply num_tokens per rank, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a good point. Actually I'm not sure what the right thing to do here is for two reasons. (1) We only log rank-zero loss, and (2) I think a naive multiplication with num_devices will not work. If I understand correctly DistributedSampler will apply collate_fn to each worker's batch separately, in which case we actually have different numbers of tokens on each rank, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the more I think about it the less I want to change this. (1) it will introduce unneeded syncs to calculate this value properly and (2) we could just log on every rank if we really want to and aggregate posthoc

@@ -126,7 +127,8 @@ def test_log_dict(self) -> None:
assert_expected(tensor_tag.step, 1)


class WandBLoggerTest:
@pytest.mark.skip(reason="This was never running and needs to be fixed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

@ebsmothers ebsmothers merged commit 6cf1a7d into pytorch:main Apr 25, 2024
27 checks passed
@ebsmothers ebsmothers deleted the logging-fixes branch April 25, 2024 02:03
@tcapelle
Copy link
Contributor

great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants