Skip to content

Commit

Permalink
fix stop in /v1/completions , fix ts_first_token
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Sep 22, 2023
1 parent 43aeb95 commit eb927bc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 5 additions & 3 deletions refact_scratchpads/scratchpad_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(

self._stop_lf = False
self._stop_lf_lf = False
self._prev_token: Optional[int] = None
self._stop_tokens: Set[int] = set()
if isinstance(stop_tokens, str):
stop_tokens = [stop_tokens]
Expand Down Expand Up @@ -80,10 +81,11 @@ def after_token_selection(self, m, chosen_token: th.Tensor, **unused) -> Dict[st
if t in self._stop_tokens:
self.finish_reason = "stoptoken"

t_str = self._tokenizer.decode([t])
if self._stop_lf and t_str.startswith("\n"):
couple_of_tokens_decoded = self._tokenizer.decode(([self._prev_token] if self._prev_token is not None else []) + [t])
self._prev_token = t
if self._stop_lf and ("\n" in couple_of_tokens_decoded):
self.finish_reason = "stop-lf"
if self._stop_lf_lf and t_str.startswith("\n\n"):
if self._stop_lf_lf and ("\n\n" in couple_of_tokens_decoded):
self.finish_reason = "stop-lflf"

self._tokens_produced += 1
Expand Down
2 changes: 2 additions & 0 deletions self_hosting_machinery/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(self, tokenizer, request_id: str, upload_proxy: UploadProxy,
self.upload_proxy_args = upload_proxy_args

def put(self, value):
if self.upload_proxy_args.get("ts_first_token", 0) == 0:
self.upload_proxy_args["ts_first_token"] = time.time()
super().put(value)

def on_finalized_text(self, text: str, stream_end: bool = False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def clamp(self):


class NlpCompletion(NlpSamplingParams):
model: str = Query(default="", regex="^[a-z/A-Z0-9_\.]+$")
model: str = Query(default=Required, regex="^[a-z/A-Z0-9_\.\-]+$")
prompt: str
n: int = 1
echo: bool = False
Expand Down

0 comments on commit eb927bc

Please sign in to comment.