Skip to content

Commit

Permalink
update emperical formula
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Jul 15, 2023
1 parent 1509d39 commit 1f9b872
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
args.gpus = ""

from patches import apply_attention_patch, apply_ntk_scaling_patch
#apply_attention_patch()
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)

# Set CUDA devices if available
Expand Down
2 changes: 1 addition & 1 deletion scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from peft import PeftModel

from patches import apply_attention_patch, apply_ntk_scaling_patch
#apply_attention_patch()
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)

generation_config = dict(
Expand Down
9 changes: 7 additions & 2 deletions scripts/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def apply_attention_patch(

try:
from xformers import ops as xops
if use_memory_efficient_attention is True:
print("Use memory_efficient_attention from xformers")
except ImportError:
xops = None
print(
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
if store_kv_before_rope is True:
print("Store KV before rope")

def xformers_forward(
self,
Expand Down Expand Up @@ -76,8 +80,9 @@ def xformers_forward(
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_bias = None if (query_states.size(1)==1 and key_states.size(1)>1) else xops.LowerTriangularMask()
attn_output = xops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=0)
query_states, key_states, value_states, attn_bias=attn_bias, p=0)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

Expand Down Expand Up @@ -161,7 +166,7 @@ def adaptive_ntk_forward(self, x, seq_len=None):
elif self.alpha=='auto':
t = torch.arange(seq_len, device=x.device, dtype=self.ntk_inv_freq.dtype)
dim = self.dim
alpha = seq_len / 1024 - 1
alpha = (seq_len / 1024 - 1) * 1.1
base = self.base * alpha ** (dim / (dim-2))
ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

Expand Down
2 changes: 1 addition & 1 deletion scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from peft import PeftModel

from patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch()
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)

from openai_api_protocol import (
Expand Down
2 changes: 1 addition & 1 deletion scripts/openai_server_demo/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def adaptive_ntk_forward(self, x, seq_len=None):
elif self.alpha=='auto':
t = torch.arange(seq_len, device=x.device, dtype=self.ntk_inv_freq.dtype)
dim = self.dim
alpha = seq_len / 1024 - 1
alpha = (seq_len / 1024 - 1) * 1.1
base = self.base * alpha ** (dim / (dim-2))
ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

Expand Down

0 comments on commit 1f9b872

Please sign in to comment.