Skip to content

Commit

Permalink
Update generate function
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 23, 2024
1 parent 544a431 commit 6581d02
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
30 changes: 26 additions & 4 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' )
#' chat("llama3", messages) # returns response by default
#' chat("llama3", messages, "text") # returns text/vector
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, stream = TRUE) # stream response
#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe
#'
Expand Down Expand Up @@ -482,29 +483,50 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive =
#'
#' @param model A character string of the model name such as "llama3".
#' @param prompt A character string of the promp like "The sky is..."
#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".
#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".
#' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE.
#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes).
#' @param endpoint The endpoint to generate the completion. Default is "/api/generate".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#'
#' @return A response in the format specified in the output parameter.
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' generate("llama3", "The sky is...", stream = FALSE, output = "df")
#' generate("llama3", "The sky is...", stream = TRUE, output = "df")
#' generate("llama3", "The sky is...", stream = TRUE, output = "text")
#' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0)
#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist")
generate <- function(model, prompt, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/generate", host = NULL) {
generate <- function(model, prompt, system = "", template = "", raw = FALSE, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/generate", host = NULL, ...) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(
model = model,
stream = stream,
prompt = prompt
prompt = prompt,
system = system,
template = template,
raw = raw,
stream = stream,
keep_alive = keep_alive
)
req <- httr2::req_body_json(req, body_json)

opts <- list(...)
if (length(opts) > 0) {
if (validate_options(...)) {
body_json$options <- opts
} else {
stop("Invalid model options passed to ... argument. Please check the model options and try again.")
}
}

req <- httr2::req_body_json(req, body_json, stream = stream)

content <- ""
if (!stream) {
Expand Down
1 change: 1 addition & 0 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 18 additions & 2 deletions man/generate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6581d02

Please sign in to comment.