diff --git a/R/ollama.R b/R/ollama.R index 04191be..e2979b7 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -114,6 +114,7 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end #' ) #' chat("llama3", messages) # returns response by default #' chat("llama3", messages, "text") # returns text/vector +#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options #' chat("llama3", messages, stream = TRUE) # stream response #' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe #' @@ -482,10 +483,15 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive = #' #' @param model A character string of the model name such as "llama3". #' @param prompt A character string of the promp like "The sky is..." +#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "". +#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "". +#' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE. #' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text". #' @param stream Enable response streaming. Default is FALSE. +#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes). #' @param endpoint The endpoint to generate the completion. Default is "/api/generate". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' @param ... Additional options to pass to the model. #' #' #' @return A response in the format specified in the output parameter. @@ -493,18 +499,34 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive = #' #' @examplesIf test_connection()$status_code == 200 #' generate("llama3", "The sky is...", stream = FALSE, output = "df") -#' generate("llama3", "The sky is...", stream = TRUE, output = "df") +#' generate("llama3", "The sky is...", stream = TRUE, output = "text") +#' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0) #' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") -generate <- function(model, prompt, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/generate", host = NULL) { +generate <- function(model, prompt, system = "", template = "", raw = FALSE, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/generate", host = NULL, ...) { req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") body_json <- list( model = model, stream = stream, - prompt = prompt + prompt = prompt, + system = system, + template = template, + raw = raw, + stream = stream, + keep_alive = keep_alive ) - req <- httr2::req_body_json(req, body_json) + + opts <- list(...) + if (length(opts) > 0) { + if (validate_options(...)) { + body_json$options <- opts + } else { + stop("Invalid model options passed to ... argument. Please check the model options and try again.") + } + } + + req <- httr2::req_body_json(req, body_json, stream = stream) content <- "" if (!stream) { diff --git a/man/chat.Rd b/man/chat.Rd index e993459..9435969 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -46,6 +46,7 @@ messages <- list( ) chat("llama3", messages) # returns response by default chat("llama3", messages, "text") # returns text/vector +chat("llama3", messages, "hello!", temperature = 2.8) # additional options chat("llama3", messages, stream = TRUE) # stream response chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe diff --git a/man/generate.Rd b/man/generate.Rd index 907ad68..353a2db 100644 --- a/man/generate.Rd +++ b/man/generate.Rd @@ -7,10 +7,15 @@ generate( model, prompt, + system = "", + template = "", + raw = FALSE, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, + keep_alive = "5m", endpoint = "/api/generate", - host = NULL + host = NULL, + ... ) } \arguments{ @@ -18,13 +23,23 @@ generate( \item{prompt}{A character string of the promp like "The sky is..."} +\item{system}{A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".} + +\item{template}{A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".} + +\item{raw}{If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE.} + \item{output}{A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".} \item{stream}{Enable response streaming. Default is FALSE.} +\item{keep_alive}{The time to keep the connection alive. Default is "5m" (5 minutes).} + \item{endpoint}{The endpoint to generate the completion. Default is "/api/generate".} \item{host}{The base URL to use. Default is NULL, which uses Ollama's default base URL.} + +\item{...}{Additional options to pass to the model.} } \value{ A response in the format specified in the output parameter. @@ -35,7 +50,8 @@ Generate a response for a given prompt with a provided model. \examples{ \dontshow{if (test_connection()$status_code == 200) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} generate("llama3", "The sky is...", stream = FALSE, output = "df") -generate("llama3", "The sky is...", stream = TRUE, output = "df") +generate("llama3", "The sky is...", stream = TRUE, output = "text") +generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0) generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") \dontshow{\}) # examplesIf} }