From b67c9b64fdf138079508a0bbfe5b691c1256ec0a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 23 Sep 2024 11:16:29 +0200 Subject: [PATCH] FIX: Bug in find_minimal_target_modules (#2083) This bug was reported by Sayak and would occur if a required suffix had itself as suffix a string that was already determined to be required, in which case this required suffix would not be added. The fix consists of prefixing a "." to the suffix before checking if it is required or not. On top of this, the algorithm has been changed to be deterministic. Previously, it was not deterministic because a dictionary that was looped over was built from a set, and sets don't guarantee order. This would result in the loop being in arbitrary order. As long as the algorithm is 100% correct, the order should not matter. But in case we find bugs like this, the order does matter. We don't want bugs to be flaky, therefore it is best to sort the dict and remove randomness from the function. --------- Co-authored-by: Sayak Paul --- src/peft/tuners/tuners_utils.py | 6 +++-- tests/test_tuners_utils.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 2f195b2cbe..03b8531bfd 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -885,14 +885,16 @@ def generate_suffixes(s): # Initialize a set for required suffixes required_suffixes = set() - for item, suffixes in target_modules_suffix_map.items(): + # We sort the target_modules_suffix_map simply to get deterministic behavior, since sets have no order. In theory + # the order should not matter but in case there is a bug, it's better for the bug to be deterministic. + for item, suffixes in sorted(target_modules_suffix_map.items(), key=lambda tup: tup[1]): # Go through target_modules items, shortest suffixes first for suffix in suffixes: # If the suffix is already in required_suffixes or matches other_module_names, skip it if suffix in required_suffixes or suffix in other_module_suffixes: continue # Check if adding this suffix covers the item - if not any(item.endswith(req_suffix) for req_suffix in required_suffixes): + if not any(item.endswith("." + req_suffix) for req_suffix in required_suffixes): required_suffixes.add(suffix) break diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 4072098493..90dbea8d70 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1282,3 +1282,48 @@ def test_get_peft_model_applies_find_target_modules(self): # check that the resulting model is still the same model_check_after = sum(p.sum() for p in model.parameters()) assert model_check_sum_before == model_check_after + + def test_suffix_is_substring_of_other_suffix(self): + # This test is based on a real world bug found in diffusers. The issue was that we needed the suffix + # 'time_emb_proj' in the minimal target modules. However, if there already was the suffix 'proj' in the + # required_suffixes, 'time_emb_proj' would not be added because the test was `endswith(suffix)` and + # 'time_emb_proj' ends with 'proj'. The correct logic is to test if `endswith("." + suffix")`. The module names + # chosen here are only a subset of the hundreds of actual module names but this subset is sufficient to + # replicate the bug. + target_modules = [ + "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", + "mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj", + "up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj", + "mid_block.attentions.0.proj_out", + "up_blocks.0.attentions.0.proj_out", + "down_blocks.1.attentions.0.proj_out", + "up_blocks.0.resnets.0.time_emb_proj", + "down_blocks.0.resnets.0.time_emb_proj", + "mid_block.resnets.0.time_emb_proj", + ] + other_module_names = [ + "conv_in", + "time_proj", + "time_embedding", + "time_embedding.linear_1", + "add_time_proj", + "add_embedding", + "add_embedding.linear_1", + "add_embedding.linear_2", + "down_blocks", + "down_blocks.0", + "down_blocks.0.resnets", + "down_blocks.0.resnets.0", + "up_blocks", + "up_blocks.0", + "up_blocks.0.attentions", + "up_blocks.0.attentions.0", + "up_blocks.0.attentions.0.norm", + "up_blocks.0.attentions.0.transformer_blocks", + "up_blocks.0.attentions.0.transformer_blocks.0", + "up_blocks.0.attentions.0.transformer_blocks.0.norm1", + "up_blocks.0.attentions.0.transformer_blocks.0.attn1", + ] + expected = {"time_emb_proj", "proj", "proj_out"} + result = find_minimal_target_modules(target_modules, other_module_names) + assert result == expected