Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 25, 2024
1 parent a69d150 commit 159d0db
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 160 deletions.
258 changes: 131 additions & 127 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ create_request <- function(endpoint, host = NULL) {
#' list_models("jsonlist")
#' list_models("raw")
list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), endpoint = "/api/tags", host = NULL) {

if (!output[1] %in% c("df", "resp", "jsonlist", "raw", "text")) {
stop("Invalid output format specified. Supported formats are 'df', 'resp', 'jsonlist', 'raw', 'text'.")
}
Expand All @@ -68,7 +69,6 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
return(resp_process(resp = resp, output = output[1]))
},
error = function(e) {
Expand All @@ -79,125 +79,6 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end



#' Chat with Ollama models
#'
#' @param model A character string of the model name such as "llama3".
#' @param messages A list with list of messages for the model (see examples below).
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @return A response in the format specified in the output parameter.
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' # one message
#' messages <- list(
#' list(role = "user", content = "How are you doing?")
#' )
#' chat("llama3", messages) # returns response by default
#' chat("llama3", messages, "text") # returns text/vector
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, stream = TRUE) # stream response
#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe
#'
#' # multiple messages
#' messages <- list(
#' list(role = "user", content = "Hello!"),
#' list(role = "assistant", content = "Hi! How are you?"),
#' list(role = "user", content = "Who is the prime minister of the uk?"),
#' list(role = "assistant", content = "Rishi Sunak"),
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(
model = model,
messages = messages,
stream = stream,
keep_alive = keep_alive
)
opts <- list(...)
if (length(opts) > 0) {
if (validate_options(...)) {
body_json$options <- opts
} else {
stop("Invalid model options passed to ... argument. Please check the model options and try again.")
}
}

req <- httr2::req_body_json(req, body_json)

content <- ""
if (!stream) {
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
return(resp_process(resp = resp, output = output[1]))
},
error = function(e) {
stop(e)
}
)
}

# streaming
env <- new.env()
env$buffer <- ""
env$content <- ""
env$accumulated_data <- raw()
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
cat("\n\n")

# process streaming output
json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
role = character(length(json_lines_output)),
content = character(length(json_lines_output)),
created_at = character(length(json_lines_output))
)

if (output[1] == "raw") {
return(rawToChar(env$accumulated_data))
}

for (i in seq_along(json_lines)) {
json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]])
df_response$model[i] <- json_lines_output[[i]]$model
df_response$role[i] <- json_lines_output[[i]]$message$role
df_response$content[i] <- json_lines_output[[i]]$message$content
df_response$created_at[i] <- json_lines_output[[i]]$created_at
}

if (output[1] == "jsonlist") {
return(json_lines_output)
}

if (output[1] == "df") {
return(df_response)
}

if (output[1] == "text") {
return(paste0(df_response$content, collapse = ""))
}

return(resp)
}






#' Pull/download a model
#'
Expand All @@ -216,18 +97,17 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
#' pull("llama3")
#' pull("all-minilm", stream = FALSE)
pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) {

req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model, stream = stream, insecure = insecure)
req <- httr2::req_body_json(req, body_json)

content <- ""
if (!stream) {
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
return(resp)
},
error = function(e) {
Expand Down Expand Up @@ -272,7 +152,6 @@ delete <- function(model, endpoint = "/api/delete", host = NULL) {
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
return(resp)
},
error = function(e) {
Expand All @@ -294,6 +173,8 @@ normalize <- function(x) {

#' Get vector embedding for a single prompt
#'
#' This function will be deprecated over time and has been superceded by `embed()`. See `embed()` for more details.
#'
#' @param model A character string of the model name such as "llama3".
#' @param prompt A character string of the prompt that you want to get the vector embedding for.
#' @param normalize Normalize the vector to length 1. Default is TRUE.
Expand Down Expand Up @@ -328,7 +209,6 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
v <- unlist(resp_process(resp, "jsonlist")$embedding)
if (normalize) {
v <- normalize(v)
Expand Down Expand Up @@ -390,7 +270,6 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive =
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
json_body <- httr2::resp_body_json(resp)$embeddings
m <- do.call(cbind, lapply(json_body, function(x) {
v <- unlist(x)
Expand Down Expand Up @@ -475,12 +354,10 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out

req <- httr2::req_body_json(req, body_json, stream = stream)

content <- ""
if (!stream) {
tryCatch(
{
resp <- httr2::req_perform(req)
print(resp)
return(resp_process(resp = resp, output = output[1]))
},
error = function(e) {
Expand Down Expand Up @@ -532,3 +409,130 @@ generate <- function(model, prompt, system = "", template = "", raw = FALSE, out

return(resp)
}











#' Chat with Ollama models
#'
#' @param model A character string of the model name such as "llama3".
#' @param messages A list with list of messages for the model (see examples below).
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @return A response in the format specified in the output parameter.
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' # one message
#' messages <- list(
#' list(role = "user", content = "How are you doing?")
#' )
#' chat("llama3", messages) # returns response by default
#' chat("llama3", messages, "text") # returns text/vector
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, stream = TRUE) # stream response
#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe
#'
#' # multiple messages
#' messages <- list(
#' list(role = "user", content = "Hello!"),
#' list(role = "assistant", content = "Hi! How are you?"),
#' list(role = "user", content = "Who is the prime minister of the uk?"),
#' list(role = "assistant", content = "Rishi Sunak"),
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(
model = model,
messages = messages,
stream = stream,
keep_alive = keep_alive
)

opts <- list(...)
if (length(opts) > 0) {
if (validate_options(...)) {
body_json$options <- opts
} else {
stop("Invalid model options passed to ... argument. Please check the model options and try again.")
}
}

req <- httr2::req_body_json(req, body_json)

if (!stream) {
tryCatch(
{
resp <- httr2::req_perform(req)
return(resp_process(resp = resp, output = output[1]))
},
error = function(e) {
stop(e)
}
)
}

# streaming
env <- new.env()
env$buffer <- ""
env$content <- ""
env$accumulated_data <- raw()
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
cat("\n\n")

# process streaming output
json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
role = character(length(json_lines_output)),
content = character(length(json_lines_output)),
created_at = character(length(json_lines_output))
)

if (output[1] == "raw") {
return(rawToChar(env$accumulated_data))
}

for (i in seq_along(json_lines)) {
json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]])
df_response$model[i] <- json_lines_output[[i]]$model
df_response$role[i] <- json_lines_output[[i]]$message$role
df_response$content[i] <- json_lines_output[[i]]$message$content
df_response$created_at[i] <- json_lines_output[[i]]$created_at
}

if (output[1] == "jsonlist") {
return(json_lines_output)
}

if (output[1] == "df") {
return(df_response)
}

if (output[1] == "text") {
return(paste0(df_response$content, collapse = ""))
}

return(resp)
}



16 changes: 5 additions & 11 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@ stream_handler <- function(x, env, endpoint) {
tryCatch(
{
json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "")

if (endpoint == "/api/generate") {
stream_content <- jsonlite::fromJSON(json_string)$response
env$content <- c(env$content, stream_content)
env$buffer <- ""
cat(stream_content) # stream/print stream
} else if (endpoint == "/api/chat") {
stream_content <- jsonlite::fromJSON(json_string)$message$content
env$content <- c(env$content, stream_content)
env$buffer <- ""
cat(stream_content)
} else if (endpoint == "/api/pull") {
json_string <- paste0(env$buffer, json_strings[i], "\n", collapse = "")
stream_content <- jsonlite::fromJSON(json_string)$status
env$content <<- c(env$content, stream_content)
env$buffer <<- ""
cat(stream_content, "\n")
stream_content <- paste0(stream_content, "\n")
}
# concatenate the content
env$content <- c(env$content, stream_content)
env$buffer <- ""
cat(stream_content) # stream/print stream
},
error = function(e) {
env$buffer <- paste0(env$buffer, json_strings[i])
Expand Down
Loading

0 comments on commit 159d0db

Please sign in to comment.