Skip to content

feat(firebaseai): add think feature #17409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
32 changes: 31 additions & 1 deletion packages/firebase_ai/firebase_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ final class UsageMetadata {
{this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
this.thoughtsTokenCount,
this.promptTokensDetails,
this.candidatesTokensDetails});

Expand All @@ -158,6 +159,9 @@ final class UsageMetadata {
/// Total token count for the generation request (prompt + candidates).
final int? totalTokenCount;

/// Number of tokens present in thoughts output.
final int? thoughtsTokenCount;

/// List of modalities that were processed in the request input.
final List<ModalityTokenCount>? promptTokensDetails;

Expand All @@ -172,13 +176,15 @@ UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required int? thoughtsTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

Expand Down Expand Up @@ -697,10 +703,25 @@ enum ResponseModalities {
const ResponseModalities(this._jsonString);
final String _jsonString;

/// Convert to json format
// ignore: public_member_api_docs
String toJson() => _jsonString;
}

/// Config for thinking features.
class ThinkingConfig {
// ignore: public_member_api_docs
ThinkingConfig({this.thinkingBudget});

/// The number of thoughts tokens that the model should generate.
final int? thinkingBudget;

// ignore: public_member_api_docs
Map<String, Object?> toJson() => {
if (thinkingBudget case final thinkingBudget?)
'thinkingBudget': thinkingBudget,
};
}

/// Configuration options for model generation and outputs.
abstract class BaseGenerationConfig {
// ignore: public_member_api_docs
Expand Down Expand Up @@ -826,6 +847,7 @@ final class GenerationConfig extends BaseGenerationConfig {
super.responseModalities,
this.responseMimeType,
this.responseSchema,
this.thinkingConfig,
});

/// The set of character sequences (up to 5) that will stop output generation.
Expand All @@ -847,6 +869,12 @@ final class GenerationConfig extends BaseGenerationConfig {
/// a schema; currently this is limited to `application/json`.
final Schema? responseSchema;

/// Config for thinking features.
///
/// An error will be returned if this field is set for models that don't
/// support thinking.
final ThinkingConfig? thinkingConfig;

@override
Map<String, Object?> toJson() => {
...super.toJson(),
Expand All @@ -857,6 +885,8 @@ final class GenerationConfig extends BaseGenerationConfig {
'responseMimeType': responseMimeType,
if (responseSchema case final responseSchema?)
'responseSchema': responseSchema.toJson(),
if (thinkingConfig case final thinkingConfig?)
'thinkingConfig': thinkingConfig.toJson(),
};
}

Expand Down
5 changes: 5 additions & 0 deletions packages/firebase_ai/firebase_ai/lib/src/developer/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,15 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
return createUsageMetadata(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: null,
candidatesTokensDetails: null,
);
Expand Down
34 changes: 31 additions & 3 deletions packages/firebase_ai/firebase_ai/test/api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import 'package:firebase_ai/firebase_ai.dart';
import 'package:firebase_ai/src/api.dart';

Expand Down Expand Up @@ -396,6 +395,7 @@ void main() {
group('GenerationConfig & BaseGenerationConfig', () {
test('GenerationConfig toJson with all fields', () {
final schema = Schema.object(properties: {});
final thinkingConfig = ThinkingConfig(thinkingBudget: 100);
final config = GenerationConfig(
candidateCount: 1,
stopSequences: ['\n', 'stop'],
Expand All @@ -407,6 +407,7 @@ void main() {
frequencyPenalty: 0.4,
responseMimeType: 'application/json',
responseSchema: schema,
thinkingConfig: thinkingConfig,
);
expect(config.toJson(), {
'candidateCount': 1,
Expand All @@ -418,8 +419,8 @@ void main() {
'frequencyPenalty': 0.4,
'stopSequences': ['\n', 'stop'],
'responseMimeType': 'application/json',
'responseSchema': schema
.toJson(), // Schema itself not schema.toJson() in the provided code
'responseSchema': schema.toJson(),
'thinkingConfig': {'thinkingBudget': 100},
});
});

Expand All @@ -438,6 +439,33 @@ void main() {
'responseMimeType': 'text/plain',
});
});

test('GenerationConfig toJson without thinkingConfig', () {
final config = GenerationConfig(temperature: 0.5);
expect(config.toJson(), {'temperature': 0.5});
});
});

group('ThinkingConfig', () {
test('toJson with thinkingBudget set', () {
final config = ThinkingConfig(thinkingBudget: 123);
expect(config.toJson(), {'thinkingBudget': 123});
});

test('toJson with thinkingBudget null', () {
final config = ThinkingConfig();
// Expecting the key to be absent or the value to be explicitly null,
// depending on implementation. Current implementation omits the key.
expect(config.toJson(), {});
});

test('constructor initializes thinkingBudget', () {
final config = ThinkingConfig(thinkingBudget: 456);
expect(config.thinkingBudget, 456);

final configNull = ThinkingConfig();
expect(configNull.thinkingBudget, isNull);
});
});

group('Parsing Functions', () {
Expand Down
126 changes: 126 additions & 0 deletions packages/firebase_ai/firebase_ai/test/developer_api_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import 'package:firebase_ai/src/developer/api.dart';
import 'package:flutter_test/flutter_test.dart';

void main() {
group('DeveloperSerialization', () {
group('parseGenerateContentResponse', () {
test('parses usageMetadata with thoughtsTokenCount correctly', () {
final jsonResponse = {
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': 'Some generated text.'}
]
},
'finishReason': 'STOP',
}
],
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 5,
'totalTokenCount': 15,
'thoughtsTokenCount': 3,
}
};
final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 5);
expect(response.usageMetadata!.totalTokenCount, 15);
expect(response.usageMetadata!.thoughtsTokenCount, 3);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
final jsonResponse = {
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': 'Some generated text.'}
]
},
'finishReason': 'STOP',
}
],
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 5,
'totalTokenCount': 15,
// thoughtsTokenCount is missing
}
};
final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 5);
expect(response.usageMetadata!.totalTokenCount, 15);
expect(response.usageMetadata!.thoughtsTokenCount, isNull);
});

test('parses usageMetadata when thoughtsTokenCount is present but null',
() {
final jsonResponse = {
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': 'Some generated text.'}
]
},
'finishReason': 'STOP',
}
],
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 5,
'totalTokenCount': 15,
'thoughtsTokenCount': null,
}
};
final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.thoughtsTokenCount, isNull);
});

test('parses response when usageMetadata is missing', () {
final jsonResponse = {
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': 'Some generated text.'}
]
},
'finishReason': 'STOP',
}
],
// usageMetadata is missing
};
final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);
expect(response.usageMetadata, isNull);
});
});
});
}
Loading