From 1ac5a28fa3121b4eb6c8b82efd8ff9bef8cefe5a Mon Sep 17 00:00:00 2001 From: moussaKam Date: Wed, 23 Oct 2024 15:13:30 +0000 Subject: [PATCH 1/2] fix Kd_loss with normalize --- torchtune/modules/loss/kd_losses.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index bdeff3aa5c..53875dc813 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -30,6 +30,7 @@ def forward( student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor, + normalize: bool = True, ) -> torch.Tensor: """ Args: @@ -39,6 +40,7 @@ def forward( (batch_size*num_tokens, vocab_size). labels (torch.Tensor): Ground truth labels of shape (batch_size, vocab_size). + normalize (bool): Whether to normalize the loss by the number of unmasked elements. Returns: torch.Tensor: KL divergence loss of shape (1,). @@ -50,6 +52,8 @@ def forward( prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) x = torch.sum(prod_probs, dim=-1).view(-1) mask = (labels != self.ignore_index).int() + if not normalize: + return -torch.sum(x * mask.view(-1), dim=0) if torch.sum(mask.view(-1), dim=0) == 0: return torch.tensor(0.0, device=x.device) return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) @@ -118,15 +122,19 @@ def forward( student_logits_chunk.reshape(-1, student_logits_chunk.size(-1)) for student_logits_chunk in student_logits ] + mask = (labels != self.ignore_index).int() # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] labels = [ target_chunk.reshape(-1) for target_chunk in labels.chunk(self.num_output_chunks, dim=1) ] + total_fkl_loss = 0.0 for student_chunk, teacher_chunk, label_chunk in zip( student_logits, teacher_logits, labels ): - total_fkl_loss += self.fkl_loss(student_chunk, teacher_chunk, label_chunk) + total_fkl_loss += self.fkl_loss( + student_chunk, teacher_chunk, label_chunk, normalize=False + ) - return total_fkl_loss / self.num_output_chunks + return total_fkl_loss / torch.sum(mask.view(-1), dim=0) From aee21d4a1f5d6e00ef014fac733f33dccdf2ed85 Mon Sep 17 00:00:00 2001 From: moussaKam Date: Wed, 23 Oct 2024 17:22:58 +0000 Subject: [PATCH 2/2] Update loss values in test --- tests/recipes/test_knowledge_distillation_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py index e389460b71..f53fec4e9e 100644 --- a/tests/recipes/test_knowledge_distillation_single_device.py +++ b/tests/recipes/test_knowledge_distillation_single_device.py @@ -47,7 +47,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): def _fetch_expected_loss_values(self, model_type): loss_values_map = { - "llama3": [11.0651, 11.0577, 11.0540, 11.7671], + "llama3": [11.7898, 11.7825, 11.7788, 11.7671], } return loss_values_map[model_type]