Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for MF #17

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 89 additions & 16 deletions routellm/routers/matrix_factorization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,46 @@
class MFModel(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
self,
dim,
num_models,
text_dim,
num_classes,
use_proj,
dim=128,
num_models=64,
text_dim=768,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we automatically set the text dim based on the selected embeddings?

num_classes=1,
use_proj=True,
collapse_linear=False,
embedding_model="all-mpnet-base-v2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better if we didn't set any default args here and specify default args only in routers.py so that it's easier to keep track of router configs.

):
"""
Args:
dim:
Dimension of the model embeddings, default to 128
num_models:
Number of models, default to 64
text_dim:
Dimension of the text embeddings
1536 for OpenAI's text-embedding-3-small
768 for all-mpnet-base-v2
1024 for infgrad/stella_en_400M_v5
num_classes:
Number of classes, default to 1, output a scalar
use_proj:
Whether to use projection for the text embeddings
This is set to be True in our pretrained models for better performance
collapse_linear:
Whether to collapse the linear transformations into a single linear layer
Since the current pretrained models only consist of Linear layers,
we can collapse them into a single layer for faster inference
See https://github.com/lm-sys/RouteLLM/issues/9
embedding_model:
Text embedding model for the prompt, should be the same as the one used in training
Use all-mpnet-base-v2 to avoid OpenAI's key, however, slightly worse performance
Use OpenAI's text-embedding-3-small for better performance
"""
super().__init__()
self._name = "TextMF"
self.use_proj = use_proj
self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer
self.P = torch.nn.Embedding(num_models, dim)

self.embedding_model = "text-embedding-3-small"
self.embedding_model = embedding_model

if self.use_proj:
self.text_proj = torch.nn.Sequential(
Expand All @@ -104,19 +132,35 @@ def get_device(self):
return self.P.weight.device

def forward(self, model_id, prompt):
if self.embedding_model == "text-embedding-3-small":
prompt_embed = (
OPENAI_CLIENT.embeddings.create(
input=[prompt], model=self.embedding_model
)
.data[0]
.embedding
)
elif self.embedding_model == "all-mpnet-base-v2":
prompt_embed = self._embedding_model.encode([prompt])
elif self.embedding_model == "infgrad/stella_en_400M_v5":
prompt_embed = self._embedding_model.encode(
[prompt], prompt_name="s2s_query"
)
else:
raise ValueError(
f"Unsupported embedding model {self.embedding_model}, "
"should be one of text-embedding-3-small, all-mpnet-base-v2, infgrad/stella_en_400M_v5"
)

prompt_embed = torch.tensor(prompt_embed, device=self.get_device())
model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device())

model_embed = self.P(model_id)
model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1)
if self.collapse_linear:
upscaled_model_embed = self.precompute_upscaled_embedding(model_id)
return upscaled_model_embed @ prompt_embed.squeeze(-1)

prompt_embed = (
OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model)
.data[0]
.embedding
)
prompt_embed = torch.tensor(prompt_embed, device=self.get_device())
model_embed = self.P(model_id)
prompt_embed = self.text_proj(prompt_embed)

return self.classifier(model_embed * prompt_embed).squeeze()

@torch.no_grad()
Expand All @@ -127,3 +171,32 @@ def pred_win_rate(self, model_a, model_b, prompt):

def load(self, path):
self.load_state_dict(torch.load(path))

def post_process_weight(self):
# since the current model consist of only linear transformations
# we can collapse the linear transformations into a single linear layer
# https://github.com/lm-sys/RouteLLM/issues/9
num_models = self.P.weight.shape[0]
text_dim = self.text_proj[0].weight.shape[1]

self.P.weight.data = torch.nn.functional.normalize(
self.P.weight.data, p=2, dim=1
)

if (
self.embedding_model == "all-mpnet-base-v2"
or self.embedding_model == "infgrad/stella_en_400M_v5"
):
from sentence_transformers import SentenceTransformer

self._embedding_model = SentenceTransformer(
self.embedding_model, trust_remote_code=True
).to("cuda")

if self.collapse_linear:
self.precompute_upscaled_embedding = torch.nn.Embedding(
num_models, text_dim
)
self.precompute_upscaled_embedding.weight.data = (
self.P.weight * self.classifier[0].weight.data
) @ self.text_proj[0].weight.data
6 changes: 5 additions & 1 deletion routellm/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@ def __init__(
weak_model="mixtral-8x7b-instruct-v0.1",
hidden_size=128,
num_models=64,
text_dim=1536,
text_dim=768,
num_classes=1,
use_proj=True,
embedding_model="all-mpnet-base-v2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to use OpenAI's embeddings by default to preserve our existing behavior? I'll add some docs to describe the different options here.

):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -230,7 +231,10 @@ def __init__(
text_dim=text_dim,
num_classes=num_classes,
use_proj=use_proj,
embedding_model=embedding_model,
)

self.model.post_process_weight()
self.model = self.model.eval().to(device)
self.strong_model_id = MODEL_IDS[strong_model]
self.weak_model_id = MODEL_IDS[weak_model]
Expand Down