-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization for MF #17
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,18 +74,46 @@ | |
class MFModel(torch.nn.Module, PyTorchModelHubMixin): | ||
def __init__( | ||
self, | ||
dim, | ||
num_models, | ||
text_dim, | ||
num_classes, | ||
use_proj, | ||
dim=128, | ||
num_models=64, | ||
text_dim=768, | ||
num_classes=1, | ||
use_proj=True, | ||
collapse_linear=False, | ||
embedding_model="all-mpnet-base-v2", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be better if we didn't set any default args here and specify default args only in |
||
): | ||
""" | ||
Args: | ||
dim: | ||
Dimension of the model embeddings, default to 128 | ||
num_models: | ||
Number of models, default to 64 | ||
text_dim: | ||
Dimension of the text embeddings | ||
1536 for OpenAI's text-embedding-3-small | ||
768 for all-mpnet-base-v2 | ||
1024 for infgrad/stella_en_400M_v5 | ||
num_classes: | ||
Number of classes, default to 1, output a scalar | ||
use_proj: | ||
Whether to use projection for the text embeddings | ||
This is set to be True in our pretrained models for better performance | ||
collapse_linear: | ||
Whether to collapse the linear transformations into a single linear layer | ||
Since the current pretrained models only consist of Linear layers, | ||
we can collapse them into a single layer for faster inference | ||
See https://github.com/lm-sys/RouteLLM/issues/9 | ||
embedding_model: | ||
Text embedding model for the prompt, should be the same as the one used in training | ||
Use all-mpnet-base-v2 to avoid OpenAI's key, however, slightly worse performance | ||
Use OpenAI's text-embedding-3-small for better performance | ||
""" | ||
super().__init__() | ||
self._name = "TextMF" | ||
self.use_proj = use_proj | ||
self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer | ||
self.P = torch.nn.Embedding(num_models, dim) | ||
|
||
self.embedding_model = "text-embedding-3-small" | ||
self.embedding_model = embedding_model | ||
|
||
if self.use_proj: | ||
self.text_proj = torch.nn.Sequential( | ||
|
@@ -104,19 +132,35 @@ def get_device(self): | |
return self.P.weight.device | ||
|
||
def forward(self, model_id, prompt): | ||
if self.embedding_model == "text-embedding-3-small": | ||
prompt_embed = ( | ||
OPENAI_CLIENT.embeddings.create( | ||
input=[prompt], model=self.embedding_model | ||
) | ||
.data[0] | ||
.embedding | ||
) | ||
elif self.embedding_model == "all-mpnet-base-v2": | ||
prompt_embed = self._embedding_model.encode([prompt]) | ||
elif self.embedding_model == "infgrad/stella_en_400M_v5": | ||
prompt_embed = self._embedding_model.encode( | ||
[prompt], prompt_name="s2s_query" | ||
) | ||
else: | ||
raise ValueError( | ||
f"Unsupported embedding model {self.embedding_model}, " | ||
"should be one of text-embedding-3-small, all-mpnet-base-v2, infgrad/stella_en_400M_v5" | ||
) | ||
|
||
prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) | ||
model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) | ||
|
||
model_embed = self.P(model_id) | ||
model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) | ||
if self.collapse_linear: | ||
upscaled_model_embed = self.precompute_upscaled_embedding(model_id) | ||
return upscaled_model_embed @ prompt_embed.squeeze(-1) | ||
|
||
prompt_embed = ( | ||
OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) | ||
.data[0] | ||
.embedding | ||
) | ||
prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) | ||
model_embed = self.P(model_id) | ||
prompt_embed = self.text_proj(prompt_embed) | ||
|
||
return self.classifier(model_embed * prompt_embed).squeeze() | ||
|
||
@torch.no_grad() | ||
|
@@ -127,3 +171,32 @@ def pred_win_rate(self, model_a, model_b, prompt): | |
|
||
def load(self, path): | ||
self.load_state_dict(torch.load(path)) | ||
|
||
def post_process_weight(self): | ||
# since the current model consist of only linear transformations | ||
# we can collapse the linear transformations into a single linear layer | ||
# https://github.com/lm-sys/RouteLLM/issues/9 | ||
num_models = self.P.weight.shape[0] | ||
text_dim = self.text_proj[0].weight.shape[1] | ||
|
||
self.P.weight.data = torch.nn.functional.normalize( | ||
self.P.weight.data, p=2, dim=1 | ||
) | ||
|
||
if ( | ||
self.embedding_model == "all-mpnet-base-v2" | ||
or self.embedding_model == "infgrad/stella_en_400M_v5" | ||
): | ||
from sentence_transformers import SentenceTransformer | ||
|
||
self._embedding_model = SentenceTransformer( | ||
self.embedding_model, trust_remote_code=True | ||
).to("cuda") | ||
|
||
if self.collapse_linear: | ||
self.precompute_upscaled_embedding = torch.nn.Embedding( | ||
num_models, text_dim | ||
) | ||
self.precompute_upscaled_embedding.weight.data = ( | ||
self.P.weight * self.classifier[0].weight.data | ||
) @ self.text_proj[0].weight.data |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,9 +217,10 @@ def __init__( | |
weak_model="mixtral-8x7b-instruct-v0.1", | ||
hidden_size=128, | ||
num_models=64, | ||
text_dim=1536, | ||
text_dim=768, | ||
num_classes=1, | ||
use_proj=True, | ||
embedding_model="all-mpnet-base-v2", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe better to use OpenAI's embeddings by default to preserve our existing behavior? I'll add some docs to describe the different options here. |
||
): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
@@ -230,7 +231,10 @@ def __init__( | |
text_dim=text_dim, | ||
num_classes=num_classes, | ||
use_proj=use_proj, | ||
embedding_model=embedding_model, | ||
) | ||
|
||
self.model.post_process_weight() | ||
self.model = self.model.eval().to(device) | ||
self.strong_model_id = MODEL_IDS[strong_model] | ||
self.weak_model_id = MODEL_IDS[weak_model] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we automatically set the text dim based on the selected embeddings?