diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index fc32834..40eeb3f 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -2,9 +2,9 @@ # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help on: push: - branches: [main, master, dev, test, tawab, stream] + branches: [main, master, dev, generate] pull_request: - branches: [main, master, dev, test, tawab, stream] + branches: [main, master, dev, generate] name: R-CMD-check diff --git a/.gitignore b/.gitignore index 457525e..ddbce56 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ .DS_Store .quarto docs + +R/scratch.R diff --git a/R/ollama.R b/R/ollama.R index 318a397..9dc3773 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -44,71 +44,170 @@ create_request <- function(endpoint, host = NULL) { -#' Get available local models + + + + + + + + +#' Generate a completion. #' -#' @param output The output format. Default is "df". Other options are "resp", "jsonlist", "raw", "text". -#' @param endpoint The endpoint to get the models. Default is "/api/tags". +#' Generate a response for a given prompt with a provided model. +#' +#' @param model A character string of the model name such as "llama3". +#' @param prompt A character string of the promp like "The sky is..." +#' @param suffix A character string after the model response. Default is "". +#' @param images A path to an image file to include in the prompt. Default is "". +#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "". +#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "". +#' @param context A list of context from a previous response to include previous conversation in the prompt. Default is an empty list. +#' @param stream Enable response streaming. Default is FALSE. +#' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE. +#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes). +#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text". +#' @param endpoint The endpoint to generate the completion. Default is "/api/generate". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' @param ... Additional options to pass to the model. #' #' @return A response in the format specified in the output parameter. #' @export #' +#' @references +#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion) +#' #' @examplesIf test_connection()$status_code == 200 -#' list_models() # returns dataframe/tibble by default -#' list_models("df") -#' list_models("resp") # httr2 response object -#' list_models("jsonlist") -#' list_models("raw") -list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), endpoint = "/api/tags", host = NULL) { +#' generate("llama3", "The sky is...", stream = FALSE, output = "df") +#' generate("llama3", "The sky is...", stream = TRUE, output = "text") +#' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0) +#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") +generate <- function(model, prompt, suffix = "", images = list(), system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/generate", host = NULL, ...) { - if (!output[1] %in% c("df", "resp", "jsonlist", "raw", "text")) { - stop("Invalid output format specified. Supported formats are 'df', 'resp', 'jsonlist', 'raw', 'text'.") - } req <- create_request(endpoint, host) - req <- httr2::req_method(req, "GET") - tryCatch( - { - resp <- httr2::req_perform(req) - return(resp_process(resp = resp, output = output[1])) - }, - error = function(e) { - stop(e) - } + req <- httr2::req_method(req, "POST") + body_json <- list( + model = model, + prompt = prompt, + suffix = suffix, + system = system, + template = template, + context = context, + stream = stream, + raw = raw, + images = images, + stream = stream, + keep_alive = keep_alive ) + + # check if model options are passed and specified correctly + opts <- list(...) + if (length(opts) > 0) { + if (validate_options(...)) { + body_json$options <- opts + } else { + stop("Invalid model options passed to ... argument. Please check the model options and try again.") + } + } + + req <- httr2::req_body_json(req, body_json, stream = stream) + + if (!stream) { + tryCatch( + { + resp <- httr2::req_perform(req) + return(resp_process(resp = resp, output = output[1])) + }, + error = function(e) { + stop(e) + } + ) + } + + # streaming + env <- new.env() + env$buffer <- "" + env$content <- "" + env$accumulated_data <- raw() + wrapped_handler <- function(x) stream_handler(x, env, endpoint) + resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) + cat("\n\n") + resp$body <- env$accumulated_data + + return(resp_process(resp = resp, output = output[1])) } -#' Pull/download a model -#' -#' See https://ollama.com/library for a list of available models. Use the list_models() function to get the list of models already downloaded/installed on your machine. + + + + + + + +#' Chat with Ollama models #' -#' @param model A character string of the model name to download/pull, such as "llama3". -#' @param stream Enable response streaming. Default is TRUE. -#' @param insecure Allow insecure connections Only use this if you are pulling from your own library during development. Default is FALSE. -#' @param endpoint The endpoint to pull the model. Default is "/api/pull". +#' @param model A character string of the model name such as "llama3". +#' @param messages A list with list of messages for the model (see examples below). +#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". +#' @param stream Enable response streaming. Default is FALSE. +#' @param keep_alive The duration to keep the connection alive. Default is "5m". +#' @param endpoint The endpoint to chat with the model. Default is "/api/chat". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' @param ... Additional options to pass to the model. #' -#' @return A httr2 response object. +#' @return A response in the format specified in the output parameter. #' @export #' #' @examplesIf test_connection()$status_code == 200 -#' pull("llama3") -#' pull("all-minilm", stream = FALSE) -pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) { - +#' # one message +#' messages <- list( +#' list(role = "user", content = "How are you doing?") +#' ) +#' chat("llama3", messages) # returns response by default +#' chat("llama3", messages, "text") # returns text/vector +#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options +#' chat("llama3", messages, stream = TRUE) # stream response +#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe +#' +#' # multiple messages +#' messages <- list( +#' list(role = "user", content = "Hello!"), +#' list(role = "assistant", content = "Hi! How are you?"), +#' list(role = "user", content = "Who is the prime minister of the uk?"), +#' list(role = "assistant", content = "Rishi Sunak"), +#' list(role = "user", content = "List all the previous messages.") +#' ) +#' chat("llama3", messages, stream = TRUE) +chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) { req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") - body_json <- list(model = model, stream = stream, insecure = insecure) + body_json <- list( + model = model, + messages = messages, + stream = stream, + keep_alive = keep_alive + ) + + opts <- list(...) + if (length(opts) > 0) { + if (validate_options(...)) { + body_json$options <- opts + } else { + stop("Invalid model options passed to ... argument. Please check the model options and try again.") + } + } + req <- httr2::req_body_json(req, body_json) if (!stream) { tryCatch( { resp <- httr2::req_perform(req) - return(resp) + return(resp_process(resp = resp, output = output[1])) }, error = function(e) { stop(e) @@ -123,11 +222,107 @@ pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", env$accumulated_data <- raw() wrapped_handler <- function(x) stream_handler(x, env, endpoint) resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) + cat("\n\n") + + # process streaming output + json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]] + json_lines_output <- vector("list", length = length(json_lines)) + df_response <- tibble::tibble( + model = character(length(json_lines_output)), + role = character(length(json_lines_output)), + content = character(length(json_lines_output)), + created_at = character(length(json_lines_output)) + ) + + if (output[1] == "raw") { + return(rawToChar(env$accumulated_data)) + } + + for (i in seq_along(json_lines)) { + json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) + df_response$model[i] <- json_lines_output[[i]]$model + df_response$role[i] <- json_lines_output[[i]]$message$role + df_response$content[i] <- json_lines_output[[i]]$message$content + df_response$created_at[i] <- json_lines_output[[i]]$created_at + } + + if (output[1] == "jsonlist") { + return(json_lines_output) + } + + if (output[1] == "df") { + return(df_response) + } + + if (output[1] == "text") { + return(paste0(df_response$content, collapse = "")) + } + return(resp) } + + + + + + + + + + + + + + + + + + + +#' Get available local models +#' +#' @param output The output format. Default is "df". Other options are "resp", "jsonlist", "raw", "text". +#' @param endpoint The endpoint to get the models. Default is "/api/tags". +#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. +#' +#' @return A response in the format specified in the output parameter. +#' @export +#' +#' @examplesIf test_connection()$status_code == 200 +#' list_models() # returns dataframe/tibble by default +#' list_models("df") +#' list_models("resp") # httr2 response object +#' list_models("jsonlist") +#' list_models("raw") +list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), endpoint = "/api/tags", host = NULL) { + + if (!output[1] %in% c("df", "resp", "jsonlist", "raw", "text")) { + stop("Invalid output format specified. Supported formats are 'df', 'resp', 'jsonlist', 'raw', 'text'.") + } + req <- create_request(endpoint, host) + req <- httr2::req_method(req, "GET") + tryCatch( + { + resp <- httr2::req_perform(req) + return(resp_process(resp = resp, output = output[1])) + }, + error = function(e) { + stop(e) + } + ) +} + + + + + + + + + #' Delete a model #' #' Delete a model from your local machine that you downlaoded using the pull() function. To see which models are available, use the list_models() function. @@ -161,64 +356,62 @@ delete <- function(model, endpoint = "/api/delete", host = NULL) { } -normalize <- function(x) { - norm <- sqrt(sum(x^2)) - normalized_vector <- x / norm - return(normalized_vector) -} - - -#' Get vector embedding for a single prompt +#' Pull/download a model #' -#' This function will be deprecated over time and has been superceded by `embed()`. See `embed()` for more details. +#' See https://ollama.com/library for a list of available models. Use the list_models() function to get the list of models already downloaded/installed on your machine. #' -#' @param model A character string of the model name such as "llama3". -#' @param prompt A character string of the prompt that you want to get the vector embedding for. -#' @param normalize Normalize the vector to length 1. Default is TRUE. -#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes). -#' @param endpoint The endpoint to get the vector embedding. Default is "/api/embeddings". +#' @param model A character string of the model name to download/pull, such as "llama3". +#' @param stream Enable response streaming. Default is TRUE. +#' @param insecure Allow insecure connections Only use this if you are pulling from your own library during development. Default is FALSE. +#' @param endpoint The endpoint to pull the model. Default is "/api/pull". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. -#' @param ... Additional options to pass to the model. #' -#' @return A numeric vector of the embedding. +#' @return A httr2 response object. #' @export #' #' @examplesIf test_connection()$status_code == 200 -#' embeddings("nomic-embed-text:latest", "The quick brown fox jumps over the lazy dog.") -#' # pass model options to the model -#' embeddings("nomic-embed-text:latest", "Hello!", temperature = 0.1, num_predict = 3) -embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpoint = "/api/embeddings", host = NULL, ...) { +#' pull("llama3") +#' pull("all-minilm", stream = FALSE) +pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) { + req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") - body_json <- list(model = model, prompt = prompt, keep_alive = keep_alive) - opts <- list(...) - if (length(opts) > 0) { - if (validate_options(...)) { - body_json$options <- opts - } else { - stop("Invalid model options passed to ... argument. Please check the model options and try again.") - } + body_json <- list(model = model, stream = stream, insecure = insecure) + req <- httr2::req_body_json(req, body_json) + + if (!stream) { + tryCatch( + { + resp <- httr2::req_perform(req) + return(resp) + }, + error = function(e) { + stop(e) + } + ) } - req <- httr2::req_body_json(req, body_json) + # streaming + env <- new.env() + env$buffer <- "" + env$content <- "" + env$accumulated_data <- raw() + wrapped_handler <- function(x) stream_handler(x, env, endpoint) + resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) + return(resp) +} - tryCatch( - { - resp <- httr2::req_perform(req) - v <- unlist(resp_process(resp, "jsonlist")$embedding) - if (normalize) { - v <- normalize(v) - } - return(v) - }, - error = function(e) { - stop(e) - } - ) + + + +normalize <- function(x) { + norm <- sqrt(sum(x^2)) + normalized_vector <- x / norm + return(normalized_vector) } @@ -229,6 +422,12 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo + + + + + + #' Get embedding for inputs #' #' Supercedes the `embeddings()` function. @@ -298,50 +497,29 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive = - - - - - -#' Generate a completion. +#' Get vector embedding for a single prompt #' -#' Generate a response for a given prompt with a provided model. +#' This function will be deprecated over time and has been superceded by `embed()`. See `embed()` for more details. #' #' @param model A character string of the model name such as "llama3". -#' @param prompt A character string of the promp like "The sky is..." -#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "". -#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "". -#' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE. -#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text". -#' @param stream Enable response streaming. Default is FALSE. +#' @param prompt A character string of the prompt that you want to get the vector embedding for. +#' @param normalize Normalize the vector to length 1. Default is TRUE. #' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes). -#' @param endpoint The endpoint to generate the completion. Default is "/api/generate". +#' @param endpoint The endpoint to get the vector embedding. Default is "/api/embeddings". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. #' @param ... Additional options to pass to the model. #' -#' -#' @return A response in the format specified in the output parameter. +#' @return A numeric vector of the embedding. #' @export #' #' @examplesIf test_connection()$status_code == 200 -#' generate("llama3", "The sky is...", stream = FALSE, output = "df") -#' generate("llama3", "The sky is...", stream = TRUE, output = "text") -#' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0) -#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") -generate <- function(model, prompt, system = "", template = "", raw = FALSE, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/generate", host = NULL, ...) { +#' embeddings("nomic-embed-text:latest", "The quick brown fox jumps over the lazy dog.") +#' # pass model options to the model +#' embeddings("nomic-embed-text:latest", "Hello!", temperature = 0.1, num_predict = 3) +embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpoint = "/api/embeddings", host = NULL, ...) { req <- create_request(endpoint, host) req <- httr2::req_method(req, "POST") - - body_json <- list( - model = model, - stream = stream, - prompt = prompt, - system = system, - template = template, - raw = raw, - stream = stream, - keep_alive = keep_alive - ) + body_json <- list(model = model, prompt = prompt, keep_alive = keep_alive) opts <- list(...) if (length(opts) > 0) { @@ -352,62 +530,21 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out } } - req <- httr2::req_body_json(req, body_json, stream = stream) + req <- httr2::req_body_json(req, body_json) - if (!stream) { - tryCatch( - { - resp <- httr2::req_perform(req) - return(resp_process(resp = resp, output = output[1])) - }, - error = function(e) { - stop(e) + tryCatch( + { + resp <- httr2::req_perform(req) + v <- unlist(resp_process(resp, "jsonlist")$embedding) + if (normalize) { + v <- normalize(v) } - ) - } - - # streaming - env <- new.env() - env$buffer <- "" - env$content <- "" - env$accumulated_data <- raw() - wrapped_handler <- function(x) stream_handler(x, env, endpoint) - resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) - cat("\n\n") - - # process streaming output - json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]] - json_lines_output <- vector("list", length = length(json_lines)) - df_response <- tibble::tibble( - model = character(length(json_lines_output)), - response = character(length(json_lines_output)), - created_at = character(length(json_lines_output)) + return(v) + }, + error = function(e) { + stop(e) + } ) - - if (output[1] == "raw") { - return(rawToChar(env$accumulated_data)) - } - - for (i in seq_along(json_lines)) { - json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) - df_response$model[i] <- json_lines_output[[i]]$model - df_response$response[i] <- json_lines_output[[i]]$response - df_response$created_at[i] <- json_lines_output[[i]]$created_at - } - - if (output[1] == "jsonlist") { - return(json_lines_output) - } - - if (output[1] == "df") { - return(df_response) - } - - if (output[1] == "text") { - return(paste0(df_response$response, collapse = "")) - } - - return(resp) } @@ -420,119 +557,10 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out -#' Chat with Ollama models -#' -#' @param model A character string of the model name such as "llama3". -#' @param messages A list with list of messages for the model (see examples below). -#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". -#' @param stream Enable response streaming. Default is FALSE. -#' @param keep_alive The duration to keep the connection alive. Default is "5m". -#' @param endpoint The endpoint to chat with the model. Default is "/api/chat". -#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. -#' @param ... Additional options to pass to the model. -#' -#' @return A response in the format specified in the output parameter. -#' @export -#' -#' @examplesIf test_connection()$status_code == 200 -#' # one message -#' messages <- list( -#' list(role = "user", content = "How are you doing?") -#' ) -#' chat("llama3", messages) # returns response by default -#' chat("llama3", messages, "text") # returns text/vector -#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options -#' chat("llama3", messages, stream = TRUE) # stream response -#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe -#' -#' # multiple messages -#' messages <- list( -#' list(role = "user", content = "Hello!"), -#' list(role = "assistant", content = "Hi! How are you?"), -#' list(role = "user", content = "Who is the prime minister of the uk?"), -#' list(role = "assistant", content = "Rishi Sunak"), -#' list(role = "user", content = "List all the previous messages.") -#' ) -#' chat("llama3", messages, stream = TRUE) -chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) { - req <- create_request(endpoint, host) - req <- httr2::req_method(req, "POST") - - body_json <- list( - model = model, - messages = messages, - stream = stream, - keep_alive = keep_alive - ) - - opts <- list(...) - if (length(opts) > 0) { - if (validate_options(...)) { - body_json$options <- opts - } else { - stop("Invalid model options passed to ... argument. Please check the model options and try again.") - } - } - - req <- httr2::req_body_json(req, body_json) - - if (!stream) { - tryCatch( - { - resp <- httr2::req_perform(req) - return(resp_process(resp = resp, output = output[1])) - }, - error = function(e) { - stop(e) - } - ) - } - - # streaming - env <- new.env() - env$buffer <- "" - env$content <- "" - env$accumulated_data <- raw() - wrapped_handler <- function(x) stream_handler(x, env, endpoint) - resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1) - cat("\n\n") - - # process streaming output - json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]] - json_lines_output <- vector("list", length = length(json_lines)) - df_response <- tibble::tibble( - model = character(length(json_lines_output)), - role = character(length(json_lines_output)), - content = character(length(json_lines_output)), - created_at = character(length(json_lines_output)) - ) - - if (output[1] == "raw") { - return(rawToChar(env$accumulated_data)) - } - - for (i in seq_along(json_lines)) { - json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) - df_response$model[i] <- json_lines_output[[i]]$model - df_response$role[i] <- json_lines_output[[i]]$message$role - df_response$content[i] <- json_lines_output[[i]]$message$content - df_response$created_at[i] <- json_lines_output[[i]]$created_at - } - if (output[1] == "jsonlist") { - return(json_lines_output) - } - if (output[1] == "df") { - return(df_response) - } - if (output[1] == "text") { - return(paste0(df_response$content, collapse = "")) - } - return(resp) -} diff --git a/R/utils.R b/R/utils.R index b0da046..06d275b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -45,13 +45,12 @@ stream_handler <- function(x, env, endpoint) { #' #' @examplesIf test_connection()$status_code == 200 #' resp <- list_models("resp") -#' resp_process(resp, "df") # parse response to dataframe/tibble -#' resp_process(resp, "jsonlist") # parse response to list -#' resp_process(resp, "raw") # parse response to raw string -#' resp_process(resp, "resp") # return input response object -#' resp_process(resp, "text") # return text/character vector +#' resp_process(resp, "df") # parse response to dataframe/tibble +#' resp_process(resp, "jsonlist") # parse response to list +#' resp_process(resp, "raw") # parse response to raw string +#' resp_process(resp, "resp") # return input response object +#' resp_process(resp, "text") # return text/character vector resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text")) { - if (is.null(resp) || resp$status_code != 200) { warning("Cannot process response") return(NULL) @@ -60,30 +59,68 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text output <- output[1] if (output == "resp") { return(resp) - } else if (output == "raw") { + } + + # process stream resp separately + stream <- FALSE + headers <- httr2::resp_headers(resp) + transfer_encoding <- headers$`Transfer-Encoding` # if response is chunked, then it was a streamed output + if (!is.null(transfer_encoding)) stream <- grepl("chunked", transfer_encoding) + if (stream) { + return(resp_process_stream(resp, output)) + } + + # process non-stream response below + if (output == "raw") { return(httr2::resp_raw(resp)) } else if (output == "jsonlist") { return(httr2::resp_body_json(resp)) } - # convert data to data frame # process different endpoints - if (grepl("api/tags", resp$url)) { + if (grepl("api/generate", resp$url)) { # process generate endpoint + json_body <- httr2::resp_body_json(resp) + df_response <- tibble::tibble( + model = json_body$model, + response = json_body$response, + created_at = json_body$created_at + ) + + if (output == "df") { + return(df_response) + } else if (output == "text") { + return(df_response$response) + } + } else if (grepl("api/chat", resp$url)) { # process chat endpoint + json_body <- httr2::resp_body_json(resp) + df_response <- tibble::tibble( + model = json_body$model, + role = json_body$message$role, + content = json_body$message$content, + created_at = json_body$created_at + ) + if (output == "df") { + return(df_response) + } else if (output == "text") { + return(df_response$content) + } + } else if (grepl("api/tags", resp$url)) { # process tags endpoint json_body <- httr2::resp_body_json(resp)[[1]] df_response <- tibble::tibble( name = character(length(json_body)), size = character(length(json_body)), parameter_size = character(length(json_body)), quantization_level = character(length(json_body)), - modified = character(length(json_body))) + modified = character(length(json_body)) + ) for (i in seq_along(json_body)) { - df_response[i, 'name'] <- json_body[[i]]$name + df_response[i, "name"] <- json_body[[i]]$name size <- json_body[[i]]$size / 10^9 df_response[i, "size"] <- ifelse(size > 1, paste0(round(size, 1), " GB"), paste0(round(size * 1000), " MB")) - df_response[i, 'parameter_size'] <- json_body[[i]]$details$parameter_size - df_response[i, 'quantization_level'] <- json_body[[i]]$details$quantization_level - df_response[i, 'modified'] <- strsplit(json_body[[i]]$modified_at, ".", fixed = TRUE)[[1]][1] + df_response[i, "parameter_size"] <- json_body[[i]]$details$parameter_size + df_response[i, "quantization_level"] <- json_body[[i]]$details$quantization_level + df_response[i, "modified"] <- strsplit(json_body[[i]]$modified_at, ".", fixed = TRUE)[[1]][1] } if (output == "df") { @@ -91,40 +128,63 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text } else if (output == "text") { return(df_response$name) } + } +} - # process chat endpoint - } else if (grepl("api/chat", resp$url)) { - json_body <- httr2::resp_body_json(resp) - df_response <- tibble::tibble(model = json_body$model, - role = json_body$message$role, - content = json_body$message$content, - created_at = json_body$created_at) - if (output == "df") { - return(df_response) - } else if (output == "text") { - return(df_response$content) - } - # process generate endpoint - } else if (grepl("api/generate", resp$url)) { +#' Process httr2 response object for streaming. +#' +#' @keywords internal +resp_process_stream <- function(resp, output) { + if (output == "raw") { + return(rawToChar(resp$body)) + } - json_body <- httr2::resp_body_json(resp) - df_response <- tibble::tibble(model = json_body$model, - response = json_body$response, - created_at = json_body$created_at) + if (grepl("api/generate", resp$url)) { # process generate endpoint + json_lines <- strsplit(rawToChar(resp$body), "\n")[[1]] + json_lines_output <- vector("list", length = length(json_lines)) + df_response <- tibble::tibble( + model = character(length(json_lines_output)), + response = character(length(json_lines_output)), + created_at = character(length(json_lines_output)) + ) + for (i in seq_along(json_lines)) { + json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]]) + df_response$model[i] <- json_lines_output[[i]]$model + df_response$response[i] <- json_lines_output[[i]]$response + df_response$created_at[i] <- json_lines_output[[i]]$created_at + } + + if (output == "jsonlist") { + return(json_lines_output) + } if (output == "df") { return(df_response) - } else if (output == "text") { - return(df_response$response) } + if (output == "text") { + return(paste0(df_response$response, collapse = "")) + } + } else if (grepl("api/chat", resp$url)) { # process chat endpoint + return(NULL) # TODO fill in + } else if (grepl("api/tags", resp$url)) { # process tags endpoint { + return(NULL) # TODO fill in } + } + + + + + + + + #' Create a message #' #' @param content The content of the message. @@ -188,7 +248,7 @@ prepend_message <- function(content, role = "user", x = NULL) { x <- list() } new_message <- list(role = role, content = content) - x <- c(list(new_message), x) # Prepend by combining the new message with the existing list + x <- c(list(new_message), x) # Prepend by combining the new message with the existing list return(x) } @@ -211,18 +271,21 @@ prepend_message <- function(content, role = "user", x = NULL) { #' messages <- list( #' list(role = "system", content = "Be friendly"), #' list(role = "user", content = "How are you?") -#' ) +#' ) #' insert_message("INSERT MESSAGE AT THE END", "user", messages) #' insert_message("INSERT MESSAGE AT THE BEGINNING", "user", messages, 2) insert_message <- function(content, role = "user", x = NULL, position = -1) { - if (position == -1) position <- length(x) + 1 new_message <- list(role = role, content = content) if (is.null(x)) { return(list(new_message)) } - if (position == 1) return(prepend_message(content, role, x)) - if (position == length(x) + 1) return(append_message(content, role, x)) + if (position == 1) { + return(prepend_message(content, role, x)) + } + if (position == length(x) + 1) { + return(append_message(content, role, x)) + } x <- c(x[1:(position - 1)], list(new_message), x[position:length(x)]) return(x) @@ -247,11 +310,11 @@ insert_message <- function(content, role = "user", x = NULL, position = -1) { #' messages <- list( #' list(role = "system", content = "Be friendly"), #' list(role = "user", content = "How are you?") -#' ) -#' delete_message(messages, 1) # delete first message -#' delete_message(messages, -2) # same as above (delete first message) -#' delete_message(messages, 2) # delete second message -#' delete_message(messages, -1) # same as above (delete second message) +#' ) +#' delete_message(messages, 1) # delete first message +#' delete_message(messages, -2) # same as above (delete first message) +#' delete_message(messages, 2) # delete second message +#' delete_message(messages, -1) # same as above (delete second message) delete_message <- function(x, position = -1) { if (position == 0 || abs(position) > length(x)) { stop("Position out of valid range.") @@ -259,8 +322,3 @@ delete_message <- function(x, position = -1) { if (position < 0) position <- length(x) + position + 1 return(x[-position]) } - - - - - diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 0000000..e69de29 diff --git a/_pkgdown.yml b/_pkgdown.yml index 8a8fabd..e6ace5c 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -7,11 +7,11 @@ reference: - title: API calls desc: Functions to make calls to the Ollama server/API contents: - - list_models - - pull - - delete - generate - chat + - list_models + - delete + - pull - embed - embeddings - test_connection @@ -41,3 +41,4 @@ reference: - package_config - model_options - stream_handler + - resp_process_stream diff --git a/inst/extdata/image1.png b/inst/extdata/image1.png new file mode 100644 index 0000000..d8e59d6 Binary files /dev/null and b/inst/extdata/image1.png differ diff --git a/inst/extdata/image2.png b/inst/extdata/image2.png new file mode 100644 index 0000000..aeb3bd5 Binary files /dev/null and b/inst/extdata/image2.png differ diff --git a/man/delete_message.Rd b/man/delete_message.Rd index 1c35f10..c6a9d22 100644 --- a/man/delete_message.Rd +++ b/man/delete_message.Rd @@ -23,9 +23,9 @@ elements/messages from the end of the sequence. messages <- list( list(role = "system", content = "Be friendly"), list(role = "user", content = "How are you?") - ) -delete_message(messages, 1) # delete first message -delete_message(messages, -2) # same as above (delete first message) -delete_message(messages, 2) # delete second message -delete_message(messages, -1) # same as above (delete second message) +) +delete_message(messages, 1) # delete first message +delete_message(messages, -2) # same as above (delete first message) +delete_message(messages, 2) # delete second message +delete_message(messages, -1) # same as above (delete second message) } diff --git a/man/generate.Rd b/man/generate.Rd index 353a2db..cc7c96f 100644 --- a/man/generate.Rd +++ b/man/generate.Rd @@ -7,12 +7,15 @@ generate( model, prompt, + suffix = "", + images = list(), system = "", template = "", - raw = FALSE, - output = c("resp", "jsonlist", "raw", "df", "text"), + context = list(), stream = FALSE, + raw = FALSE, keep_alive = "5m", + output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/generate", host = NULL, ... @@ -23,18 +26,24 @@ generate( \item{prompt}{A character string of the promp like "The sky is..."} +\item{suffix}{A character string after the model response. Default is "".} + +\item{images}{A path to an image file to include in the prompt. Default is "".} + \item{system}{A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".} \item{template}{A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".} -\item{raw}{If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE.} - -\item{output}{A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".} +\item{context}{A list of context from a previous response to include previous conversation in the prompt. Default is an empty list.} \item{stream}{Enable response streaming. Default is FALSE.} +\item{raw}{If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE.} + \item{keep_alive}{The time to keep the connection alive. Default is "5m" (5 minutes).} +\item{output}{A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".} + \item{endpoint}{The endpoint to generate the completion. Default is "/api/generate".} \item{host}{The base URL to use. Default is NULL, which uses Ollama's default base URL.} @@ -55,3 +64,6 @@ generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") \dontshow{\}) # examplesIf} } +\references{ +\href{https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion}{API documentation} +} diff --git a/man/insert_message.Rd b/man/insert_message.Rd index 4193dea..45548d8 100644 --- a/man/insert_message.Rd +++ b/man/insert_message.Rd @@ -26,7 +26,7 @@ The role and content are converted to a list and inserted into the input list at messages <- list( list(role = "system", content = "Be friendly"), list(role = "user", content = "How are you?") - ) +) insert_message("INSERT MESSAGE AT THE END", "user", messages) insert_message("INSERT MESSAGE AT THE BEGINNING", "user", messages, 2) } diff --git a/man/resp_process.Rd b/man/resp_process.Rd index 47929fc..ef6c469 100644 --- a/man/resp_process.Rd +++ b/man/resp_process.Rd @@ -20,10 +20,10 @@ Process httr2 response object. \examples{ \dontshow{if (test_connection()$status_code == 200) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} resp <- list_models("resp") -resp_process(resp, "df") # parse response to dataframe/tibble -resp_process(resp, "jsonlist") # parse response to list -resp_process(resp, "raw") # parse response to raw string -resp_process(resp, "resp") # return input response object -resp_process(resp, "text") # return text/character vector +resp_process(resp, "df") # parse response to dataframe/tibble +resp_process(resp, "jsonlist") # parse response to list +resp_process(resp, "raw") # parse response to raw string +resp_process(resp, "resp") # return input response object +resp_process(resp, "text") # return text/character vector \dontshow{\}) # examplesIf} } diff --git a/man/resp_process_stream.Rd b/man/resp_process_stream.Rd new file mode 100644 index 0000000..84c7719 --- /dev/null +++ b/man/resp_process_stream.Rd @@ -0,0 +1,12 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{resp_process_stream} +\alias{resp_process_stream} +\title{Process httr2 response object for streaming.} +\usage{ +resp_process_stream(resp, output) +} +\description{ +Process httr2 response object for streaming. +} +\keyword{internal}