Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add usageMetadata to GenerateContentResponse #159

Merged
merged 1 commit into from
May 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
@@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
@@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

@@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
@@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}

Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
data: {"candidates": [{"content": {"parts": [{"text": "Cheyenne"}]},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}}

data: {"candidates": [{"content": {"parts": [{"text": "Mountain View, California"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"candidatesTokenCount": 4}}
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
"content": {
"parts": [
{
"text": "Mountain View, California, United States"
"text": "Mountain View, California"
}
],
"role": "model"
@@ -31,24 +31,7 @@
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
"usageMetadata": {
"candidatesTokenCount": 4
}
}
49 changes: 45 additions & 4 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
@@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
XCTAssertEqual(part.text, "Mountain View, California, United States")
XCTAssertEqual(part.text, "Mountain View, California")
XCTAssertEqual(response.text, part.text)
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertNil(promptFeedback.blockReason)
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
XCTAssertNil(response.promptFeedback)
XCTAssertEqual(response.functionCalls, [])
}

@@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.functionCalls, [functionCall])
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
@@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase {
}))
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Loading