diff --git a/packages/firebase_ai/firebase_ai/example/ios/Podfile b/packages/firebase_ai/firebase_ai/example/ios/Podfile new file mode 100644 index 000000000000..620e46eba607 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/ios/Podfile @@ -0,0 +1,43 @@ +# Uncomment this line to define a global platform for your project +# platform :ios, '13.0' + +# CocoaPods analytics sends network stats synchronously affecting flutter build latency. +ENV['COCOAPODS_DISABLE_STATS'] = 'true' + +project 'Runner', { + 'Debug' => :debug, + 'Profile' => :release, + 'Release' => :release, +} + +def flutter_root + generated_xcode_build_settings_path = File.expand_path(File.join('..', 'Flutter', 'Generated.xcconfig'), __FILE__) + unless File.exist?(generated_xcode_build_settings_path) + raise "#{generated_xcode_build_settings_path} must exist. If you're running pod install manually, make sure flutter pub get is executed first" + end + + File.foreach(generated_xcode_build_settings_path) do |line| + matches = line.match(/FLUTTER_ROOT\=(.*)/) + return matches[1].strip if matches + end + raise "FLUTTER_ROOT not found in #{generated_xcode_build_settings_path}. Try deleting Generated.xcconfig, then run flutter pub get" +end + +require File.expand_path(File.join('packages', 'flutter_tools', 'bin', 'podhelper'), flutter_root) + +flutter_ios_podfile_setup + +target 'Runner' do + use_frameworks! + + flutter_install_all_ios_pods File.dirname(File.realpath(__FILE__)) + target 'RunnerTests' do + inherit! :search_paths + end +end + +post_install do |installer| + installer.pods_project.targets.each do |target| + flutter_additional_ios_build_settings(target) + end +end diff --git a/packages/firebase_ai/firebase_ai/example/macos/Podfile b/packages/firebase_ai/firebase_ai/example/macos/Podfile new file mode 100644 index 000000000000..ff5ddb3b8bdc --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/macos/Podfile @@ -0,0 +1,42 @@ +platform :osx, '10.15' + +# CocoaPods analytics sends network stats synchronously affecting flutter build latency. +ENV['COCOAPODS_DISABLE_STATS'] = 'true' + +project 'Runner', { + 'Debug' => :debug, + 'Profile' => :release, + 'Release' => :release, +} + +def flutter_root + generated_xcode_build_settings_path = File.expand_path(File.join('..', 'Flutter', 'ephemeral', 'Flutter-Generated.xcconfig'), __FILE__) + unless File.exist?(generated_xcode_build_settings_path) + raise "#{generated_xcode_build_settings_path} must exist. If you're running pod install manually, make sure \"flutter pub get\" is executed first" + end + + File.foreach(generated_xcode_build_settings_path) do |line| + matches = line.match(/FLUTTER_ROOT\=(.*)/) + return matches[1].strip if matches + end + raise "FLUTTER_ROOT not found in #{generated_xcode_build_settings_path}. Try deleting Flutter-Generated.xcconfig, then run \"flutter pub get\"" +end + +require File.expand_path(File.join('packages', 'flutter_tools', 'bin', 'podhelper'), flutter_root) + +flutter_macos_podfile_setup + +target 'Runner' do + use_frameworks! + + flutter_install_all_macos_pods File.dirname(File.realpath(__FILE__)) + target 'RunnerTests' do + inherit! :search_paths + end +end + +post_install do |installer| + installer.pods_project.targets.each do |target| + flutter_additional_macos_build_settings(target) + end +end diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 9e48a0c58686..c239dbee762a 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -66,6 +66,8 @@ export 'src/error.dart' QuotaExceeded, UnsupportedUserLocation; export 'src/firebase_ai.dart' show FirebaseAI; +export 'src/hybrid_config.dart' + show HybridConfig, HybridPreference, LocalModelConfig; export 'src/image_config.dart' show ImageConfig, ImageAspectRatio, ImageSize; export 'src/imagen/imagen_api.dart' show diff --git a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart index f02bc723b580..0a394e1bf78a 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart @@ -17,6 +17,7 @@ import 'dart:async'; import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; +import 'package:flutter/foundation.dart' show debugPrint; import 'package:http/http.dart' as http; import 'package:meta/meta.dart'; @@ -26,12 +27,14 @@ import 'content.dart'; import 'developer/api.dart'; import 'error.dart'; import 'firebaseai_version.dart'; +import 'hybrid_config.dart'; import 'imagen/imagen_api.dart'; import 'imagen/imagen_content.dart'; import 'imagen/imagen_edit.dart'; import 'imagen/imagen_reference.dart'; import 'live_api.dart'; import 'live_session.dart'; +import 'local_model_runner.dart'; import 'platform_header_helper.dart'; import 'server_template/template_tool.dart'; import 'tool.dart'; diff --git a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart index 40971c347362..177c97dd3386 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart @@ -17,6 +17,7 @@ import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:firebase_core_platform_interface/firebase_core_platform_interface.dart' show FirebasePluginPlatform; +import 'package:flutter_gemma/flutter_gemma.dart' as gemma; import 'package:meta/meta.dart'; import '../firebase_ai.dart'; @@ -131,6 +132,16 @@ class FirebaseAI extends FirebasePluginPlatform { return newInstance; } + /// Initialize the local on-device model framework (flutter_gemma). + /// Must be called before loading local models or switching preferences to local. + static Future initializeLocal({ + String? huggingFaceToken, + }) async { + await gemma.FlutterGemma.initialize( + huggingFaceToken: huggingFaceToken, + ); + } + /// Create a [GenerativeModel] backed by the generative model named [model]. /// /// The [model] argument can be a model name (such as `'gemini-pro'`) or a @@ -149,6 +160,7 @@ class FirebaseAI extends FirebasePluginPlatform { List? tools, ToolConfig? toolConfig, Content? systemInstruction, + HybridConfig? hybridConfig, }) { return createGenerativeModel( model: model, @@ -163,6 +175,7 @@ class FirebaseAI extends FirebasePluginPlatform { toolConfig: toolConfig, systemInstruction: systemInstruction, useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + hybridConfig: hybridConfig, ); } diff --git a/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart index 2bf58351eafe..3cfad22e124d 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart @@ -46,10 +46,14 @@ final class GenerativeModel extends BaseApiClientModel { ToolConfig? toolConfig, Content? systemInstruction, http.Client? httpClient, + HybridConfig? hybridConfig, }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, _toolConfig = toolConfig, _systemInstruction = systemInstruction, + _hybridConfig = hybridConfig, + _preference = + hybridConfig?.initialPreference ?? HybridPreference.onlyCloud, super( serializationStrategy: useVertexBackend ? VertexSerialization() @@ -61,7 +65,11 @@ final class GenerativeModel extends BaseApiClientModel { apiKey: app.options.apiKey, httpClient: httpClient, requestHeaders: BaseModel.firebaseTokens( - appCheck, auth, app, useLimitedUseAppCheckTokens))); + appCheck, auth, app, useLimitedUseAppCheckTokens))) { + if (hybridConfig != null) { + _localRunner = LocalModelRunner(hybridConfig.localConfig); + } + } GenerativeModel._constructTestModel({ required String model, @@ -77,10 +85,14 @@ final class GenerativeModel extends BaseApiClientModel { ToolConfig? toolConfig, Content? systemInstruction, ApiClient? apiClient, + HybridConfig? hybridConfig, }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, _toolConfig = toolConfig, _systemInstruction = systemInstruction, + _hybridConfig = hybridConfig, + _preference = + hybridConfig?.initialPreference ?? HybridPreference.onlyCloud, super( serializationStrategy: useVertexBackend ? VertexSerialization() @@ -92,7 +104,11 @@ final class GenerativeModel extends BaseApiClientModel { HttpApiClient( apiKey: app.options.apiKey, requestHeaders: BaseModel.firebaseTokens( - appCheck, auth, app, useLimitedUseAppCheckTokens))); + appCheck, auth, app, useLimitedUseAppCheckTokens))) { + if (hybridConfig != null) { + _localRunner = LocalModelRunner(hybridConfig.localConfig); + } + } final List _safetySettings; final GenerationConfig? _generationConfig; @@ -103,52 +119,281 @@ final class GenerativeModel extends BaseApiClientModel { final ToolConfig? _toolConfig; final Content? _systemInstruction; - /// Generates content responding to [prompt]. - /// - /// Sends a "generateContent" API request for the configured model, - /// and waits for the response. - /// - /// Example: - /// ```dart - /// final response = await model.generateContent([Content.text(prompt)]); - /// print(response.text); - /// ``` + // Hybrid Fields + final HybridConfig? _hybridConfig; + HybridPreference _preference; + LocalModelRunner? _localRunner; + + /// Test-only setter to allow mocking the local runner. + @visibleForTesting + // ignore: avoid_setters_without_getters + set localRunner(LocalModelRunner? runner) { + _localRunner = runner; + } + + /// Get the current hybrid preference policy. + HybridPreference get preference => _preference; + + /// Update the hybrid preference policy at runtime. + Future setPreference(HybridPreference preference) async { + if (_hybridConfig == null && preference != HybridPreference.onlyCloud) { + throw StateError( + 'Cannot set preference to $preference because hybridConfig was not ' + 'provided during model initialization.', + ); + } + _preference = preference; + if (preference == HybridPreference.onlyLocal) { + await _localRunner!.initialize(); + } else if (preference == HybridPreference.preferLocal) { + if (await isLocalModelInstalled()) { + await _localRunner!.initialize(); + } + } + } + + /// Checks if the local on-device model is installed and ready. + Future isLocalModelInstalled() async { + if (_localRunner == null) return false; + return _localRunner!.isInstalled(); + } + + /// Initiates the download of the on-device local model. + Future downloadLocalModel( + {void Function(int progress)? onProgress}) async { + if (_localRunner == null) { + throw StateError( + 'Cannot download local model: hybridConfig was not provided during model initialization.'); + } + await _localRunner!.download(onProgress: onProgress); + } + + /// Generates content responding to [prompt], automatically routing to + /// either Cloud or Local based on preference policies. Future generateContent(Iterable prompt, - {List? safetySettings, - GenerationConfig? generationConfig, - List? tools, - ToolConfig? toolConfig}) => - makeRequest( - Task.generateContent, - _serializationStrategy.generateContentRequest( + {List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig}) async { + if (_hybridConfig == null) { + return _generateContentCloud( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } + final runLocal = await _shouldRunLocal(); + if (runLocal) { + try { + await _localRunner!.initialize(); + return await _localRunner!.generateContent( + prompt, + tools: tools ?? this.tools, + toolConfig: toolConfig ?? _toolConfig, + systemInstruction: _systemInstruction, + ); + } catch (e) { + if (_preference == HybridPreference.preferLocal) { + debugPrint( + 'LocalModelRunner: On-device generation failed ($e). Falling back to Cloud backend...'); + return _generateContentCloud( prompt, - model, - safetySettings ?? _safetySettings, - generationConfig ?? _generationConfig, - tools ?? this.tools, - toolConfig ?? _toolConfig, - _systemInstruction, - ), - _serializationStrategy.parseGenerateContentResponse); + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } + rethrow; + } + } else { + try { + return await _generateContentCloud( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } catch (e) { + if (_preference == HybridPreference.preferCloud && + await isLocalModelInstalled()) { + debugPrint( + 'LocalModelRunner: Cloud generation failed ($e). Falling back to local on-device backend...'); + await _localRunner!.initialize(); + return _localRunner!.generateContent( + prompt, + tools: tools ?? this.tools, + toolConfig: toolConfig ?? _toolConfig, + systemInstruction: _systemInstruction, + ); + } + rethrow; + } + } + } /// Generates a stream of content responding to [prompt]. - /// - /// Sends a "streamGenerateContent" API request for the configured model, - /// and waits for the response. - /// - /// Example: - /// ```dart - /// final responses = await model.generateContent([Content.text(prompt)]); - /// await for (final response in responses) { - /// print(response.text); - /// } - /// ``` Stream generateContentStream( Iterable prompt, {List? safetySettings, GenerationConfig? generationConfig, List? tools, ToolConfig? toolConfig}) { + if (_hybridConfig == null) { + return _generateContentStreamCloud( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } + return _generateContentStreamHybrid( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } + + Stream _generateContentStreamHybrid( + Iterable prompt, + {List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig}) async* { + final runLocal = await _shouldRunLocal(); + if (runLocal) { + try { + await _localRunner!.initialize(); + await for (final response in _localRunner!.generateContentStream( + prompt, + tools: tools ?? this.tools, + toolConfig: toolConfig ?? _toolConfig, + systemInstruction: _systemInstruction, + )) { + yield response; + } + } catch (e) { + if (_preference == HybridPreference.preferLocal) { + debugPrint( + 'LocalModelRunner: On-device stream failed ($e). Falling back to Cloud backend...'); + yield* _generateContentStreamCloud( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ); + } else { + rethrow; + } + } + } else { + try { + await for (final response in _generateContentStreamCloud( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + )) { + yield response; + } + } catch (e) { + if (_preference == HybridPreference.preferCloud && + await isLocalModelInstalled()) { + debugPrint( + 'LocalModelRunner: Cloud stream failed ($e). Falling back to local on-device backend...'); + await _localRunner!.initialize(); + yield* _localRunner!.generateContentStream( + prompt, + tools: tools ?? this.tools, + toolConfig: toolConfig ?? _toolConfig, + systemInstruction: _systemInstruction, + ); + } else { + rethrow; + } + } + } + } + + /// Counts the total number of tokens in [contents]. + Future countTokens( + Iterable contents, + ) async { + if (_hybridConfig == null) { + return _countTokensCloud(contents); + } + final runLocal = await _shouldRunLocal(); + if (runLocal) { + try { + await _localRunner!.initialize(); + return await _localRunner!.countTokens(contents); + } catch (e) { + if (_preference == HybridPreference.preferLocal) { + return _countTokensCloud(contents); + } + rethrow; + } + } else { + try { + return await _countTokensCloud(contents); + } catch (e) { + if (_preference == HybridPreference.preferCloud && + await isLocalModelInstalled()) { + await _localRunner!.initialize(); + return _localRunner!.countTokens(contents); + } + rethrow; + } + } + } + + // === Private Fallback Helpers === + + Future _shouldRunLocal() async { + if (_preference == HybridPreference.onlyLocal) return true; + if (_preference == HybridPreference.onlyCloud) return false; + if (_preference == HybridPreference.preferLocal) { + return isLocalModelInstalled(); + } + return false; // preferCloud defaults to cloud first + } + + Future _generateContentCloud( + Iterable prompt, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) { + return makeRequest( + Task.generateContent, + _serializationStrategy.generateContentRequest( + prompt, + model, + safetySettings ?? _safetySettings, + generationConfig ?? _generationConfig, + tools ?? this.tools, + toolConfig ?? _toolConfig, + _systemInstruction, + ), + _serializationStrategy.parseGenerateContentResponse); + } + + Stream _generateContentStreamCloud( + Iterable prompt, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) { final response = client.streamRequest( taskUri(Task.streamGenerateContent), _serializationStrategy.generateContentRequest( @@ -163,26 +408,8 @@ final class GenerativeModel extends BaseApiClientModel { return response.map(_serializationStrategy.parseGenerateContentResponse); } - /// Counts the total number of tokens in [contents]. - /// - /// Sends a "countTokens" API request for the configured model, - /// and waits for the response. - /// - /// Example: - /// ```dart - /// final promptContent = [Content.text(prompt)]; - /// final totalTokens = - /// (await model.countTokens(promptContent)).totalTokens; - /// if (totalTokens > maxPromptSize) { - /// print('Prompt is too long!'); - /// } else { - /// final response = await model.generateContent(promptContent); - /// print(response.text); - /// } - /// ``` - Future countTokens( - Iterable contents, - ) async { + Future _countTokensCloud( + Iterable contents) async { final parameters = _serializationStrategy.countTokensRequest( contents, model, @@ -210,6 +437,7 @@ GenerativeModel createGenerativeModel({ List? tools, ToolConfig? toolConfig, Content? systemInstruction, + HybridConfig? hybridConfig, }) => GenerativeModel._( model: model, @@ -224,6 +452,7 @@ GenerativeModel createGenerativeModel({ tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, + hybridConfig: hybridConfig, ); /// Creates a model with an overridden [ApiClient] for testing. @@ -243,6 +472,7 @@ GenerativeModel createModelWithClient({ List? safetySettings, List? tools, ToolConfig? toolConfig, + HybridConfig? hybridConfig, }) => GenerativeModel._constructTestModel( model: model, @@ -257,4 +487,5 @@ GenerativeModel createModelWithClient({ systemInstruction: systemInstruction, tools: tools, toolConfig: toolConfig, - apiClient: client); + apiClient: client, + hybridConfig: hybridConfig); diff --git a/packages/firebase_ai/firebase_ai/lib/src/hybrid_config.dart b/packages/firebase_ai/firebase_ai/lib/src/hybrid_config.dart new file mode 100644 index 000000000000..fe5172984add --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/hybrid_config.dart @@ -0,0 +1,94 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:flutter_gemma/flutter_gemma.dart' as gemma; + +/// Routing and fallback policies for the Hybrid Generative Model. +enum HybridPreference { + /// Use the local on-device model if it is installed and ready. + /// Fallback to the cloud model if the local model is not downloaded or fails to load. + preferLocal, + + /// Use the cloud model by default. + /// Fallback to the local on-device model if the cloud call fails (e.g., network offline or quota exceeded). + preferCloud, + + /// Only use the local on-device model. + /// Throws a [StateError] or exception if the model is not downloaded/ready. + /// Never attempts to contact the cloud. + onlyLocal, + + /// Only use the cloud model. + /// Throws an exception if the cloud call fails. + /// Never attempts to load or use local resources. + onlyCloud, +} + +/// Configuration options specifically for the local on-device model powered by flutter_gemma. +class LocalModelConfig { + + /// Creates a new [LocalModelConfig] for on-device model execution. + LocalModelConfig({ + required this.modelType, + this.fileType = gemma.ModelFileType.task, + this.modelPath, + this.modelUrl, + this.hfToken, + this.preferredBackend, + this.maxTokens = 1024, + this.supportImage = false, + this.supportAudio = false, + }); + /// The type of model (e.g., gemmaIt, gemma4, deepSeek, qwen, etc.). + final gemma.ModelType modelType; + + /// The file type of the model (task or binary). Defaults to task. + final gemma.ModelFileType fileType; + + /// The local path to the model file on the device (if already downloaded or bundled). + final String? modelPath; + + /// The remote HTTP/HTTPS URL to download the model from if not installed. + final String? modelUrl; + + /// Optional HuggingFace authentication token for gated model downloads (e.g., Gemma 3). + final String? hfToken; + + /// Backend preference (CPU or GPU). GPU is highly recommended if available. + final gemma.PreferredBackend? preferredBackend; + + /// The maximum context length for the model (default: 1024). + final int maxTokens; + + /// Whether the model supports multimodal image inputs (default: false). + final bool supportImage; + + /// Whether the model supports audio inputs (default: false, Gemma 3n E4B only). + final bool supportAudio; +} + +/// Configuration wrapper for Hybrid Mode, containing local configurations and initial routing preference. +class HybridConfig { + + /// Creates a new [HybridConfig] with [localConfig] and [initialPreference]. + HybridConfig({ + required this.localConfig, + this.initialPreference = HybridPreference.preferCloud, + }); + /// Configuration for the on-device local model. + final LocalModelConfig localConfig; + + /// Initial routing preference policy. Defaults to [HybridPreference.preferCloud]. + final HybridPreference initialPreference; +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/local_model_runner.dart b/packages/firebase_ai/firebase_ai/lib/src/local_model_runner.dart new file mode 100644 index 000000000000..98d07b4c9b0c --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/local_model_runner.dart @@ -0,0 +1,351 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; + +import 'package:flutter/foundation.dart'; +import 'package:flutter_gemma/flutter_gemma.dart' as gemma; + +import 'api.dart'; +import 'content.dart'; +import 'hybrid_config.dart'; +import 'tool.dart'; + +/// Package-private helper to manage all on-device local model execution details +/// using the `flutter_gemma` SDK. +class LocalModelRunner { + + /// Creates a new [LocalModelRunner] instance with [LocalModelConfig]. + LocalModelRunner(this._config); + final LocalModelConfig _config; + gemma.InferenceModel? _activeModel; + bool _initialized = false; + + /// Checks if the local model file is downloaded and ready on the device. + Future isInstalled() async { + final modelId = _getModelId(); + return gemma.FlutterGemma.isModelInstalled(modelId); + } + + /// Starts the model download. + Future download({void Function(int progress)? onProgress}) async { + if (_config.modelUrl == null) { + throw StateError('Cannot download local model: LocalModelConfig.modelUrl is null.'); + } + + final builder = gemma.FlutterGemma.installModel( + modelType: _config.modelType, + fileType: _config.fileType, + ).fromNetwork(_config.modelUrl!, token: _config.hfToken); + + if (onProgress != null) { + builder.withProgress(onProgress); + } + + await builder.install(); + } + + /// Lazily initializes the flutter_gemma framework and loads the active model. + Future initialize() async { + if (_initialized) return; + + // 1. Ensure flutter_gemma is initialized (safe to call multiple times) + await gemma.FlutterGemma.initialize( + huggingFaceToken: _config.hfToken, + ); + + // 2. Ensure model is installed + final installed = await isInstalled(); + if (!installed) { + if (_config.modelUrl != null) { + debugPrint('LocalModelRunner: Local model not found. Starting auto-download...'); + await download(); + } else { + throw StateError( + 'Local model not found at path and no download URL was provided. ' + 'Please pre-download the model or specify LocalModelConfig.modelUrl.', + ); + } + } + + // 3. Load active model + _activeModel = await gemma.FlutterGemma.getActiveModel( + maxTokens: _config.maxTokens, + preferredBackend: _config.preferredBackend, + supportImage: _config.supportImage, + supportAudio: _config.supportAudio, + ); + + _initialized = true; + debugPrint('LocalModelRunner: Local model successfully initialized and loaded.'); + } + + /// Executes a single generation call. + Future generateContent( + Iterable prompt, { + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + }) async { + await initialize(); + + final gemmaTools = _mapTools(tools); + final systemInstructionStr = _mapSystemInstruction(systemInstruction); + + // Create a new chat session for this request to avoid history leaks. + final chat = await _activeModel!.createChat( + tools: gemmaTools, + supportsFunctionCalls: gemmaTools.isNotEmpty, + modelType: _config.modelType, + systemInstruction: systemInstructionStr, + ); + + // Feed the history/prompt into the chat + final messages = _mapPrompt(prompt); + for (var i = 0; i < messages.length - 1; i++) { + await chat.addQueryChunk(messages[i]); + } + + if (messages.isNotEmpty) { + await chat.addQueryChunk(messages.last); + } + + final response = await chat.generateChatResponse(); + return _mapResponse(response); + } + + /// Executes a streaming generation call. + Stream generateContentStream( + Iterable prompt, { + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + }) async* { + await initialize(); + + final gemmaTools = _mapTools(tools); + final systemInstructionStr = _mapSystemInstruction(systemInstruction); + + final chat = await _activeModel!.createChat( + tools: gemmaTools, + supportsFunctionCalls: gemmaTools.isNotEmpty, + modelType: _config.modelType, + systemInstruction: systemInstructionStr, + ); + + final messages = _mapPrompt(prompt); + for (var i = 0; i < messages.length - 1; i++) { + await chat.addQueryChunk(messages[i]); + } + + if (messages.isNotEmpty) { + await chat.addQueryChunk(messages.last); + } + + // Generate response asynchronously (streaming) + yield* chat.session.getResponseAsync().map((token) { + // Stream yielding token chunks + return _mapResponse(gemma.TextResponse(token)); + }); + } + + /// Estimates total tokens in contents. + Future countTokens(Iterable contents) async { + await initialize(); + final textToEstimate = contents + .expand((c) => c.parts) + .whereType() + .map((p) => p.text) + .join(' '); + + // Use session.sizeInTokens if available + final dummySession = await _activeModel!.createSession(); + try { + final tokens = await dummySession.sizeInTokens(textToEstimate); + return CountTokensResponse(tokens); + } finally { + await dummySession.close(); + } + } + + /// Release on-device model resources. + Future close() async { + if (_activeModel != null) { + await _activeModel!.close(); + _activeModel = null; + _initialized = false; + debugPrint('LocalModelRunner: Model resources released.'); + } + } + + // === Mappings Helper Methods === + + String _getModelId() { + if (_config.modelPath != null && _config.modelPath!.isNotEmpty) { + return _config.modelPath!.split('/').last; + } + if (_config.modelUrl != null && _config.modelUrl!.isNotEmpty) { + return _config.modelUrl!.split('/').last; + } + throw StateError('No modelPath or modelUrl provided in LocalModelConfig to identify the model.'); + } + + String? _mapSystemInstruction(Content? systemInstruction) { + if (systemInstruction == null) return null; + return systemInstruction.parts.whereType().map((p) => p.text).join('\n'); + } + + List _mapTools(List? tools) { + if (tools == null || tools.isEmpty) return const []; + final gemmaTools = []; + + for (final tool in tools) { + final decls = tool.autoFunctionDeclarations.isNotEmpty + ? tool.autoFunctionDeclarations + : (tool.toJson()['functionDeclarations'] as List?) ?? []; + + for (final decl in decls) { + if (decl is AutoFunctionDeclaration) { + final json = decl.toJson(); + final parameters = (json['parameters'] ?? json['parametersJsonSchema']) as Map? ?? {}; + gemmaTools.add(gemma.Tool( + name: decl.name, + description: decl.description, + parameters: parameters, + )); + } else if (decl is FunctionDeclaration) { + final json = decl.toJson(); + final parameters = (json['parameters'] ?? json['parametersJsonSchema']) as Map? ?? {}; + gemmaTools.add(gemma.Tool( + name: decl.name, + description: decl.description, + parameters: parameters, + )); + } else if (decl is Map) { + final name = decl['name'] as String; + final desc = decl['description'] as String? ?? ''; + final parameters = (decl['parameters'] ?? decl['parametersJsonSchema']) as Map? ?? {}; + gemmaTools.add(gemma.Tool( + name: name, + description: desc, + parameters: parameters, + )); + } + } + } + return gemmaTools; + } + + List _mapPrompt(Iterable prompt) { + final messages = []; + + for (final content in prompt) { + final isUser = content.role == 'user' || content.role == 'function'; + final textParts = content.parts.whereType().map((p) => p.text).toList(); + final imageParts = content.parts.whereType().toList(); + final functionCallParts = content.parts.whereType().toList(); + final functionResponseParts = content.parts.whereType().toList(); + + final combinedText = textParts.join('\n'); + + if (functionResponseParts.isNotEmpty) { + // Client reporting tool response output back to model + for (final resp in functionResponseParts) { + messages.add(gemma.Message.toolResponse( + toolName: resp.name, + response: resp.response.cast(), + )); + } + } else if (functionCallParts.isNotEmpty) { + // Model called a function previously in conversation + for (final call in functionCallParts) { + messages.add(gemma.Message.toolCall( + text: 'Called function: ${call.name} with args ${call.args}', + )); + } + } else if (imageParts.isNotEmpty) { + // Multimodal input + final List imageBytesList = imageParts.map((p) => p.bytes).toList(); + messages.add(gemma.Message.withImages( + text: combinedText, + imageBytes: imageBytesList, + isUser: isUser, + )); + } else { + // Standard text + messages.add(gemma.Message.text( + text: combinedText, + isUser: isUser, + )); + } + } + + return messages; + } + + GenerateContentResponse _mapResponse(gemma.ModelResponse response) { + if (response is gemma.TextResponse) { + final content = Content('model', [TextPart(response.token)]); + return GenerateContentResponse([ + Candidate( + content, + [], + null, + FinishReason.stop, + null, + ) + ], null); + } else if (response is gemma.FunctionCallResponse) { + final content = Content('model', [ + FunctionCall(response.name, response.args.cast()) + ]); + return GenerateContentResponse([ + Candidate( + content, + [], + null, + FinishReason.stop, + null, + ) + ], null); + } else if (response is gemma.ParallelFunctionCallResponse) { + final parts = response.calls.map((call) { + return FunctionCall(call.name, call.args.cast()); + }).toList(); + final content = Content('model', parts); + return GenerateContentResponse([ + Candidate( + content, + [], + null, + FinishReason.stop, + null, + ) + ], null); + } else if (response is gemma.ThinkingResponse) { + final content = Content('model', [TextPart(response.content, isThought: true)]); + return GenerateContentResponse([ + Candidate( + content, + [], + null, + FinishReason.stop, + null, + ) + ], null); + } + + throw UnsupportedError('Unknown local model response type: ${response.runtimeType}'); + } +} diff --git a/packages/firebase_ai/firebase_ai/pubspec.yaml b/packages/firebase_ai/firebase_ai/pubspec.yaml index fada966d9c5d..90ba575ddaae 100644 --- a/packages/firebase_ai/firebase_ai/pubspec.yaml +++ b/packages/firebase_ai/firebase_ai/pubspec.yaml @@ -27,6 +27,7 @@ dependencies: firebase_core_platform_interface: ^7.0.1 flutter: sdk: flutter + flutter_gemma: ^0.15.1 http: ^1.1.0 meta: ^1.15.0 web_socket_channel: ^3.0.1 diff --git a/packages/firebase_ai/firebase_ai/test/hybrid_mode_test.dart b/packages/firebase_ai/firebase_ai/test/hybrid_mode_test.dart new file mode 100644 index 000000000000..8d135d7d72f5 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/test/hybrid_mode_test.dart @@ -0,0 +1,425 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; +import 'package:firebase_ai/firebase_ai.dart'; +import 'package:firebase_ai/src/base_model.dart'; +import 'package:firebase_ai/src/client.dart'; +import 'package:firebase_ai/src/local_model_runner.dart'; +import 'package:firebase_core/firebase_core.dart'; +import 'package:flutter_gemma/flutter_gemma.dart' as gemma; +import 'package:flutter_test/flutter_test.dart'; + +import 'mock.dart'; +import 'utils/matchers.dart'; + +class FakeLocalModelRunner extends Fake implements LocalModelRunner { + bool isInstalledResult = true; + bool initializeCalled = false; + bool closeCalled = false; + int initializeCount = 0; + + GenerateContentResponse? generateContentResult; + Object? generateContentError; + + Stream? generateContentStreamResult; + Object? generateContentStreamError; + + CountTokensResponse? countTokensResult; + Object? countTokensError; + + @override + Future isInstalled() async => isInstalledResult; + + @override + Future initialize() async { + initializeCalled = true; + initializeCount++; + } + + @override + Future generateContent( + Iterable prompt, { + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + }) async { + if (generateContentError != null) { + throw generateContentError!; + } + return generateContentResult ?? + GenerateContentResponse([ + Candidate( + Content('model', [const TextPart('fake local response')]), + null, + null, + null, + null, + ) + ], null); + } + + @override + Stream generateContentStream( + Iterable prompt, { + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + }) { + if (generateContentStreamError != null) { + return Stream.error(generateContentStreamError!); + } + if (generateContentStreamResult != null) { + return generateContentStreamResult!; + } + return Stream.value(GenerateContentResponse([ + Candidate( + Content('model', [const TextPart('fake local stream response')]), + null, + null, + null, + null, + ) + ], null)); + } + + @override + Future countTokens(Iterable contents) async { + if (countTokensError != null) { + throw countTokensError!; + } + return countTokensResult ?? CountTokensResponse(42); + } + + @override + Future close() async { + closeCalled = true; + } +} + +class MockApiClient implements ApiClient { + Map? response; + List>? streamResponses; + Object? error; + Uri? lastUri; + Map? lastBody; + int requestCount = 0; + + MockApiClient({this.response, this.streamResponses, this.error}); + + @override + Future> makeRequest( + Uri uri, + Map body, + ) async { + lastUri = uri; + lastBody = body; + requestCount++; + if (error != null) throw error!; + return response ?? {}; + } + + @override + Stream> streamRequest( + Uri uri, + Map body, + ) { + lastUri = uri; + lastBody = body; + requestCount++; + if (error != null) return Stream.error(error!); + return Stream.fromIterable(streamResponses ?? []); + } +} + +const Map arbitraryGenerateContentResponse = { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': 'Some Response'}, + ], + }, + }, + ], +}; + +void main() { + setupFirebaseVertexAIMocks(); + late FirebaseApp app; + + setUpAll(() async { + app = await Firebase.initializeApp(); + }); + + group('Hybrid Mode Tests', () { + const modelName = 'gemini-1.5-flash'; + final hybridConfig = HybridConfig( + localConfig: LocalModelConfig( + modelType: gemma.ModelType.gemmaIt, + modelPath: 'models/gemma-2b.bin', + ), + initialPreference: HybridPreference.onlyCloud, + ); + + GenerativeModel createTestModel({ + ApiClient? client, + HybridConfig? config, + FakeLocalModelRunner? fakeLocalRunner, + }) { + final model = createModelWithClient( + app: app, + location: 'us-central1', + model: modelName, + client: client ?? MockApiClient(), + useVertexBackend: false, + hybridConfig: config ?? hybridConfig, + ); + if (fakeLocalRunner != null) { + model.localRunner = fakeLocalRunner; + } + return model; + } + + group('onlyCloud preference', () { + test('generateContent goes to cloud', () async { + final mockClient = MockApiClient(response: arbitraryGenerateContentResponse); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: FakeLocalModelRunner(), + ); + + await model.setPreference(HybridPreference.onlyCloud); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'Some Response'); + expect(mockClient.requestCount, 1); + }); + + test('generateContentStream goes to cloud', () async { + final mockClient = MockApiClient(streamResponses: [arbitraryGenerateContentResponse]); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: FakeLocalModelRunner(), + ); + + await model.setPreference(HybridPreference.onlyCloud); + + final stream = model.generateContentStream([Content.text('hello')]); + final results = await stream.toList(); + + expect(results.first.text, 'Some Response'); + expect(mockClient.requestCount, 1); + }); + + test('countTokens goes to cloud', () async { + final mockClient = MockApiClient(response: {'totalTokens': 5}); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: FakeLocalModelRunner(), + ); + + await model.setPreference(HybridPreference.onlyCloud); + + final response = await model.countTokens([Content.text('hello')]); + + expect(response.totalTokens, 5); + expect(mockClient.requestCount, 1); + }); + }); + + group('onlyLocal preference', () { + test('generateContent goes to local', () async { + final fakeLocalRunner = FakeLocalModelRunner(); + final model = createTestModel(fakeLocalRunner: fakeLocalRunner); + + await model.setPreference(HybridPreference.onlyLocal); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'fake local response'); + expect(fakeLocalRunner.initializeCalled, true); + }); + + test('generateContentStream goes to local', () async { + final fakeLocalRunner = FakeLocalModelRunner(); + final model = createTestModel(fakeLocalRunner: fakeLocalRunner); + + await model.setPreference(HybridPreference.onlyLocal); + + final stream = model.generateContentStream([Content.text('hello')]); + final results = await stream.toList(); + + expect(results.first.text, 'fake local stream response'); + expect(fakeLocalRunner.initializeCalled, true); + }); + + test('countTokens goes to local', () async { + final fakeLocalRunner = FakeLocalModelRunner(); + final model = createTestModel(fakeLocalRunner: fakeLocalRunner); + + await model.setPreference(HybridPreference.onlyLocal); + + final response = await model.countTokens([Content.text('hello')]); + + expect(response.totalTokens, 42); + expect(fakeLocalRunner.initializeCalled, true); + }); + }); + + group('preferLocal preference', () { + test('goes to local when installed', () async { + final fakeLocalRunner = FakeLocalModelRunner()..isInstalledResult = true; + final model = createTestModel(fakeLocalRunner: fakeLocalRunner); + + await model.setPreference(HybridPreference.preferLocal); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'fake local response'); + expect(fakeLocalRunner.initializeCalled, true); + }); + + test('goes to cloud when local not installed', () async { + final fakeLocalRunner = FakeLocalModelRunner()..isInstalledResult = false; + final mockClient = MockApiClient(response: arbitraryGenerateContentResponse); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferLocal); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'Some Response'); + expect(fakeLocalRunner.initializeCalled, false); + expect(mockClient.requestCount, 1); + }); + + test('falls back to cloud when local fails', () async { + final fakeLocalRunner = FakeLocalModelRunner() + ..isInstalledResult = true + ..generateContentError = Exception('Local engine crash'); + final mockClient = MockApiClient(response: arbitraryGenerateContentResponse); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferLocal); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'Some Response'); + expect(fakeLocalRunner.initializeCalled, true); + expect(mockClient.requestCount, 1); + }); + + test('falls back to cloud when local stream fails', () async { + final fakeLocalRunner = FakeLocalModelRunner() + ..isInstalledResult = true + ..generateContentStreamError = Exception('Local stream crash'); + final mockClient = MockApiClient(streamResponses: [arbitraryGenerateContentResponse]); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferLocal); + + final stream = model.generateContentStream([Content.text('hello')]); + final results = await stream.toList(); + + expect(results.first.text, 'Some Response'); + expect(fakeLocalRunner.initializeCalled, true); + expect(mockClient.requestCount, 1); + }); + }); + + group('preferCloud preference', () { + test('goes to cloud when cloud succeeds', () async { + final fakeLocalRunner = FakeLocalModelRunner(); + final mockClient = MockApiClient(response: arbitraryGenerateContentResponse); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferCloud); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'Some Response'); + expect(fakeLocalRunner.initializeCalled, false); + expect(mockClient.requestCount, 1); + }); + + test('falls back to local when cloud fails and local is installed', () async { + final fakeLocalRunner = FakeLocalModelRunner()..isInstalledResult = true; + final mockClient = MockApiClient(error: Exception('Cloud offline')); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferCloud); + + final response = await model.generateContent([Content.text('hello')]); + + expect(response.text, 'fake local response'); + expect(fakeLocalRunner.initializeCalled, true); + expect(mockClient.requestCount, 1); + }); + + test('does NOT fall back to local when cloud fails but local is NOT installed', () async { + final fakeLocalRunner = FakeLocalModelRunner()..isInstalledResult = false; + final mockClient = MockApiClient(error: Exception('Cloud offline')); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferCloud); + + await expectLater( + model.generateContent([Content.text('hello')]), + throwsA(isA()), + ); + expect(fakeLocalRunner.initializeCalled, false); + expect(mockClient.requestCount, 1); + }); + + test('falls back to local when cloud stream fails and local is installed', () async { + final fakeLocalRunner = FakeLocalModelRunner()..isInstalledResult = true; + final mockClient = MockApiClient(error: Exception('Cloud stream offline')); + final model = createTestModel( + client: mockClient, + fakeLocalRunner: fakeLocalRunner, + ); + + await model.setPreference(HybridPreference.preferCloud); + + final stream = model.generateContentStream([Content.text('hello')]); + final results = await stream.toList(); + + expect(results.first.text, 'fake local stream response'); + expect(fakeLocalRunner.initializeCalled, true); + expect(mockClient.requestCount, 1); + }); + }); + }); +} diff --git a/packages/firebase_data_connect/firebase_data_connect/pubspec.yaml b/packages/firebase_data_connect/firebase_data_connect/pubspec.yaml index 3a398dc706a7..5f22f5ddd485 100644 --- a/packages/firebase_data_connect/firebase_data_connect/pubspec.yaml +++ b/packages/firebase_data_connect/firebase_data_connect/pubspec.yaml @@ -26,7 +26,7 @@ dependencies: path: ^1.9.0 path_provider: ^2.0.0 protobuf: ^3.1.0 - sqlite3: ^2.9.0 + sqlite3: '>=2.9.0 <4.0.0' sqlite3_flutter_libs: ^0.5.40 web_socket_channel: ^3.0.1 diff --git a/tests/windows/flutter/generated_plugin_registrant.cc b/tests/windows/flutter/generated_plugin_registrant.cc index b6d55c18d245..18bc97507433 100644 --- a/tests/windows/flutter/generated_plugin_registrant.cc +++ b/tests/windows/flutter/generated_plugin_registrant.cc @@ -12,6 +12,7 @@ #include #include #include +#include void RegisterPlugins(flutter::PluginRegistry* registry) { FirebaseAppCheckPluginCApiRegisterWithRegistrar( @@ -26,4 +27,6 @@ void RegisterPlugins(flutter::PluginRegistry* registry) { registry->GetRegistrarForPlugin("FirebaseRemoteConfigPluginCApi")); FirebaseStoragePluginCApiRegisterWithRegistrar( registry->GetRegistrarForPlugin("FirebaseStoragePluginCApi")); + FlutterGemmaPluginRegisterWithRegistrar( + registry->GetRegistrarForPlugin("FlutterGemmaPlugin")); } diff --git a/tests/windows/flutter/generated_plugins.cmake b/tests/windows/flutter/generated_plugins.cmake index 89f9a20f171a..234ba1c26211 100644 --- a/tests/windows/flutter/generated_plugins.cmake +++ b/tests/windows/flutter/generated_plugins.cmake @@ -9,9 +9,11 @@ list(APPEND FLUTTER_PLUGIN_LIST firebase_database firebase_remote_config firebase_storage + flutter_gemma ) list(APPEND FLUTTER_FFI_PLUGIN_LIST + jni ) set(PLUGIN_BUNDLED_LIBRARIES)