From 5fabd1da89e5f25cdf8d8a181b50a88e01f75bad Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Thu, 5 Dec 2024 18:07:53 -0500 Subject: [PATCH 1/2] [Vertex AI] Add ImageGenerationResponse for decoding predict response --- .../Imagen/ImageGenerationResponse.swift | 124 ++++++ .../Imagen/ImageGenerationResponseTests.swift | 376 ++++++++++++++++++ 2 files changed, 500 insertions(+) create mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift new file mode 100644 index 00000000000..7988d5811cc --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift @@ -0,0 +1,124 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct ImageGenerationResponse { + let images: [Image] + let raiFilteredReason: String? +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationResponse { + struct Image: Equatable { + let mimeType: String + let bytesBase64Encoded: String? + let gcsURI: String? + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationResponse { + struct RAIFilteredReason { + let raiFilteredReason: String + } +} + +// MARK: - Codable Conformances + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationResponse.Image: Decodable { + enum CodingKeys: String, CodingKey { + case mimeType + case bytesBase64Encoded + case gcsURI = "gcsUri" + } + + init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + mimeType = try container.decode(String.self, forKey: .mimeType) + bytesBase64Encoded = try container.decodeIfPresent(String.self, forKey: .bytesBase64Encoded) + gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI) + guard bytesBase64Encoded != nil || gcsURI != nil else { + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI], + debugDescription: """ + Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \ + \(CodingKeys.gcsURI.rawValue); both are nil. + """ + ) + ) + } + guard bytesBase64Encoded == nil || gcsURI == nil else { + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI], + debugDescription: """ + Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \ + \(CodingKeys.gcsURI.rawValue); both are specified. + """ + ) + ) + } + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationResponse.RAIFilteredReason: Decodable { + enum CodingKeys: CodingKey { + case raiFilteredReason + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationResponse: Decodable { + enum CodingKeys: CodingKey { + case predictions + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + guard container.contains(.predictions) else { + images = [] + raiFilteredReason = nil + // TODO: Log warning if no predictions. + return + } + var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions) + + var images = [Image]() + var raiFilteredReasons = [String]() + while !predictionsContainer.isAtEnd { + if let image = try? predictionsContainer.decode(Image.self) { + images.append(image) + } else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) { + raiFilteredReasons.append(filterReason.raiFilteredReason) + } else if let _ = try? predictionsContainer.decode(JSONObject.self) { + // TODO: Log or throw unsupported prediction type + } else { + // This should never be thrown since JSONObject accepts any valid JSON. + throw DecodingError.dataCorruptedError( + in: predictionsContainer, + debugDescription: "Failed to decode Prediction." + ) + } + } + + self.images = images + raiFilteredReason = raiFilteredReasons.first + // TODO: Log if more than one RAI Filtered Reason; unexpected behaviour. + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift new file mode 100644 index 00000000000..6178b217c41 --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift @@ -0,0 +1,376 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class ImageGenerationResponseTests: XCTestCase { + let decoder = JSONDecoder() + + // MARK: - Image Decoding + + func testDecodeImage_bytesBase64Encoded() throws { + let mimeType = "image/png" + let bytesBase64Encoded = "test-base64-bytes" + let json = """ + { + "bytesBase64Encoded": "\(bytesBase64Encoded)", + "mimeType": "\(mimeType)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.bytesBase64Encoded, bytesBase64Encoded) + XCTAssertNil(image.gcsURI) + } + + func testDecodeImage_gcsURI() throws { + let gcsURI = "gs://test-bucket/images/123456789/sample_0.png" + let mimeType = "image/jpeg" + let json = """ + { + "mimeType": "\(mimeType)", + "gcsUri": "\(gcsURI)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.gcsURI, gcsURI) + XCTAssertNil(image.bytesBase64Encoded) + } + + func testDecodeImage_missingBytesBase64EncodedAndGCSURI_throws() throws { + let json = """ + { + "mimeType": "image/jpeg" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.dataCorrupted(context) { + let codingPath = try XCTUnwrap(context + .codingPath as? [ImageGenerationResponse.Image.CodingKeys]) + XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) + XCTAssertTrue(context.debugDescription.contains("both are nil")) + } catch { + XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") + } + } + + func testDecodeImage_bothBytesBase64EncodedAndGCSURI_throws() throws { + let json = """ + { + "bytesBase64Encoded": "test-base64-bytes", + "mimeType": "image/png", + "gcsUri": "gs://test-bucket/images/123456789/sample_0.png" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.dataCorrupted(context) { + let codingPath = try XCTUnwrap(context + .codingPath as? [ImageGenerationResponse.Image.CodingKeys]) + XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) + XCTAssertTrue(context.debugDescription.contains("both are specified")) + } catch { + XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") + } + } + + // MARK: - RAI Filtered Reason Decoding + + func testDecodeRAIFilteredReason() throws { + let raiFilteredReason = """ + Unable to show generated images. All images were filtered out because they violated Vertex \ + AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. \ + If you think this was an error, send feedback. Support codes: 1234567 + """ + let json = """ + { + "raiFilteredReason": "\(raiFilteredReason)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let filterReason = try decoder.decode( + ImageGenerationResponse.RAIFilteredReason.self, + from: jsonData + ) + + XCTAssertEqual(filterReason.raiFilteredReason, raiFilteredReason) + } + + func testDecodeRAIFilteredReason_reasonNotSpecified_throws() throws { + let json = """ + { + "otherField": "test-value" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImageGenerationResponse.RAIFilteredReason.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap( + codingKey as? ImageGenerationResponse.RAIFilteredReason.CodingKeys + ) + XCTAssertEqual(codingKey, .raiFilteredReason) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } + + // MARK: - Image Generation Response Decoding + + func testDecodeResponse_oneBase64Image_noneFiltered() throws { + let mimeType = "image/png" + let bytesBase64Encoded = "test-base64-bytes" + let image = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded, + gcsURI: nil + ) + let json = """ + { + "predictions": [ + { + "bytesBase64Encoded": "\(bytesBase64Encoded)", + "mimeType": "\(mimeType)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, [image]) + XCTAssertNil(response.raiFilteredReason) + } + + func testDecodeResponse_multipleBase64Images_noneFiltered() throws { + let mimeType = "image/png" + let bytesBase64Encoded1 = "test-base64-bytes-1" + let bytesBase64Encoded2 = "test-base64-bytes-2" + let bytesBase64Encoded3 = "test-base64-bytes-3" + let image1 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded1, + gcsURI: nil + ) + let image2 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded2, + gcsURI: nil + ) + let image3 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded3, + gcsURI: nil + ) + let json = """ + { + "predictions": [ + { + "bytesBase64Encoded": "\(bytesBase64Encoded1)", + "mimeType": "\(mimeType)" + }, + { + "bytesBase64Encoded": "\(bytesBase64Encoded2)", + "mimeType": "\(mimeType)" + }, + { + "bytesBase64Encoded": "\(bytesBase64Encoded3)", + "mimeType": "\(mimeType)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, [image1, image2, image3]) + XCTAssertNil(response.raiFilteredReason) + } + + func testDecodeResponse_multipleBase64Images_someFiltered() throws { + let mimeType = "image/png" + let bytesBase64Encoded1 = "test-base64-bytes-1" + let bytesBase64Encoded2 = "test-base64-bytes-2" + let image1 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded1, + gcsURI: nil + ) + let image2 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: bytesBase64Encoded2, + gcsURI: nil + ) + let raiFilteredReason = """ + Your current safety filter threshold filtered out 2 generated images. You will not be charged \ + for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback. + """ + let json = """ + { + "predictions": [ + { + "bytesBase64Encoded": "\(bytesBase64Encoded1)", + "mimeType": "\(mimeType)" + }, + { + "bytesBase64Encoded": "\(bytesBase64Encoded2)", + "mimeType": "\(mimeType)" + }, + { + "raiFilteredReason": "\(raiFilteredReason)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, [image1, image2]) + XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) + } + + func testDecodeResponse_multipleGCSImages_noneFiltered() throws { + let mimeType = "image/png" + let gcsURI1 = "gs://test-bucket/images/123456789/sample_0.png" + let gcsURI2 = "gs://test-bucket/images/123456789/sample_1.png" + let image1 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: nil, + gcsURI: gcsURI1 + ) + let image2 = ImageGenerationResponse.Image( + mimeType: mimeType, + bytesBase64Encoded: nil, + gcsURI: gcsURI2 + ) + let json = """ + { + "predictions": [ + { + "gcsUri": "\(gcsURI1)", + "mimeType": "\(mimeType)" + }, + { + "gcsUri": "\(gcsURI2)", + "mimeType": "\(mimeType)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, [image1, image2]) + XCTAssertNil(response.raiFilteredReason) + } + + func testDecodeResponse_noImages_allFiltered() throws { + let raiFilteredReason = """ + Unable to show generated images. All images were filtered out because they violated Vertex \ + AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. \ + If you think this was an error, send feedback. Support codes: 1234567 + """ + let json = """ + { + "predictions": [ + { + "raiFilteredReason": "\(raiFilteredReason)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, []) + XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) + } + + func testDecodeResponse_noImagesAnd_noFilteredReason() throws { + let json = "{}" + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, []) + XCTAssertNil(response.raiFilteredReason) + } + + func testDecodeResponse_multipleFilterReasons_returnsFirst() throws { + let raiFilteredReason1 = "filtered-reason-1" + let raiFilteredReason2 = "filtered-reason-2" + let json = """ + { + "predictions": [ + { + "raiFilteredReason": "\(raiFilteredReason1)" + }, + { + "raiFilteredReason": "\(raiFilteredReason2)" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, []) + XCTAssertEqual(response.raiFilteredReason, raiFilteredReason1) + XCTAssertNotEqual(response.raiFilteredReason, raiFilteredReason2) + } + + func testDecodeResponse_unknownPrediction() throws { + let json = """ + { + "predictions": [ + { + "someField": "some-value" + }, + ] + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + + XCTAssertEqual(response.images, []) + XCTAssertNil(response.raiFilteredReason) + } +} From 67c37ed6a65e49ceb071a7f623df3c770d6f7deb Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Thu, 5 Dec 2024 19:43:00 -0500 Subject: [PATCH 2/2] Add ImagenImage and ImagenImageRepresentable --- .../Imagen/DecodableImagenImage.swift | 62 ++++++++ .../Imagen/ImageGenerationResponse.swift | 67 +------- .../Internal/Imagen/InternalImagenImage.swift | 32 ++++ .../Internal/Imagen/RAIFilteredReason.swift | 25 +++ .../Types/Public/Imagen/ImagenImage.swift | 22 +++ .../Imagen/ImagenImageRepresentable.swift | 27 ++++ .../Imagen/ImageGenerationResponseTests.swift | 148 +----------------- .../Imagen/InternalImagenImageTests.swift | 101 ++++++++++++ .../Types/Imagen/RAIFilteredReasonTests.swift | 64 ++++++++ 9 files changed, 344 insertions(+), 204 deletions(-) create mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift create mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift create mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/RAIFilteredReason.swift create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImage.swift create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImageRepresentable.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/RAIFilteredReasonTests.swift diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift new file mode 100644 index 00000000000..2b4cc888830 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +protocol DecodableImagenImage: ImagenImage, Decodable { + init(mimeType: String, bytesBase64Encoded: String?, gcsURI: String?) +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +enum ImagenImageCodingKeys: String, CodingKey { + case mimeType + case bytesBase64Encoded + case gcsURI = "gcsUri" +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension DecodableImagenImage { + init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: ImagenImageCodingKeys.self) + let mimeType = try container.decode(String.self, forKey: .mimeType) + let bytesBase64Encoded = try container.decodeIfPresent( + String.self, + forKey: .bytesBase64Encoded + ) + let gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI) + guard bytesBase64Encoded != nil || gcsURI != nil else { + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI], + debugDescription: """ + Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \ + \(ImagenImageCodingKeys.gcsURI.rawValue); both are nil. + """ + ) + ) + } + guard bytesBase64Encoded == nil || gcsURI == nil else { + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI], + debugDescription: """ + Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \ + \(ImagenImageCodingKeys.gcsURI.rawValue); both are specified. + """ + ) + ) + } + + self.init(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded, gcsURI: gcsURI) + } +} diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift index 7988d5811cc..bb1b2809bd1 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift @@ -16,73 +16,12 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct ImageGenerationResponse { - let images: [Image] + let images: [InternalImagenImage] let raiFilteredReason: String? } -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse { - struct Image: Equatable { - let mimeType: String - let bytesBase64Encoded: String? - let gcsURI: String? - } -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse { - struct RAIFilteredReason { - let raiFilteredReason: String - } -} - // MARK: - Codable Conformances -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse.Image: Decodable { - enum CodingKeys: String, CodingKey { - case mimeType - case bytesBase64Encoded - case gcsURI = "gcsUri" - } - - init(from decoder: any Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - mimeType = try container.decode(String.self, forKey: .mimeType) - bytesBase64Encoded = try container.decodeIfPresent(String.self, forKey: .bytesBase64Encoded) - gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI) - guard bytesBase64Encoded != nil || gcsURI != nil else { - throw DecodingError.dataCorrupted( - DecodingError.Context( - codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI], - debugDescription: """ - Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \ - \(CodingKeys.gcsURI.rawValue); both are nil. - """ - ) - ) - } - guard bytesBase64Encoded == nil || gcsURI == nil else { - throw DecodingError.dataCorrupted( - DecodingError.Context( - codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI], - debugDescription: """ - Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \ - \(CodingKeys.gcsURI.rawValue); both are specified. - """ - ) - ) - } - } -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse.RAIFilteredReason: Decodable { - enum CodingKeys: CodingKey { - case raiFilteredReason - } -} - @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension ImageGenerationResponse: Decodable { enum CodingKeys: CodingKey { @@ -99,10 +38,10 @@ extension ImageGenerationResponse: Decodable { } var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions) - var images = [Image]() + var images = [InternalImagenImage]() var raiFilteredReasons = [String]() while !predictionsContainer.isAtEnd { - if let image = try? predictionsContainer.decode(Image.self) { + if let image = try? predictionsContainer.decode(InternalImagenImage.self) { images.append(image) } else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) { raiFilteredReasons.append(filterReason.raiFilteredReason) diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift new file mode 100644 index 00000000000..a9f175b9241 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct InternalImagenImage { + let mimeType: String + let bytesBase64Encoded: String? + let gcsURI: String? + + init(mimeType: String, bytesBase64Encoded: String?, gcsURI: String?) { + self.mimeType = mimeType + self.bytesBase64Encoded = bytesBase64Encoded + self.gcsURI = gcsURI + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension InternalImagenImage: DecodableImagenImage {} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension InternalImagenImage: Equatable {} diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/RAIFilteredReason.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/RAIFilteredReason.swift new file mode 100644 index 00000000000..d1bb64da9cc --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/RAIFilteredReason.swift @@ -0,0 +1,25 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct RAIFilteredReason { + let raiFilteredReason: String +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension RAIFilteredReason: Decodable { + enum CodingKeys: CodingKey { + case raiFilteredReason + } +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImage.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImage.swift new file mode 100644 index 00000000000..4a5e90fc785 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImage.swift @@ -0,0 +1,22 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public protocol ImagenImage: ImagenImageRepresentable { + var mimeType: String { get } + var bytesBase64Encoded: String? { get } + var gcsURI: String? { get } +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImageRepresentable.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImageRepresentable.swift new file mode 100644 index 00000000000..816bbed5d28 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImageRepresentable.swift @@ -0,0 +1,27 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public protocol ImagenImageRepresentable { + var imagenImage: any ImagenImage { get } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public extension ImagenImage { + var imagenImage: any ImagenImage { + return self + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift index 6178b217c41..ab66bffa6e8 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift @@ -20,138 +20,10 @@ import XCTest final class ImageGenerationResponseTests: XCTestCase { let decoder = JSONDecoder() - // MARK: - Image Decoding - - func testDecodeImage_bytesBase64Encoded() throws { - let mimeType = "image/png" - let bytesBase64Encoded = "test-base64-bytes" - let json = """ - { - "bytesBase64Encoded": "\(bytesBase64Encoded)", - "mimeType": "\(mimeType)" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let image = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) - - XCTAssertEqual(image.mimeType, mimeType) - XCTAssertEqual(image.bytesBase64Encoded, bytesBase64Encoded) - XCTAssertNil(image.gcsURI) - } - - func testDecodeImage_gcsURI() throws { - let gcsURI = "gs://test-bucket/images/123456789/sample_0.png" - let mimeType = "image/jpeg" - let json = """ - { - "mimeType": "\(mimeType)", - "gcsUri": "\(gcsURI)" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let image = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) - - XCTAssertEqual(image.mimeType, mimeType) - XCTAssertEqual(image.gcsURI, gcsURI) - XCTAssertNil(image.bytesBase64Encoded) - } - - func testDecodeImage_missingBytesBase64EncodedAndGCSURI_throws() throws { - let json = """ - { - "mimeType": "image/jpeg" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - do { - _ = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) - XCTFail("Expected an error; none thrown.") - } catch let DecodingError.dataCorrupted(context) { - let codingPath = try XCTUnwrap(context - .codingPath as? [ImageGenerationResponse.Image.CodingKeys]) - XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) - XCTAssertTrue(context.debugDescription.contains("both are nil")) - } catch { - XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") - } - } - - func testDecodeImage_bothBytesBase64EncodedAndGCSURI_throws() throws { - let json = """ - { - "bytesBase64Encoded": "test-base64-bytes", - "mimeType": "image/png", - "gcsUri": "gs://test-bucket/images/123456789/sample_0.png" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - do { - _ = try decoder.decode(ImageGenerationResponse.Image.self, from: jsonData) - XCTFail("Expected an error; none thrown.") - } catch let DecodingError.dataCorrupted(context) { - let codingPath = try XCTUnwrap(context - .codingPath as? [ImageGenerationResponse.Image.CodingKeys]) - XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) - XCTAssertTrue(context.debugDescription.contains("both are specified")) - } catch { - XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") - } - } - - // MARK: - RAI Filtered Reason Decoding - - func testDecodeRAIFilteredReason() throws { - let raiFilteredReason = """ - Unable to show generated images. All images were filtered out because they violated Vertex \ - AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. \ - If you think this was an error, send feedback. Support codes: 1234567 - """ - let json = """ - { - "raiFilteredReason": "\(raiFilteredReason)" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let filterReason = try decoder.decode( - ImageGenerationResponse.RAIFilteredReason.self, - from: jsonData - ) - - XCTAssertEqual(filterReason.raiFilteredReason, raiFilteredReason) - } - - func testDecodeRAIFilteredReason_reasonNotSpecified_throws() throws { - let json = """ - { - "otherField": "test-value" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - do { - _ = try decoder.decode(ImageGenerationResponse.RAIFilteredReason.self, from: jsonData) - XCTFail("Expected an error; none thrown.") - } catch let DecodingError.keyNotFound(codingKey, _) { - let codingKey = try XCTUnwrap( - codingKey as? ImageGenerationResponse.RAIFilteredReason.CodingKeys - ) - XCTAssertEqual(codingKey, .raiFilteredReason) - } catch { - XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") - } - } - - // MARK: - Image Generation Response Decoding - func testDecodeResponse_oneBase64Image_noneFiltered() throws { let mimeType = "image/png" let bytesBase64Encoded = "test-base64-bytes" - let image = ImageGenerationResponse.Image( + let image = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded, gcsURI: nil @@ -179,17 +51,17 @@ final class ImageGenerationResponseTests: XCTestCase { let bytesBase64Encoded1 = "test-base64-bytes-1" let bytesBase64Encoded2 = "test-base64-bytes-2" let bytesBase64Encoded3 = "test-base64-bytes-3" - let image1 = ImageGenerationResponse.Image( + let image1 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1, gcsURI: nil ) - let image2 = ImageGenerationResponse.Image( + let image2 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2, gcsURI: nil ) - let image3 = ImageGenerationResponse.Image( + let image3 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded3, gcsURI: nil @@ -224,12 +96,12 @@ final class ImageGenerationResponseTests: XCTestCase { let mimeType = "image/png" let bytesBase64Encoded1 = "test-base64-bytes-1" let bytesBase64Encoded2 = "test-base64-bytes-2" - let image1 = ImageGenerationResponse.Image( + let image1 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1, gcsURI: nil ) - let image2 = ImageGenerationResponse.Image( + let image2 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2, gcsURI: nil @@ -267,16 +139,12 @@ final class ImageGenerationResponseTests: XCTestCase { let mimeType = "image/png" let gcsURI1 = "gs://test-bucket/images/123456789/sample_0.png" let gcsURI2 = "gs://test-bucket/images/123456789/sample_1.png" - let image1 = ImageGenerationResponse.Image( + let image1 = InternalImagenImage( mimeType: mimeType, bytesBase64Encoded: nil, gcsURI: gcsURI1 ) - let image2 = ImageGenerationResponse.Image( - mimeType: mimeType, - bytesBase64Encoded: nil, - gcsURI: gcsURI2 - ) + let image2 = InternalImagenImage(mimeType: mimeType, bytesBase64Encoded: nil, gcsURI: gcsURI2) let json = """ { "predictions": [ diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift new file mode 100644 index 00000000000..c66cc052785 --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class InternalImagenImageTests: XCTestCase { + let decoder = JSONDecoder() + + func testDecodeImage_bytesBase64Encoded() throws { + let mimeType = "image/png" + let bytesBase64Encoded = "test-base64-bytes" + let json = """ + { + "bytesBase64Encoded": "\(bytesBase64Encoded)", + "mimeType": "\(mimeType)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(InternalImagenImage.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.bytesBase64Encoded, bytesBase64Encoded) + XCTAssertNil(image.gcsURI) + } + + func testDecodeImage_gcsURI() throws { + let gcsURI = "gs://test-bucket/images/123456789/sample_0.png" + let mimeType = "image/jpeg" + let json = """ + { + "mimeType": "\(mimeType)", + "gcsUri": "\(gcsURI)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(InternalImagenImage.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.gcsURI, gcsURI) + XCTAssertNil(image.bytesBase64Encoded) + } + + func testDecodeImage_missingBytesBase64EncodedAndGCSURI_throws() throws { + let json = """ + { + "mimeType": "image/jpeg" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(InternalImagenImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.dataCorrupted(context) { + let codingPath = try XCTUnwrap(context + .codingPath as? [ImagenImageCodingKeys]) + XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) + XCTAssertTrue(context.debugDescription.contains("both are nil")) + } catch { + XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") + } + } + + func testDecodeImage_bothBytesBase64EncodedAndGCSURI_throws() throws { + let json = """ + { + "bytesBase64Encoded": "test-base64-bytes", + "mimeType": "image/png", + "gcsUri": "gs://test-bucket/images/123456789/sample_0.png" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(InternalImagenImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.dataCorrupted(context) { + let codingPath = try XCTUnwrap(context.codingPath as? [ImagenImageCodingKeys]) + XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) + XCTAssertTrue(context.debugDescription.contains("both are specified")) + } catch { + XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") + } + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/RAIFilteredReasonTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/RAIFilteredReasonTests.swift new file mode 100644 index 00000000000..0282eca58ce --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/RAIFilteredReasonTests.swift @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class RAIFilteredReasonTests: XCTestCase { + let decoder = JSONDecoder() + + func testDecodeRAIFilteredReason() throws { + let raiFilteredReason = """ + Unable to show generated images. All images were filtered out because they violated Vertex \ + AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. \ + If you think this was an error, send feedback. Support codes: 1234567 + """ + let json = """ + { + "raiFilteredReason": "\(raiFilteredReason)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let filterReason = try decoder.decode( + RAIFilteredReason.self, + from: jsonData + ) + + XCTAssertEqual(filterReason.raiFilteredReason, raiFilteredReason) + } + + func testDecodeRAIFilteredReason_reasonNotSpecified_throws() throws { + let json = """ + { + "otherField": "test-value" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(RAIFilteredReason.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap( + codingKey as? RAIFilteredReason.CodingKeys + ) + XCTAssertEqual(codingKey, .raiFilteredReason) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } +}