Skip to content

Commit

Permalink
Add ImagenModel.imageGenerationParameters() tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 9, 2024
1 parent 24da495 commit 0d8f559
Showing 1 changed file with 169 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,175 @@ final class ImageGenerationParametersTests: XCTestCase {
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
}

func testDefaultParameters_noneSpecified() throws {
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testDefaultParameters_includeStorageURI() throws {
let storageURI = "gs://test-bucket/path"
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: storageURI,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testParameters_includeGenerationConfig() throws {
let sampleCount = 2
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.landscape16x9
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: nil,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
XCTAssertEqual(parameters.aspectRatio, "16:9")
XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg")
XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality)
}

func testDefaultParameters_includeSafetySettings() throws {
let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockOnlyHigh
let personGeneration = ImagenSafetySettings.PersonGeneration.allowAll
let includeFilterReason = false
let safetySettings = ImagenSafetySettings(
safetyFilterLevel: safetyFilterLevel,
includeFilterReason: includeFilterReason,
personGeneration: personGeneration
)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: safetyFilterLevel.rawValue,
personGeneration: personGeneration.rawValue,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: includeFilterReason
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
safetySettings: safetySettings
)

XCTAssertEqual(parameters, expectedParameters)
XCTAssertEqual(parameters.safetyFilterLevel, "block_only_high")
XCTAssertEqual(parameters.personGeneration, "allow_all")
}

func testParameters_includeAll() throws {
let storageURI = "gs://test-bucket/path"
let sampleCount = 4
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.portrait3x4
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockNone
let personGeneration = ImagenSafetySettings.PersonGeneration.blockAll
let includeFilterReason = false
let safetySettings = ImagenSafetySettings(
safetyFilterLevel: safetyFilterLevel,
includeFilterReason: includeFilterReason,
personGeneration: personGeneration
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: storageURI,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: safetyFilterLevel.rawValue,
personGeneration: personGeneration.rawValue,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: includeFilterReason
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
safetySettings: safetySettings
)

XCTAssertEqual(parameters, expectedParameters)
XCTAssertEqual(parameters.aspectRatio, "3:4")
XCTAssertEqual(parameters.safetyFilterLevel, "block_none")
XCTAssertEqual(parameters.personGeneration, "dont_allow")
XCTAssertEqual(parameters.outputOptions?.mimeType, "image/png")
XCTAssertNil(parameters.outputOptions?.compressionQuality)
}

// MARK: - Encoding Tests

func testEncodeParameters_allSpecified() throws {
Expand Down

0 comments on commit 0d8f559

Please sign in to comment.