diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index c417446..1c37003 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -12,6 +12,36 @@ import argparse import os +import transformers +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init # Parse command-line arguments parser = argparse.ArgumentParser() diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index a73ed5a..817a823 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -18,6 +18,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from peft import PeftModel +import transformers +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init generation_config = dict( temperature=0.2, diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 0a93a65..c91d541 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -21,6 +21,37 @@ from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig from peft import PeftModel +import transformers +old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ +def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None): + self.dim = dim + self.base = base + old_init(self, dim, max_position_embeddings, base, device) + +def adaptive_ntk_forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + dim = self.dim + alpha = seq_len / 1024 - 1 + base = self.base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = emb.cos()[None, None, :, :] + sin_cached = emb.sin()[None, None, :, :] + return ( + cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) + ) +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward +transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init + from openai_api_protocol import ( ChatCompletionRequest, ChatCompletionResponse,