Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Bypass proxy for testing against staging #14235

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Foundation
/// Constants associated with the Vertex AI for Firebase SDK.
enum Constants {
/// The Vertex AI backend endpoint URL.
static let baseURL = "https://firebasevertexai.googleapis.com"
static let baseURL = "staging-aiplatform.sandbox.googleapis.com"

/// The base reverse-DNS name for `NSError` or `CustomNSError` error domains.
///
Expand Down
5 changes: 4 additions & 1 deletion FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Foundation
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String
let location: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
Expand All @@ -31,7 +32,9 @@ extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var url: URL {
URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens")!
URL(
string: "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens"
)!
}
}

Expand Down
3 changes: 2 additions & 1 deletion FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Foundation
struct GenerateContentRequest {
/// Model name.
let model: String
let location: String
let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
Expand Down Expand Up @@ -45,7 +46,7 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
let modelURL = "\(Constants.baseURL)/\(options.apiVersion)/\(model)"
let modelURL = "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model)"
if isStreaming {
return URL(string: "\(modelURL):streamGenerateContent?alt=sse")!
} else {
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public struct RequestOptions {
let timeout: TimeInterval

/// The API version to use in requests to the backend.
let apiVersion = "v1beta"
let apiVersion = "v1beta1"

/// Initializes a request options object.
///
Expand Down
24 changes: 8 additions & 16 deletions FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,28 +180,20 @@ struct GenerativeAIService {
private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
urlRequest.httpMethod = "POST"
urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
guard let accessToken = ProcessInfo.processInfo.environment["GCLOUD_ACCESS_TOKEN"] else {
fatalError("""
Missing access token; run `gcloud auth print-access-token` and add an environment variable \
`GCLOUD_ACCESS_TOKEN` with the printed value.
Note: This value will only be valid for 60 minutes.
""")
}
urlRequest.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization")
urlRequest.setValue(
"\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
forHTTPHeaderField: "x-goog-api-client"
)
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")

if let appCheck {
let tokenResult = await appCheck.getToken(forcingRefresh: false)
urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
if let error = tokenResult.error {
VertexLog.error(
code: .appCheckTokenFetchFailed,
"Failed to fetch AppCheck token. Error: \(error)"
)
}
}

if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
}

let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)
Expand Down
7 changes: 7 additions & 0 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class GenerativeModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let location: String

/// Configuration parameters used for the MultiModalModel.
let generationConfig: GenerationConfig?

Expand Down Expand Up @@ -61,6 +63,7 @@ public final class GenerativeModel {
init(name: String,
projectID: String,
apiKey: String,
location: String = "us-central1",
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]?,
Expand All @@ -78,6 +81,7 @@ public final class GenerativeModel {
auth: auth,
urlSession: urlSession
)
self.location = location
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.tools = tools
Expand Down Expand Up @@ -125,6 +129,7 @@ public final class GenerativeModel {
try content.throwIfError()
let response: GenerateContentResponse
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
location: location,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -182,6 +187,7 @@ public final class GenerativeModel {
-> AsyncThrowingStream<GenerateContentResponse, Error> {
try content.throwIfError()
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
location: location,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -253,6 +259,7 @@ public final class GenerativeModel {
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
location: location,
contents: content,
systemInstruction: systemInstruction,
tools: tools,
Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public class VertexAI {
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
location: location,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
Expand Down
Loading