From 8168a441afe1b222855125cf884bcb5e34f03b29 Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Mon, 3 Jul 2023 09:37:10 +0800 Subject: [PATCH 1/2] add Position Interpolation for inference scripts --- scripts/inference/gradio_demo.py | 20 ++++++++++++++++++ scripts/inference/inference_hf.py | 20 ++++++++++++++++++ .../openai_server_demo/openai_api_server.py | 21 +++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index c417446..e9aac18 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -12,6 +12,26 @@ import argparse import os +import transformers +def pi_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: # seq_len > 2048 + print(f"Perform position interpolation for length {seq_len}") + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + scale = self.max_seq_len_cached / seq_len + t *= scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward # Parse command-line arguments parser = argparse.ArgumentParser() diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index a73ed5a..b9fc505 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -18,6 +18,26 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from peft import PeftModel +import transformers +def pi_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: # seq_len > 2048 + print(f"Perform position interpolation for length {seq_len}") + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + scale = self.max_seq_len_cached / seq_len + t *= scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward generation_config = dict( temperature=0.2, diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 0a93a65..a6a70c5 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -21,6 +21,27 @@ from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig from peft import PeftModel +import transformers +def pi_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: # seq_len > 2048 + print(f"Perform position interpolation for length {seq_len}") + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + scale = self.max_seq_len_cached / seq_len + t *= scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward + from openai_api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, From 8556b8f616ee159f507fa2384b199577c5ebeb52 Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Wed, 5 Jul 2023 16:39:09 +0800 Subject: [PATCH 2/2] replace position interploation with NTK method --- scripts/inference/gradio_demo.py | 24 +++++++++++++------ scripts/inference/inference_hf.py | 24 +++++++++++++------ .../openai_server_demo/openai_api_server.py | 24 +++++++++++++------ 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index e9aac18..1c37003 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -13,13 +13,22 @@ import os import transformers -def pi_forward(self, x, seq_len=None): - if seq_len > self.max_seq_len_cached: # seq_len > 2048 - print(f"Perform position interpolation for length {seq_len}") +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - scale = self.max_seq_len_cached / seq_len - t *= scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) cos_cached = emb.cos()[None, None, :, :] sin_cached = emb.sin()[None, None, :, :] @@ -31,7 +40,8 @@ def pi_forward(self, x, seq_len=None): self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) ) -transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init # Parse command-line arguments parser = argparse.ArgumentParser() diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index b9fc505..817a823 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -19,13 +19,22 @@ from peft import PeftModel import transformers -def pi_forward(self, x, seq_len=None): - if seq_len > self.max_seq_len_cached: # seq_len > 2048 - print(f"Perform position interpolation for length {seq_len}") +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - scale = self.max_seq_len_cached / seq_len - t *= scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) cos_cached = emb.cos()[None, None, :, :] sin_cached = emb.sin()[None, None, :, :] @@ -37,7 +46,8 @@ def pi_forward(self, x, seq_len=None): self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) ) -transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init generation_config = dict( temperature=0.2, diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index a6a70c5..c91d541 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -22,13 +22,22 @@ from peft import PeftModel import transformers -def pi_forward(self, x, seq_len=None): - if seq_len > self.max_seq_len_cached: # seq_len > 2048 - print(f"Perform position interpolation for length {seq_len}") +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - scale = self.max_seq_len_cached / seq_len - t *= scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) cos_cached = emb.cos()[None, None, :, :] sin_cached = emb.sin()[None, None, :, :] @@ -40,7 +49,8 @@ def pi_forward(self, x, seq_len=None): self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) ) -transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init from openai_api_protocol import ( ChatCompletionRequest,