Skip to content

Commit

Permalink
Add stream code
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 23, 2024
1 parent 6581d02 commit 6f912ff
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 99 deletions.
121 changes: 24 additions & 97 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@ package_config <- list(

















#' Create a httr2 request object.
#'
#' Creates a httr2 request object with base URL, headers and endpoint. Used by other functions in the package and not intended to be used directly.
Expand Down Expand Up @@ -114,7 +100,7 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' )
#' chat("llama3", messages) # returns response by default
#' chat("llama3", messages, "text") # returns text/vector
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, stream = TRUE) # stream response
#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe
#'
Expand Down Expand Up @@ -163,36 +149,16 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
}

# streaming
buffer <- ""
content <- ""
accumulated_data <- raw()
stream_handler <- function(x) {
s <- rawToChar(x)
accumulated_data <<- append(accumulated_data, x)
json_strings <- strsplit(s, "\n")[[1]]

for (i in seq_along(json_strings)) {
tryCatch(
{
json_string <- paste0(buffer, json_strings[i], "\n", collapse = "")
stream_content <- jsonlite::fromJSON(json_string)$message$content
content <<- c(content, stream_content)
buffer <<- ""
# stream/print stream
cat(stream_content)
},
error = function(e) {
buffer <<- paste0(buffer, json_strings[i])
}
)
}
return(TRUE)
}
resp <- httr2::req_perform_stream(req, stream_handler, buffer_kb = 1)
env <- new.env()
env$buffer <- ""
env$content <- ""
env$accumulated_data <- raw()
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
cat("\n\n")

# process streaming output
json_lines <- strsplit(rawToChar(accumulated_data), "\n")[[1]]
json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
Expand All @@ -202,7 +168,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
)

if (output[1] == "raw") {
return(rawToChar(accumulated_data))
return(rawToChar(env$accumulated_data))
}

for (i in seq_along(json_lines)) {
Expand Down Expand Up @@ -248,7 +214,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
#'
#' @examplesIf test_connection()$status_code == 200
#' pull("llama3")
# " pull("all-minilm", stream = FALSE)
#' pull("all-minilm", stream = FALSE)
pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")
Expand All @@ -271,31 +237,12 @@ pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull",
}

# streaming
buffer <- ""
content <- ""
accumulated_data <- raw()
stream_handler <- function(x) {
s <- rawToChar(x)
accumulated_data <<- append(accumulated_data, x)
json_strings <- strsplit(s, "\n")[[1]]
for (i in seq_along(json_strings)) {
tryCatch(
{
json_string <- paste0(buffer, json_strings[i], "\n", collapse = "")
stream_content <- jsonlite::fromJSON(json_string)$status
content <<- c(content, stream_content)
buffer <<- ""
# stream/print stream
cat(stream_content, "\n")
},
error = function(e) {
buffer <<- paste0(buffer, json_strings[i])
}
)
}
return(TRUE)
}
resp <- httr2::req_perform_stream(req, stream_handler, buffer_kb = 1)
env <- new.env()
env$buffer <- ""
env$content <- ""
env$accumulated_data <- raw()
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
return(resp)
}

Expand Down Expand Up @@ -543,36 +490,16 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out
}

# streaming
buffer <- ""
content <- ""
accumulated_data <- raw()
stream_handler <- function(x) {
s <- rawToChar(x)
accumulated_data <<- append(accumulated_data, x)
json_strings <- strsplit(s, "\n")[[1]]

for (i in seq_along(json_strings)) {
tryCatch(
{
json_string <- paste0(buffer, json_strings[i], "\n", collapse = "")
stream_content <- jsonlite::fromJSON(json_string)$response
content <<- c(content, stream_content)
buffer <<- ""
# stream/print stream
cat(stream_content)
},
error = function(e) {
buffer <<- paste0(buffer, json_strings[i])
}
)
}
return(TRUE)
}
resp <- httr2::req_perform_stream(req, stream_handler, buffer_kb = 1)
env <- new.env()
env$buffer <- ""
env$content <- ""
env$accumulated_data <- raw()
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
cat("\n\n")

# process streaming output
json_lines <- strsplit(rawToChar(accumulated_data), "\n")[[1]]
json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
Expand All @@ -581,7 +508,7 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out
)

if (output[1] == "raw") {
return(rawToChar(accumulated_data))
return(rawToChar(env$accumulated_data))
}

for (i in seq_along(json_lines)) {
Expand Down
43 changes: 43 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
#' Stream handler helper function
#'
#' Function to handle streaming.
#'
#' @keywords internal
stream_handler <- function(x, env, endpoint) {
s <- rawToChar(x)
env$accumulated_data <- append(env$accumulated_data, x)
json_strings <- strsplit(s, "\n")[[1]]
for (i in seq_along(json_strings)) {
tryCatch(
{
json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "")

if (endpoint == "/api/generate") {
stream_content <- jsonlite::fromJSON(json_string)$response
env$content <- c(env$content, stream_content)
env$buffer <- ""
cat(stream_content) # stream/print stream
} else if (endpoint == "/api/chat") {
stream_content <- jsonlite::fromJSON(json_string)$message$content
env$content <- c(env$content, stream_content)
env$buffer <- ""
cat(stream_content)
} else if (endpoint == "/api/pull") {
json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "")
stream_content <- jsonlite::fromJSON(json_string)$status
env$content <<- c(env$content, stream_content)
env$buffer <<- ""
cat(stream_content, "\n")
}
},
error = function(e) {
env$buffer <- paste0(env$buffer, json_strings[i])
}
)
}
return(TRUE)
}




#' Process httr2 response object.
#'
#' @param resp A httr2 response object.
Expand Down
3 changes: 2 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ reference:
- delete_message
- insert_message

- subtitle: Miscellaneous functions and variables
- subtitle: Internal functions and variables
desc: Functions and variables used internally by the package
contents:
- package_config
- model_options
- stream_handler
2 changes: 1 addition & 1 deletion man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/pull.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions man/stream_handler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6f912ff

Please sign in to comment.