Skip to content

Commit

Permalink
[Vertex AI] Add ImagenSafetySettings
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 9, 2024
1 parent c5472fc commit 1229559
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 6 deletions.
10 changes: 7 additions & 3 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

init(name: String,
projectID: String,
apiKey: String,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
auth: AuthInterop?,
Expand All @@ -42,6 +45,7 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}

Expand Down Expand Up @@ -89,16 +93,16 @@ public final class ImagenModel {
seed: nil,
negativePrompt: generationConfig?.negativePrompt,
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personGeneration?.rawValue,
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: true
includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenSafetySettings {
let safetyFilterLevel: SafetyFilterLevel?
let includeFilterReason: Bool?
let personGeneration: PersonGeneration?

public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil,
personGeneration: PersonGeneration? = nil) {
self.safetyFilterLevel = safetyFilterLevel
self.includeFilterReason = includeFilterReason
self.personGeneration = personGeneration
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct SafetyFilterLevel: ProtoEnum {
enum Kind: String {
case blockLowAndAbove = "block_low_and_above"
case blockMediumAndAbove = "block_medium_and_above"
case blockOnlyHigh = "block_only_high"
case blockNone = "block_none"
}

public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove)
public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove)
public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh)
public static let blockNone = SafetyFilterLevel(kind: .blockNone)

let rawValue: String
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct PersonGeneration: ProtoEnum {
enum Kind: String {
case blockAll = "dont_allow"
case allowAdult = "allow_adult"
case allowAll = "allow_all"
}

public static let blockAll = PersonGeneration(kind: .blockAll)
public static let allowAdult = PersonGeneration(kind: .allowAdult)
public static let allowAll = PersonGeneration(kind: .allowAll)

let rawValue: String
}
}
5 changes: 3 additions & 2 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
-> ImagenModel {
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
auth: auth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ final class IntegrationTests: XCTestCase {
systemInstruction: systemInstruction
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001"
modelName: "imagen-3.0-fast-generate-001",
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personGeneration: .blockAll
)
)

storage = Storage.storage()
Expand Down

0 comments on commit 1229559

Please sign in to comment.