Skip to content

Commit

Permalink
[Vertex AI] Bypass proxy for testing against staging
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 9, 2024
1 parent 16381c7 commit 742d7b0
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 20 deletions.
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Foundation
/// Constants associated with the Vertex AI for Firebase SDK.
enum Constants {
/// The Vertex AI backend endpoint URL.
static let baseURL = "https://firebasevertexai.googleapis.com"
static let baseURL = "staging-aiplatform.sandbox.googleapis.com"

/// The base reverse-DNS name for `NSError` or `CustomNSError` error domains.
///
Expand Down
5 changes: 4 additions & 1 deletion FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Foundation
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String
let location: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
Expand All @@ -31,7 +32,9 @@ extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var url: URL {
URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens")!
URL(
string: "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens"
)!
}
}

Expand Down
3 changes: 2 additions & 1 deletion FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Foundation
struct GenerateContentRequest {
/// Model name.
let model: String
let location: String
let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
Expand Down Expand Up @@ -45,7 +46,7 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
let modelURL = "\(Constants.baseURL)/\(options.apiVersion)/\(model)"
let modelURL = "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model)"
if isStreaming {
return URL(string: "\(modelURL):streamGenerateContent?alt=sse")!
} else {
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public struct RequestOptions {
let timeout: TimeInterval

/// The API version to use in requests to the backend.
let apiVersion = "v1beta"
let apiVersion = "v1beta1"

/// Initializes a request options object.
///
Expand Down
24 changes: 8 additions & 16 deletions FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,28 +180,20 @@ struct GenerativeAIService {
private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
urlRequest.httpMethod = "POST"
urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
guard let accessToken = ProcessInfo.processInfo.environment["GCLOUD_ACCESS_TOKEN"] else {
fatalError("""
Missing access token; run `gcloud auth print-access-token` and add an environment variable \
`GCLOUD_ACCESS_TOKEN` with the printed value.
Note: This value will only be valid for 60 minutes.
""")
}
urlRequest.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization")
urlRequest.setValue(
"\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
forHTTPHeaderField: "x-goog-api-client"
)
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")

if let appCheck {
let tokenResult = await appCheck.getToken(forcingRefresh: false)
urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
if let error = tokenResult.error {
VertexLog.error(
code: .appCheckTokenFetchFailed,
"Failed to fetch AppCheck token. Error: \(error)"
)
}
}

if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
}

let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)
Expand Down
7 changes: 7 additions & 0 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class GenerativeModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let location: String

/// Configuration parameters used for the MultiModalModel.
let generationConfig: GenerationConfig?

Expand Down Expand Up @@ -61,6 +63,7 @@ public final class GenerativeModel {
init(name: String,
projectID: String,
apiKey: String,
location: String = "us-central1",
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]?,
Expand All @@ -78,6 +81,7 @@ public final class GenerativeModel {
auth: auth,
urlSession: urlSession
)
self.location = location
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.tools = tools
Expand Down Expand Up @@ -125,6 +129,7 @@ public final class GenerativeModel {
try content.throwIfError()
let response: GenerateContentResponse
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
location: location,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -182,6 +187,7 @@ public final class GenerativeModel {
-> AsyncThrowingStream<GenerateContentResponse, Error> {
try content.throwIfError()
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
location: location,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -253,6 +259,7 @@ public final class GenerativeModel {
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
location: location,
contents: content,
systemInstruction: systemInstruction,
tools: tools,
Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public class VertexAI {
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
location: location,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
Expand Down

0 comments on commit 742d7b0

Please sign in to comment.