Skip to content

Commit

Permalink
Add ... to chat function
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 23, 2024
1 parent b1f4bf4 commit f2c9e19
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
17 changes: 15 additions & 2 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' @param messages A list with list of messages for the model (see examples below).
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @return A response in the format specified in the output parameter.
#' @export
Expand All @@ -121,14 +123,25 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/chat", host = NULL) {
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) {

req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model,
messages = messages,
stream = stream,
messages = messages)
keep_alive = keep_alive
)
opts <- list(...)
if (length(opts) > 0) {
if (validate_options(...)) {
body_json$options <- opts
} else {
stop("Invalid model options passed to ... argument. Please check the model options and try again.")
}
}

req <- httr2::req_body_json(req, body_json)

content <- ""
Expand Down
7 changes: 5 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ messages <- list(
)
resp <- chat("llama3", messages) # default returns httr2 response object
resp # <httr2_response>
resp_process(resp, "text") # process the response to return text/vector output
# specify output type when calling the function
chat("llama3", messages, output = "df") # data frame/tibble
chat("llama3", messages, output = "raw") # raw string
chat("llama3", messages, output = "jsonlist") # list
chat("llama3", messages, output = "text", stream = FALSE) # text vector
chat("llama3", messages, output = "text") # text vector
messages <- list(
list(role = "user", content = "Hello!"),
Expand All @@ -114,7 +117,7 @@ messages <- list(
list(role = "assistant", content = "Rishi Sunak"),
list(role = "user", content = "List all the previous messages.")
)
chat("llama3", messages, output = "df")
cat(chat("llama3", messages, output = "text")) # print the formatted output
```

#### Streaming responses
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,13 @@ messages <- list(
)
resp <- chat("llama3", messages) # default returns httr2 response object
resp # <httr2_response>
resp_process(resp, "text") # process the response to return text/vector output

# specify output type when calling the function
chat("llama3", messages, output = "df") # data frame/tibble
chat("llama3", messages, output = "raw") # raw string
chat("llama3", messages, output = "jsonlist") # list
chat("llama3", messages, output = "text", stream = FALSE) # text vector
chat("llama3", messages, output = "text") # text vector

messages <- list(
list(role = "user", content = "Hello!"),
Expand All @@ -129,7 +132,7 @@ messages <- list(
list(role = "assistant", content = "Rishi Sunak"),
list(role = "user", content = "List all the previous messages.")
)
chat("llama3", messages, output = "df")
cat(chat("llama3", messages, output = "text")) # print the formatted output
```

#### Streaming responses
Expand Down
8 changes: 7 additions & 1 deletion man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f2c9e19

Please sign in to comment.