Skip to content

feat(firebase_ai): Add support for Grounding with Google Search #17468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/firebase_ai/firebase_ai/lib/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export 'src/error.dart'
ServerException,
UnsupportedUserLocation;
export 'src/firebase_ai.dart' show FirebaseAI;
export 'src/function_calling.dart'
export 'src/tool.dart'
show
FunctionCallingConfig,
FunctionCallingMode,
Expand Down
307 changes: 282 additions & 25 deletions packages/firebase_ai/firebase_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import 'content.dart';
import 'error.dart';
import 'function_calling.dart' show Tool, ToolConfig;
import 'tool.dart' show Tool, ToolConfig;
import 'schema.dart';

/// Response for Count Tokens
Expand Down Expand Up @@ -187,7 +187,8 @@ final class Candidate {
// TODO: token count?
// ignore: public_member_api_docs
Candidate(this.content, this.safetyRatings, this.citationMetadata,
this.finishReason, this.finishMessage);
this.finishReason, this.finishMessage,
{this.groundingMetadata});

/// Generated content returned from the model.
final Content content;
Expand All @@ -212,6 +213,9 @@ final class Candidate {
/// Message for finish reason.
final String? finishMessage;

/// Metadata returned to the client when grounding is enabled.
final GroundingMetadata? groundingMetadata;

/// The concatenation of the text parts of [content], if any.
///
/// If this candidate was finished for a reason of [FinishReason.recitation]
Expand Down Expand Up @@ -243,6 +247,144 @@ final class Candidate {
}
}

/// Represents a specific segment within a [Content], often used to pinpoint
/// the exact location of text or data that grounding information refers to.
final class Segment {
Segment(
{required this.partIndex,
required this.startIndex,
required this.endIndex,
required this.text});

/// The zero-based index of the [Part] object within the `parts` array of its
/// parent [Content] object.
///
/// This identifies which part of the content the segment belongs to.
final int partIndex;

/// The zero-based start index of the segment within the specified [Part],
/// measured in UTF-8 bytes.
///
/// This offset is inclusive, starting from 0 at the beginning of the
/// part's content.
final int startIndex;

/// The zero-based end index of the segment within the specified [Part],
/// measured in UTF-8 bytes.
///
/// This offset is exclusive, meaning the character at this index is not
/// included in the segment.
final int endIndex;

/// The text corresponding to the segment from the response.
final String text;
}

/// A grounding chunk sourced from the web.
final class WebGroundingChunk {
WebGroundingChunk({this.uri, this.title, this.domain});

/// The URI of the retrieved web page.
final String? uri;

/// The title of the retrieved web page.
final String? title;

/// The domain of the original URI from which the content was retrieved.
///
/// This field is only populated when using the Vertex AI Gemini API.
final String? domain;
}

/// Represents a chunk of retrieved data that supports a claim in the model's
/// response.
///
/// This is part of the grounding information provided when grounding is
/// enabled.
final class GroundingChunk {
GroundingChunk({this.web});

/// Contains details if the grounding chunk is from a web source.
final WebGroundingChunk? web;
}

/// Provides information about how a specific segment of the model's response
/// is supported by the retrieved grounding chunks.
final class GroundingSupport {
GroundingSupport(
{required this.segment, required this.groundingChunkIndices});

/// Specifies the segment of the model's response content that this
/// grounding support pertains to.
final Segment segment;

/// A list of indices that refer to specific [GroundingChunk]s within the
/// [GroundingMetadata.groundingChunks] array.
///
/// These referenced chunks are the sources that
/// support the claim made in the associated `segment` of the response.
/// For example, an array `[1, 3, 4]`
/// means that `groundingChunks[1]`, `groundingChunks[3]`, and
/// `groundingChunks[4]` are the
/// retrieved content supporting this part of the response.
final List<int> groundingChunkIndices;
}

/// Google Search entry point for web searches.
final class SearchEntryPoint {
SearchEntryPoint({required this.renderedContent});

/// An HTML/CSS snippet that **must** be embedded in an app to display a
/// Google Search entry point for follow-up web searches related to the
/// model's "Grounded Response".
///
/// To ensure proper rendering, it's recommended to display this content
/// within a `WebView`.
final String renderedContent;
}

/// Metadata returned to the client when grounding is enabled.
///
/// > Important: If using Grounding with Google Search, you are required to
/// comply with the "Grounding with Google Search" usage requirements for your
/// chosen API provider:
/// [Gemini Developer API](https://ai.google.dev/gemini-api/terms#grounding-with-google-search)
/// or Vertex AI Gemini API (see [Service Terms](https://cloud.google.com/terms/service-terms)
/// section within the Service Specific Terms).
final class GroundingMetadata {
GroundingMetadata(
{this.searchEntryPoint,
required this.groundingChunks,
required this.groundingSupport,
required this.webSearchQueries});

/// Google Search entry point for web searches.
///
/// This contains an HTML/CSS snippet that **must** be embedded in an app to
// display a Google Search entry point for follow-up web searches related to
// the model's "Grounded Response".
final SearchEntryPoint? searchEntryPoint;

/// A list of [GroundingChunk]s.
///
/// Each chunk represents a piece of retrieved content (e.g., from a web
/// page) that the model used to ground its response.
final List<GroundingChunk> groundingChunks;

/// A list of [GroundingSupport]s.
///
/// Each object details how specific segments of the
/// model's response are supported by the `groundingChunks`.
final List<GroundingSupport> groundingSupport;

/// A list of web search queries that the model performed to gather the
/// grounding information.
///
/// These can be used to allow users to explore the search results
/// themselves.
final List<String> webSearchQueries;
}

/// Safety rating for a piece of content.
///
/// The safety rating contains the category of harm and the harm probability
Expand Down Expand Up @@ -1027,29 +1169,33 @@ Candidate _parseCandidate(Object? jsonObject) {
}

return Candidate(
jsonObject.containsKey('content')
? parseContent(jsonObject['content'] as Object)
: Content(null, []),
switch (jsonObject) {
{'safetyRatings': final List<Object?> safetyRatings} =>
safetyRatings.map(_parseSafetyRating).toList(),
_ => null
},
switch (jsonObject) {
{'citationMetadata': final Object citationMetadata} =>
_parseCitationMetadata(citationMetadata),
_ => null
},
switch (jsonObject) {
{'finishReason': final Object finishReason} =>
FinishReason._parseValue(finishReason),
_ => null
},
switch (jsonObject) {
{'finishMessage': final String finishMessage} => finishMessage,
_ => null
},
);
jsonObject.containsKey('content')
? parseContent(jsonObject['content'] as Object)
: Content(null, []),
switch (jsonObject) {
{'safetyRatings': final List<Object?> safetyRatings} =>
safetyRatings.map(_parseSafetyRating).toList(),
_ => null
},
switch (jsonObject) {
{'citationMetadata': final Object citationMetadata} =>
_parseCitationMetadata(citationMetadata),
_ => null
},
switch (jsonObject) {
{'finishReason': final Object finishReason} =>
FinishReason._parseValue(finishReason),
_ => null
},
switch (jsonObject) {
{'finishMessage': final String finishMessage} => finishMessage,
_ => null
},
groundingMetadata: switch (jsonObject) {
{'groundingMetadata': final Object groundingMetadata} =>
_parseGroundingMetadata(groundingMetadata),
_ => null
});
}

PromptFeedback _parsePromptFeedback(Object jsonObject) {
Expand Down Expand Up @@ -1163,3 +1309,114 @@ Citation _parseCitationSource(Object? jsonObject) {
jsonObject['license'] as String?,
);
}

GroundingMetadata _parseGroundingMetadata(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('GroundingMetadata', jsonObject);
}

final searchEntryPoint = switch (jsonObject) {
{'searchEntryPoint': final Object? searchEntryPoint} =>
_parseSearchEntryPoint(searchEntryPoint),
_ => null,
};
final groundingChunks = switch (jsonObject) {
{'groundingChunks': final List<Object?> groundingChunks} =>
groundingChunks.map(_parseGroundingChunk).toList(),
_ => null,
} ??
[];
// Filters out null elements, which are returned from _parseGroundingSupport when
// segment is null.
final groundingSupport = switch (jsonObject) {
{'groundingSupport': final List<Object?> groundingSupport} =>
groundingSupport
.map(_parseGroundingSupport)
.whereType<GroundingSupport>()
.toList(),
_ => null,
} ??
[];
final webSearchQueries = switch (jsonObject) {
{'webSearchQueries': final List<String>? webSearchQueries} =>
webSearchQueries,
_ => null,
} ??
[];

return GroundingMetadata(
searchEntryPoint: searchEntryPoint,
groundingChunks: groundingChunks,
groundingSupport: groundingSupport,
webSearchQueries: webSearchQueries);
}

Segment _parseSegment(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('Segment', jsonObject);
}

return Segment(
partIndex: (jsonObject['partIndex'] as int?) ?? 0,
startIndex: (jsonObject['startIndex'] as int?) ?? 0,
endIndex: (jsonObject['endIndex'] as int?) ?? 0,
text: (jsonObject['text'] as String?) ?? '');
}

WebGroundingChunk _parseWebGroundingChunk(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('WebGroundingChunk', jsonObject);
}

return WebGroundingChunk(
uri: jsonObject['uri'] as String?,
title: jsonObject['title'] as String?,
domain: jsonObject['domain'] as String?,
);
}

GroundingChunk _parseGroundingChunk(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('GroundingChunk', jsonObject);
}

return GroundingChunk(
web: jsonObject['web'] != null
? _parseWebGroundingChunk(jsonObject['web'])
: null,
);
}

GroundingSupport? _parseGroundingSupport(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('GroundingSupport', jsonObject);
}

final segment = switch (jsonObject) {
{'segment': final Object? segment} => _parseSegment(segment),
_ => null,
};
if (segment == null) {
return null;
}

return GroundingSupport(
segment: segment,
groundingChunkIndices:
(jsonObject['groundingChunkIndices'] as List<int>?) ?? []);
}

SearchEntryPoint _parseSearchEntryPoint(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('SearchEntryPoint', jsonObject);
}

final renderedContent = jsonObject['renderedContent'] as String?;
if (renderedContent == null) {
throw unhandledFormat('SearchEntryPoint', jsonObject);
}

return SearchEntryPoint(
renderedContent: renderedContent,
);
}
2 changes: 1 addition & 1 deletion packages/firebase_ai/firebase_ai/lib/src/base_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import 'api.dart';
import 'client.dart';
import 'content.dart';
import 'developer/api.dart';
import 'function_calling.dart';
import 'tool.dart';
import 'imagen_api.dart';
import 'imagen_content.dart';
import 'live_api.dart';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import '../api.dart'
createUsageMetadata;
import '../content.dart' show Content, FunctionCall, Part, TextPart;
import '../error.dart';
import '../function_calling.dart' show Tool, ToolConfig;
import '../tool.dart' show Tool, ToolConfig;

HarmProbability _parseHarmProbability(Object jsonObject) =>
switch (jsonObject) {
Expand Down
Loading
Loading