From f2c9e19cd6e0007de0fde0d2d7f6484dfa96e266 Mon Sep 17 00:00:00 2001 From: Hause Lin Date: Mon, 22 Jul 2024 22:03:18 -0400 Subject: [PATCH] Add ... to chat function --- R/ollama.R | 17 +++++++++++++++-- README.Rmd | 7 +++++-- README.md | 7 +++++-- man/chat.Rd | 8 +++++++- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/R/ollama.R b/R/ollama.R index a53063a..c569a73 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -96,8 +96,10 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end #' @param messages A list with list of messages for the model (see examples below). #' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". #' @param stream Enable response streaming. Default is FALSE. +#' @param keep_alive The duration to keep the connection alive. Default is "5m". #' @param endpoint The endpoint to chat with the model. Default is "/api/chat". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' @param ... Additional options to pass to the model. #' #' @return A response in the format specified in the output parameter. #' @export @@ -121,14 +123,25 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end #' list(role = "user", content = "List all the previous messages.") #' ) #' chat("llama3", messages, stream = TRUE) -chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/chat", host = NULL) { +chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) { req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") body_json <- list(model = model, + messages = messages, stream = stream, - messages = messages) + keep_alive = keep_alive + ) + opts <- list(...) + if (length(opts) > 0) { + if (validate_options(...)) { + body_json$options <- opts + } else { + stop("Invalid model options passed to ... argument. Please check the model options and try again.") + } + } + req <- httr2::req_body_json(req, body_json) content <- "" diff --git a/README.Rmd b/README.Rmd index 0b54094..5cbe80a 100644 --- a/README.Rmd +++ b/README.Rmd @@ -102,10 +102,13 @@ messages <- list( ) resp <- chat("llama3", messages) # default returns httr2 response object resp # +resp_process(resp, "text") # process the response to return text/vector output + +# specify output type when calling the function chat("llama3", messages, output = "df") # data frame/tibble chat("llama3", messages, output = "raw") # raw string chat("llama3", messages, output = "jsonlist") # list -chat("llama3", messages, output = "text", stream = FALSE) # text vector +chat("llama3", messages, output = "text") # text vector messages <- list( list(role = "user", content = "Hello!"), @@ -114,7 +117,7 @@ messages <- list( list(role = "assistant", content = "Rishi Sunak"), list(role = "user", content = "List all the previous messages.") ) -chat("llama3", messages, output = "df") +cat(chat("llama3", messages, output = "text")) # print the formatted output ``` #### Streaming responses diff --git a/README.md b/README.md index 9b4aaae..a8de197 100644 --- a/README.md +++ b/README.md @@ -117,10 +117,13 @@ messages <- list( ) resp <- chat("llama3", messages) # default returns httr2 response object resp # +resp_process(resp, "text") # process the response to return text/vector output + +# specify output type when calling the function chat("llama3", messages, output = "df") # data frame/tibble chat("llama3", messages, output = "raw") # raw string chat("llama3", messages, output = "jsonlist") # list -chat("llama3", messages, output = "text", stream = FALSE) # text vector +chat("llama3", messages, output = "text") # text vector messages <- list( list(role = "user", content = "Hello!"), @@ -129,7 +132,7 @@ messages <- list( list(role = "assistant", content = "Rishi Sunak"), list(role = "user", content = "List all the previous messages.") ) -chat("llama3", messages, output = "df") +cat(chat("llama3", messages, output = "text")) # print the formatted output ``` #### Streaming responses diff --git a/man/chat.Rd b/man/chat.Rd index c941ef5..82e2aee 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -9,8 +9,10 @@ chat( messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, + keep_alive = "5m", endpoint = "/api/chat", - host = NULL + host = NULL, + ... ) } \arguments{ @@ -22,9 +24,13 @@ chat( \item{stream}{Enable response streaming. Default is FALSE.} +\item{keep_alive}{The duration to keep the connection alive. Default is "5m".} + \item{endpoint}{The endpoint to chat with the model. Default is "/api/chat".} \item{host}{The base URL to use. Default is NULL, which uses Ollama's default base URL.} + +\item{...}{Additional options to pass to the model.} } \value{ A response in the format specified in the output parameter.