Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update quantization to use tensor subclasses #1403

Merged
merged 1 commit into from
Sep 12, 2024

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Aug 23, 2024

Summary: In torchao, we are migrating our quantization flows from module swap to tensor subclasses. The existing Int8DynActInt4WeightQuantizer will be deprecated in the near future in favor of quantizing using the quantize_ API, so we should do the same in torchtune. This quantizer is currently only used by QAT, which also recently migrated to a tensor subclass implementation.

This also changes the eval script slightly since models quantized through the torchao tensor subclasses are expected to be loaded with assign=True (see this test). We should load the model similarly in torchtune.

Test Plan:

Quantized and evaluated the base Llama3-8B model on 1 A100 GPU:

CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated.00.pth] \
    checkpointer.model_type=LLAMA3

CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated-8da4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer

Reviewers: ebsmothers, kartikayk, RdoubleA

Subscribers: ebsmothers, kartikayk, RdoubleA

Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1403

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d29833d with merge base 7c51100 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 23, 2024
@@ -22,7 +22,19 @@


if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.quant_api import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers Can you remind me what our principles are around version guards? now that 2.4 is launched do we just claim we work with stable and remove such guards? or what's the downside of doing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can remove, our assumption is that users are at least on the latest stable version of PyTorch (so at this moment 2.4)

torchtune/utils/quantization.py Outdated Show resolved Hide resolved
# importing TORCH_VERSION_AFTER_2_3 because `Int8DynActInt4WeightQuantizer`
# is only available after 2.3 so we have to guard the pytorch versions to decide
# the list of supported quantizers
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @msaroufim @ebsmothers Just wanted to confirm this is OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to ask for the 100th time, but TORCH_VERSION_AFTER_2_4 returns True if and only if the PyTorch version is >= 2.4, right? If so we can remove both of these, since we assume everyone is on at least latest stable PyTorch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark can confirm but yes I think TORCH_VERSION_AFTER_2_4 means >= 2.4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is true only in torchao nightlies. This was fixed. Before it was >2.4

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the changes look reasonable to me. My main question is whether there is any logic that's BC breaking (i.e. can I run this as is on ao's 0.3.1, latest nightly, and anything in between?)

@andrewor14
Copy link
Contributor Author

Overall the changes look reasonable to me. My main question is whether there is any logic that's BC breaking (i.e. can I run this as is on ao's 0.3.1, latest nightly, and anything in between?)

Good question. I haven't tested an older version of ao, but this change is needed to run the full end-to-end flow after the QAT subclass refactor. If there are concerns about breaking BC, maybe we should wait until we upgrade to torchao 0.5.0? (expected maybe mid-September)

**Summary:** In torchao, we are migrating our quantization flows
from module swap to tensor subclasses. The existing
`Int8DynActInt4WeightQuantizer` will be deprecated in the near
future in favor of quantizing using the `quantize_` API,
so we should do the same in torchtune. This quantizer is
currently only used by QAT, which also recently migrated to
a tensor subclass implementation.

This also changes the eval script slightly since models
quantized through the torchao tensor subclasses are expected
to be loaded with `assign=True`: https://github.com/pytorch/ao/blob/9a56e80cb6070599701b8f5f587bd8187c8dccb4/test/quantization/test_quant_api.py#L610.
We should load the model similarly in torchtune.

**Test Plan:**

Quantized and evaluated the base Llama3-8B model on 1 A100 GPU:

```
CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/my_quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated.00.pth] \
    checkpointer.model_type=LLAMA3

CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated-8da4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer
```

Reviewers: ebsmothers, kartikayk, RdoubleA

Subscribers: ebsmothers, kartikayk, RdoubleA

Subscribers:
@andrewor14
Copy link
Contributor Author

Hi @ebsmothers, I think this is ready from my side. I saw that the torchao version is actually removed from pyproject.toml. Do we expect torchtune to only work with the latest version? Either way I tested the quantization and eval with both the module swap QAT (old torchao version) and the tensor subclass QAT (latest torchao version) and both work. Please take another look. Thanks.

@ebsmothers
Copy link
Contributor

@andrewor14 thanks for the ping. Yeah currently we will install whatever the latest stable version of torchao is. This is what we do in our CI and what we recommend our users to do as well (ofc we also support nightlies). So yes, the latest version of torchao should be sufficient here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks! Once CI is green I think we're good to merge

@ebsmothers ebsmothers merged commit 6b43a1c into pytorch:main Sep 12, 2024
17 checks passed
@andrewor14 andrewor14 deleted the quantize-subclass branch September 12, 2024 22:46
@gau-nernst
Copy link
Contributor

gau-nernst commented Sep 13, 2024

Just curious, does this work with FSDP2 recipes? Trying out INT8 mixed-precision with torchtune and I saw that in distributed recipes, load state dict uses assign=True, which will override the tensor subclass.

return model.load_state_dict(sharded_sd, strict=strict, assign=True)

This is because we call torchao's quantize_() before loading state dict. @andrewor14 Did you check that the subclass ops are actually called?

One fix for this is to add DTensor support for quantize_() API (pytorch/ao#803), which allows us to do quantize_() after fully_shard() and load state dict.

Update: I tried manually swapping tensor subclass after fully_shard() but was not successful. Seemed like the naive way will erase some FSDP attributes/classes. I also noticed that there are special logic to handle NF4 within load_from_full_model_state_dict()

@ebsmothers ebsmothers mentioned this pull request Sep 25, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants