From 159d0dbc7d2995294290b577b4c8fe41f5a268dc Mon Sep 17 00:00:00 2001 From: Hause Lin Date: Thu, 25 Jul 2024 15:02:23 -0400 Subject: [PATCH] Clean up code --- R/ollama.R | 258 +++++++++++++++++++++++----------------------- R/utils.R | 16 +-- README.Rmd | 13 +-- README.md | 16 +-- man/embeddings.Rd | 2 +- 5 files changed, 145 insertions(+), 160 deletions(-) diff --git a/R/ollama.R b/R/ollama.R index fa96730..318a397 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -60,6 +60,7 @@ create_request <- function(endpoint, host = NULL) { #' list_models("jsonlist") #' list_models("raw") list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), endpoint = "/api/tags", host = NULL) { + if (!output[1] %in% c("df", "resp", "jsonlist", "raw", "text")) { stop("Invalid output format specified. Supported formats are 'df', 'resp', 'jsonlist', 'raw', 'text'.") } @@ -68,7 +69,6 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end tryCatch( { resp <- httr2::req_perform(req) - print(resp) return(resp_process(resp = resp, output = output[1])) }, error = function(e) { @@ -79,125 +79,6 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end -#' Chat with Ollama models -#' -#' @param model A character string of the model name such as "llama3". -#' @param messages A list with list of messages for the model (see examples below). -#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". -#' @param stream Enable response streaming. Default is FALSE. -#' @param keep_alive The duration to keep the connection alive. Default is "5m". -#' @param endpoint The endpoint to chat with the model. Default is "/api/chat". -#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. -#' @param ... Additional options to pass to the model. -#' -#' @return A response in the format specified in the output parameter. -#' @export -#' -#' @examplesIf test_connection()$status_code == 200 -#' # one message -#' messages <- list( -#' list(role = "user", content = "How are you doing?") -#' ) -#' chat("llama3", messages) # returns response by default -#' chat("llama3", messages, "text") # returns text/vector -#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options -#' chat("llama3", messages, stream = TRUE) # stream response -#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe -#' -#' # multiple messages -#' messages <- list( -#' list(role = "user", content = "Hello!"), -#' list(role = "assistant", content = "Hi! How are you?"), -#' list(role = "user", content = "Who is the prime minister of the uk?"), -#' list(role = "assistant", content = "Rishi Sunak"), -#' list(role = "user", content = "List all the previous messages.") -#' ) -#' chat("llama3", messages, stream = TRUE) -chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) { - req <- create_request(endpoint, host) - req <- httr2::req_method(req, "POST") - - body_json <- list( - model = model, - messages = messages, - stream = stream, - keep_alive = keep_alive - ) - opts <- list(...) - if (length(opts) > 0) { - if (validate_options(...)) { - body_json$options <- opts - } else { - stop("Invalid model options passed to ... argument. Please check the model options and try again.") - } - } - - req <- httr2::req_body_json(req, body_json) - - content <- "" - if (!stream) { - tryCatch( - { - resp <- httr2::req_perform(req) - print(resp) - return(resp_process(resp = resp, output = output[1])) - }, - error = function(e) { - stop(e) - } - ) - } - - # streaming - env <- new.env() - env$buffer <- "" - env$content <- "" - env$accumulated_data <- raw() - wrapped_handler <- function(x) stream_handler(x, env, endpoint) - resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) - cat("\n\n") - - # process streaming output - json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]] - json_lines_output <- vector("list", length = length(json_lines)) - df_response <- tibble::tibble( - model = character(length(json_lines_output)), - role = character(length(json_lines_output)), - content = character(length(json_lines_output)), - created_at = character(length(json_lines_output)) - ) - - if (output[1] == "raw") { - return(rawToChar(env$accumulated_data)) - } - - for (i in seq_along(json_lines)) { - json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) - df_response$model[i] <- json_lines_output[[i]]$model - df_response$role[i] <- json_lines_output[[i]]$message$role - df_response$content[i] <- json_lines_output[[i]]$message$content - df_response$created_at[i] <- json_lines_output[[i]]$created_at - } - - if (output[1] == "jsonlist") { - return(json_lines_output) - } - - if (output[1] == "df") { - return(df_response) - } - - if (output[1] == "text") { - return(paste0(df_response$content, collapse = "")) - } - - return(resp) -} - - - - - #' Pull/download a model #' @@ -216,18 +97,17 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t #' pull("llama3") #' pull("all-minilm", stream = FALSE) pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) { + req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") body_json <- list(model = model, stream = stream, insecure = insecure) req <- httr2::req_body_json(req, body_json) - content <- "" if (!stream) { tryCatch( { resp <- httr2::req_perform(req) - print(resp) return(resp) }, error = function(e) { @@ -272,7 +152,6 @@ delete <- function(model, endpoint = "/api/delete", host = NULL) { tryCatch( { resp <- httr2::req_perform(req) - print(resp) return(resp) }, error = function(e) { @@ -294,6 +173,8 @@ normalize <- function(x) { #' Get vector embedding for a single prompt #' +#' This function will be deprecated over time and has been superceded by `embed()`. See `embed()` for more details. +#' #' @param model A character string of the model name such as "llama3". #' @param prompt A character string of the prompt that you want to get the vector embedding for. #' @param normalize Normalize the vector to length 1. Default is TRUE. @@ -328,7 +209,6 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo tryCatch( { resp <- httr2::req_perform(req) - print(resp) v <- unlist(resp_process(resp, "jsonlist")$embedding) if (normalize) { v <- normalize(v) @@ -390,7 +270,6 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive = tryCatch( { resp <- httr2::req_perform(req) - print(resp) json_body <- httr2::resp_body_json(resp)$embeddings m <- do.call(cbind, lapply(json_body, function(x) { v <- unlist(x) @@ -475,12 +354,10 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out req <- httr2::req_body_json(req, body_json, stream = stream) - content <- "" if (!stream) { tryCatch( { resp <- httr2::req_perform(req) - print(resp) return(resp_process(resp = resp, output = output[1])) }, error = function(e) { @@ -532,3 +409,130 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out return(resp) } + + + + + + + + + + + +#' Chat with Ollama models +#' +#' @param model A character string of the model name such as "llama3". +#' @param messages A list with list of messages for the model (see examples below). +#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". +#' @param stream Enable response streaming. Default is FALSE. +#' @param keep_alive The duration to keep the connection alive. Default is "5m". +#' @param endpoint The endpoint to chat with the model. Default is "/api/chat". +#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' @param ... Additional options to pass to the model. +#' +#' @return A response in the format specified in the output parameter. +#' @export +#' +#' @examplesIf test_connection()$status_code == 200 +#' # one message +#' messages <- list( +#' list(role = "user", content = "How are you doing?") +#' ) +#' chat("llama3", messages) # returns response by default +#' chat("llama3", messages, "text") # returns text/vector +#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options +#' chat("llama3", messages, stream = TRUE) # stream response +#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe +#' +#' # multiple messages +#' messages <- list( +#' list(role = "user", content = "Hello!"), +#' list(role = "assistant", content = "Hi! How are you?"), +#' list(role = "user", content = "Who is the prime minister of the uk?"), +#' list(role = "assistant", content = "Rishi Sunak"), +#' list(role = "user", content = "List all the previous messages.") +#' ) +#' chat("llama3", messages, stream = TRUE) +chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) { + req <- create_request(endpoint, host) + req <- httr2::req_method(req, "POST") + + body_json <- list( + model = model, + messages = messages, + stream = stream, + keep_alive = keep_alive + ) + + opts <- list(...) + if (length(opts) > 0) { + if (validate_options(...)) { + body_json$options <- opts + } else { + stop("Invalid model options passed to ... argument. Please check the model options and try again.") + } + } + + req <- httr2::req_body_json(req, body_json) + + if (!stream) { + tryCatch( + { + resp <- httr2::req_perform(req) + return(resp_process(resp = resp, output = output[1])) + }, + error = function(e) { + stop(e) + } + ) + } + + # streaming + env <- new.env() + env$buffer <- "" + env$content <- "" + env$accumulated_data <- raw() + wrapped_handler <- function(x) stream_handler(x, env, endpoint) + resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) + cat("\n\n") + + # process streaming output + json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]] + json_lines_output <- vector("list", length = length(json_lines)) + df_response <- tibble::tibble( + model = character(length(json_lines_output)), + role = character(length(json_lines_output)), + content = character(length(json_lines_output)), + created_at = character(length(json_lines_output)) + ) + + if (output[1] == "raw") { + return(rawToChar(env$accumulated_data)) + } + + for (i in seq_along(json_lines)) { + json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) + df_response$model[i] <- json_lines_output[[i]]$model + df_response$role[i] <- json_lines_output[[i]]$message$role + df_response$content[i] <- json_lines_output[[i]]$message$content + df_response$created_at[i] <- json_lines_output[[i]]$created_at + } + + if (output[1] == "jsonlist") { + return(json_lines_output) + } + + if (output[1] == "df") { + return(df_response) + } + + if (output[1] == "text") { + return(paste0(df_response$content, collapse = "")) + } + + return(resp) +} + + + diff --git a/R/utils.R b/R/utils.R index 43e95a6..b0da046 100644 --- a/R/utils.R +++ b/R/utils.R @@ -11,24 +11,18 @@ stream_handler <- function(x, env, endpoint) { tryCatch( { json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "") - if (endpoint == "/api/generate") { stream_content <- jsonlite::fromJSON(json_string)$response - env$content <- c(env$content, stream_content) - env$buffer <- "" - cat(stream_content) # stream/print stream } else if (endpoint == "/api/chat") { stream_content <- jsonlite::fromJSON(json_string)$message$content - env$content <- c(env$content, stream_content) - env$buffer <- "" - cat(stream_content) } else if (endpoint == "/api/pull") { - json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "") stream_content <- jsonlite::fromJSON(json_string)$status - env$content <<- c(env$content, stream_content) - env$buffer <<- "" - cat(stream_content, "\n") + stream_content <- paste0(stream_content, "\n") } + # concatenate the content + env$content <- c(env$content, stream_content) + env$buffer <- "" + cat(stream_content) # stream/print stream }, error = function(e) { env$buffer <- paste0(env$buffer, json_strings[i]) diff --git a/README.Rmd b/README.Rmd index 8567c9c..0ddc4cc 100644 --- a/README.Rmd +++ b/README.Rmd @@ -89,7 +89,6 @@ resp_process(resp, "text") # process the response to return text/vector output generate("llama3", "Tomorrow is a...", output = "text") # directly return text/vector output generate("llama3", "Tomorrow is a...", stream = TRUE) # return httr2 response object and stream output generate("llama3", "Tomorrow is a...", output = "df", stream = TRUE) -generate("llama3", "Tomorrow is a...", "text", TRUE) # return text/vector output and stream output ``` ### Chat @@ -105,11 +104,13 @@ resp # resp_process(resp, "text") # process the response to return text/vector output # specify output type when calling the function +chat("llama3", messages, output = "text") # text vector chat("llama3", messages, output = "df") # data frame/tibble -chat("llama3", messages, output = "raw") # raw string chat("llama3", messages, output = "jsonlist") # list -chat("llama3", messages, output = "text") # text vector +chat("llama3", messages, output = "raw") # raw string +chat("llama3", messages, stream = TRUE) # stream output and return httr2 response object +# list of messages messages <- list( list(role = "user", content = "Hello!"), list(role = "assistant", content = "Hi! How are you?"), @@ -162,16 +163,10 @@ sum(e1 * e1) # 1 (identical vectors/embeddings) # non-normalized embeddings e3 <- embed("llama3", "Hello, how are you?", normalize = FALSE) e4 <- embed("llama3", "Hi, how are you?", normalize = FALSE) -sum(e3 * e4) # 23695.96 -sum(e3 * e3) # 24067.32 ``` ### Notes -#### Optional/advanced parameters - -Optional/advanced parameters (see [API docs](https://github.com/ollama/ollama/blob/main/docs/api.md)) such as `temperature` are not yet implemented as of now but will be added in the near future. - If you don't have the Ollama app running, you'll get an error. Make sure to open the Ollama app before using this library. ```{r eval=FALSE} diff --git a/README.md b/README.md index 449cc37..b45c3ea 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,6 @@ resp_process(resp, "text") # process the response to return text/vector output generate("llama3", "Tomorrow is a...", output = "text") # directly return text/vector output generate("llama3", "Tomorrow is a...", stream = TRUE) # return httr2 response object and stream output generate("llama3", "Tomorrow is a...", output = "df", stream = TRUE) -generate("llama3", "Tomorrow is a...", "text", TRUE) # return text/vector output and stream output ``` ### Chat @@ -120,11 +119,13 @@ resp # resp_process(resp, "text") # process the response to return text/vector output # specify output type when calling the function +chat("llama3", messages, output = "text") # text vector chat("llama3", messages, output = "df") # data frame/tibble -chat("llama3", messages, output = "raw") # raw string chat("llama3", messages, output = "jsonlist") # list -chat("llama3", messages, output = "text") # text vector +chat("llama3", messages, output = "raw") # raw string +chat("llama3", messages, stream = TRUE) # stream output and return httr2 response object +# list of messages messages <- list( list(role = "user", content = "Hello!"), list(role = "assistant", content = "Hi! How are you?"), @@ -182,19 +183,10 @@ sum(e1 * e1) # 1 (identical vectors/embeddings) # non-normalized embeddings e3 <- embed("llama3", "Hello, how are you?", normalize = FALSE) e4 <- embed("llama3", "Hi, how are you?", normalize = FALSE) -sum(e3 * e4) # 23695.96 -sum(e3 * e3) # 24067.32 ``` ### Notes -#### Optional/advanced parameters - -Optional/advanced parameters (see [API -docs](https://github.com/ollama/ollama/blob/main/docs/api.md)) such as -`temperature` are not yet implemented as of now but will be added in the -near future. - If you don’t have the Ollama app running, you’ll get an error. Make sure to open the Ollama app before using this library. diff --git a/man/embeddings.Rd b/man/embeddings.Rd index 19e74bb..3efc9df 100644 --- a/man/embeddings.Rd +++ b/man/embeddings.Rd @@ -33,7 +33,7 @@ embeddings( A numeric vector of the embedding. } \description{ -Get vector embedding for a single prompt +This function will be deprecated over time and has been superceded by \code{embed()}. See \code{embed()} for more details. } \examples{ \dontshow{if (test_connection()$status_code == 200) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf}