Skip to content

Commit

Permalink
Merge pull request #5 from hauselin/tawab
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin authored Jul 10, 2024
2 parents 0a92633 + 2bcf330 commit b1f4bf4
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master, dev, test]
branches: [main, master, dev, test, tawab]
pull_request:
branches: [main, master, dev, test]
branches: [main, master, dev, test, tawab]

name: R-CMD-check

Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Description: An interface to easily run local language models with 'Ollama' <htt
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Expand Down
26 changes: 16 additions & 10 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @return A response in the format specified in the output parameter.
#' @export
Expand All @@ -120,9 +121,9 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/chat") {
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/chat", host = NULL) {

req <- create_request(endpoint)
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model,
Expand Down Expand Up @@ -216,15 +217,16 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
#' @param model A character string of the model name such as "llama3".
#' @param stream Enable response streaming. Default is TRUE.
#' @param endpoint The endpoint to pull the model. Default is "/api/pull".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @return A httr2 response object.
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' pull("llama3")
#" pull("all-minilm", stream = FALSE)
pull <- function(model, stream = TRUE, endpoint = "/api/pull") {
req <- create_request(endpoint)
pull <- function(model, stream = TRUE, endpoint = "/api/pull", host = NULL) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model, stream = stream)
Expand Down Expand Up @@ -275,6 +277,7 @@ pull <- function(model, stream = TRUE, endpoint = "/api/pull") {
#'
#' @param model A character string of the model name such as "llama3".
#' @param endpoint The endpoint to delete the model. Default is "/api/delete".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @return A httr2 response object.
#' @export
Expand All @@ -283,8 +286,8 @@ pull <- function(model, stream = TRUE, endpoint = "/api/pull") {
#' \dontrun{
#' delete("llama3")
#' }
delete <- function(model, endpoint = "/api/delete") {
req <- create_request(endpoint)
delete <- function(model, endpoint = "/api/delete", host = NULL) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "DELETE")
body_json <- list(model = model)
req <- httr2::req_body_json(req, body_json)
Expand Down Expand Up @@ -313,6 +316,7 @@ normalize <- function(x) {
#' @param normalize Normalize the vector to length 1. Default is TRUE.
#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes).
#' @param endpoint The endpoint to get the vector embedding. Default is "/api/embeddings".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @return A numeric vector of the embedding.
Expand All @@ -322,8 +326,8 @@ normalize <- function(x) {
#' embeddings("nomic-embed-text:latest", "The quick brown fox jumps over the lazy dog.")
#' # pass model options to the model
#' embeddings("nomic-embed-text:latest", "Hello!", temperature = 0.1, num_predict = 3)
embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpoint = "/api/embeddings", ...) {
req <- create_request(endpoint)
embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpoint = "/api/embeddings", host = NULL, ...) {
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

opts <- list(...)
Expand Down Expand Up @@ -367,6 +371,8 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo
#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".
#' @param stream Enable response streaming. Default is FALSE.
#' @param endpoint The endpoint to generate the completion. Default is "/api/generate".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#'
#' @return A response in the format specified in the output parameter.
#' @export
Expand All @@ -375,9 +381,9 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo
#' generate("llama3", "The sky is...", stream = FALSE, output = "df")
#' generate("llama3", "The sky is...", stream = TRUE, output = "df")
#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist")
generate <- function(model, prompt, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/generate") {
generate <- function(model, prompt, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, endpoint = "/api/generate", host = NULL) {

req <- create_request(endpoint)
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model,
Expand Down
5 changes: 4 additions & 1 deletion man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/delete.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/embeddings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/generate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/pull.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b1f4bf4

Please sign in to comment.