Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Simplify ModelContent initializers #13832

Merged
merged 21 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor FileDataPart
  • Loading branch information
andrewheard committed Oct 4, 2024
commit 2da135d80d0e1038466c7d411d66b8f7524289c9
44 changes: 13 additions & 31 deletions FirebaseVertexAI/Sources/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ public struct ModelContent: Equatable, Sendable {
case let .inlineData(mimetype, data):
convertedParts.append(InlineDataPart(data: data, mimeType: mimetype))
case let .fileData(mimetype, uri):
convertedParts.append(FileDataPart(fileData: FileData(mimeType: mimetype, uri: uri)))
convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype))
case let .functionCall(functionCall):
convertedParts.append(FunctionCallPart(functionCall: functionCall))
convertedParts.append(FunctionCallPart(functionCall))
case let .functionResponse(functionResponse):
convertedParts.append(FunctionResponsePart(functionResponse: functionResponse))
}
Expand All @@ -120,7 +120,7 @@ public struct ModelContent: Equatable, Sendable {
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
case let functionResponsePart as FunctionResponsePart:
Expand All @@ -145,7 +145,7 @@ public struct ModelContent: Equatable, Sendable {
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
case let functionResponsePart as FunctionResponsePart:
Expand Down Expand Up @@ -192,31 +192,15 @@ extension ModelContent.InternalPart: Codable {
case functionResponse
}

enum InlineDataKeys: String, CodingKey {
case mimeType = "mime_type"
case bytes = "data"
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case let .text(a0):
try container.encode(a0, forKey: .text)
case let .text(text):
try container.encode(text, forKey: .text)
case let .inlineData(mimetype, bytes):
var inlineDataContainer = container.nestedContainer(
keyedBy: InlineDataKeys.self,
forKey: .inlineData
)
try inlineDataContainer.encode(mimetype, forKey: .mimeType)
try inlineDataContainer.encode(bytes, forKey: .bytes)
try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData)
case let .fileData(mimetype: mimetype, url):
// var fileDataContainer = container.nestedContainer(
// keyedBy: FileDataKeys.self,
// forKey: .fileData
// )
try container.encode(FileData(mimeType: mimetype, uri: url), forKey: .fileData)
// try fileDataContainer.encode(mimetype, forKey: .mimeType)
// try fileDataContainer.encode(url, forKey: .uri)
try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData)
case let .functionCall(functionCall):
try container.encode(functionCall, forKey: .functionCall)
case let .functionResponse(functionResponse):
Expand All @@ -229,13 +213,11 @@ extension ModelContent.InternalPart: Codable {
if values.contains(.text) {
self = try .text(values.decode(String.self, forKey: .text))
} else if values.contains(.inlineData) {
let dataContainer = try values.nestedContainer(
keyedBy: InlineDataKeys.self,
forKey: .inlineData
)
let mimetype = try dataContainer.decode(String.self, forKey: .mimeType)
let bytes = try dataContainer.decode(Data.self, forKey: .bytes)
self = .inlineData(mimetype: mimetype, bytes)
let inlineData = try values.decode(InlineData.self, forKey: .inlineData)
self = .inlineData(mimetype: inlineData.mimeType, inlineData.data)
} else if values.contains(.fileData) {
let fileData = try values.decode(FileData.self, forKey: .fileData)
self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)
} else if values.contains(.functionCall) {
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
} else if values.contains(.functionResponse) {
Expand Down
11 changes: 11 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ struct InlineData: Codable, Equatable, Sendable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct FileData: Codable, Equatable, Sendable {
let fileURI: String
let mimeType: String

init(fileURI: String, mimeType: String) {
self.fileURI = fileURI
self.mimeType = mimeType
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ErrorPart: Part, Error {
let error: Error
Expand Down
27 changes: 10 additions & 17 deletions FirebaseVertexAI/Sources/Types/Public/Part.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,41 +43,34 @@ public struct InlineDataPart: Part {
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FileData: Codable, Equatable, Sendable {
enum CodingKeys: String, CodingKey {
case mimeType = "mime_type"
case uri = "file_uri"
}
public struct FileDataPart: Part {
let fileData: FileData

public let mimeType: String
public let uri: String
public var uri: String { fileData.fileURI }
public var mimeType: String { fileData.mimeType }

public init(mimeType: String, uri: String) {
self.mimeType = mimeType
self.uri = uri
public init(uri: String, mimeType: String) {
self.init(FileData(fileURI: uri, mimeType: mimeType))
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FileDataPart: Part {
public let fileData: FileData

public init(fileData: FileData) {
init(_ fileData: FileData) {
self.fileData = fileData
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionCallPart: Part {
// TODO: Consider making FunctionCall internal and exposing params on FunctionCallPart instead.
public let functionCall: FunctionCall

public init(functionCall: FunctionCall) {
public init(_ functionCall: FunctionCall) {
self.functionCall = functionCall
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionResponsePart: Part {
// TODO: Consider making FunctionResponsePart internal and exposing params here instead.
public let functionResponse: FunctionResponse

public init(functionResponse: FunctionResponse) {
Expand Down
65 changes: 50 additions & 15 deletions FirebaseVertexAI/Tests/Unit/PartTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,28 @@ final class PartTests: XCTestCase {

// MARK: - Part Decoding

func testDecodeTextPart() throws {
let expectedText = "Hello, world!"
let json = """
{
"text" : "\(expectedText)"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let part = try decoder.decode(TextPart.self, from: jsonData)

XCTAssertEqual(part.text, expectedText)
}

func testDecodeInlineDataPart() throws {
let imageBase64 = try blueSquareImage()
let imageBase64 = try PartTests.blueSquareImage()
let mimeType = "image/png"

let json = """
{
"inlineData": {
"data": "\(imageBase64)",
"mimeType": "\(mimeType)"
"inlineData" : {
"data" : "\(imageBase64)",
"mimeType" : "\(mimeType)"
}
}
"""
Expand Down Expand Up @@ -75,9 +88,23 @@ final class PartTests: XCTestCase {

// MARK: - Part Encoding

func testEncodeTextPart() throws {
let expectedText = "Hello, world!"
let textPart = TextPart(expectedText)

let jsonData = try encoder.encode(textPart)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"text" : "\(expectedText)"
}
""")
}

func testEncodeInlineDataPart() throws {
let mimeType = "image/png"
let imageBase64 = try blueSquareImage()
let imageBase64 = try PartTests.blueSquareImage()
let imageBase64Data = Data(base64Encoded: imageBase64)
let inlineDataPart = InlineDataPart(data: imageBase64Data!, mimeType: mimeType)

Expand All @@ -97,26 +124,34 @@ final class PartTests: XCTestCase {
func testEncodeFileDataPart() throws {
let mimeType = "image/jpeg"
let fileURI = "gs://test-bucket/image.jpg"
let fileDataPart = FileDataPart(fileData: FileData(mimeType: mimeType, uri: fileURI))
let fileDataPart = FileDataPart(uri: fileURI, mimeType: mimeType)

let jsonData = try encoder.encode(fileDataPart)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"fileData" : {
"file_uri" : "\(fileURI)",
"mime_type" : "\(mimeType)"
"fileURI" : "\(fileURI)",
"mimeType" : "\(mimeType)"
}
}
""")
}
}

// MARK: - Helpers
// MARK: - Helpers

func blueSquareImage() throws -> String {
let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")!
let imageData = try Data(contentsOf: imageURL)
return imageData.base64EncodedString()
private static func bundle() -> Bundle {
#if SWIFT_PACKAGE
return Bundle.module
#else // SWIFT_PACKAGE
return Bundle(for: Self.self)
#endif // SWIFT_PACKAGE
}

private static func blueSquareImage() throws -> String {
let imageURL = Bundle.module.url(forResource: "blue", withExtension: "png")!
let imageData = try Data(contentsOf: imageURL)
return imageData.base64EncodedString()
}
}
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ final class VertexAIAPITests: XCTestCase {
let _ = try await genAI.generateContent(str, "abc", "def")
let _ = try await genAI.generateContent(
str,
FileDataPart(fileData: FileData(mimeType: "image/jpeg", uri: "gs://test-bucket/image.jpg"))
FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg")
)
#if canImport(UIKit)
_ = try await genAI.generateContent(UIImage())
Expand Down