Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend context size without fine-tuning #705

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@
import argparse
import os

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

# Parse command-line arguments
parser = argparse.ArgumentParser()
Expand Down
30 changes: 30 additions & 0 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

generation_config = dict(
temperature=0.2,
Expand Down
31 changes: 31 additions & 0 deletions scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from peft import PeftModel

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

from openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down