From 4bb8f1603080f6b90dcabbe2cd9da49dce03665f Mon Sep 17 00:00:00 2001 From: John Jeong Date: Mon, 22 Dec 2025 20:18:56 +0900 Subject: [PATCH 1/2] feat(onboarding): add local STT model selection for Hyprnote --- .../onboarding/configure-notice.tsx | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/apps/desktop/src/components/onboarding/configure-notice.tsx b/apps/desktop/src/components/onboarding/configure-notice.tsx index 2985f9e02..80d80955c 100644 --- a/apps/desktop/src/components/onboarding/configure-notice.tsx +++ b/apps/desktop/src/components/onboarding/configure-notice.tsx @@ -1,13 +1,59 @@ +import { useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useState } from "react"; + +import { + commands as localSttCommands, + type SupportedSttModel, +} from "@hypr/plugin-local-stt"; +import { cn } from "@hypr/utils"; + import { Route } from "../../routes/app/onboarding"; +import * as settings from "../../store/tinybase/settings"; import { getBack, getNext, type StepProps } from "./config"; import { OnboardingContainer } from "./shared"; export const STEP_ID_CONFIGURE_NOTICE = "configure-notice" as const; +const sttModelQueries = { + isDownloaded: (model: SupportedSttModel) => ({ + refetchInterval: 1000, + queryKey: ["stt", "model", model, "downloaded"], + queryFn: () => localSttCommands.isModelDownloaded(model), + select: (result: { status: string; data?: boolean; error?: string }) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), + isDownloading: (model: SupportedSttModel) => ({ + refetchInterval: 1000, + queryKey: ["stt", "model", model, "downloading"], + queryFn: () => localSttCommands.isModelDownloading(model), + select: (result: { status: string; data?: boolean; error?: string }) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), +}; + export function ConfigureNotice({ onNavigate }: StepProps) { const search = Route.useSearch(); const backStep = getBack(search); + if (search.local) { + return ( + onNavigate({ ...search, step: backStep }) : undefined + } + /> + ); + } + return ( void; +}) { + const search = Route.useSearch(); + const [selectedModel, setSelectedModel] = useState( + null, + ); + + const handleSelectProvider = settings.UI.useSetValueCallback( + "current_stt_provider", + (provider: string) => provider, + [], + settings.STORE_ID, + ); + + const handleSelectModel = settings.UI.useSetValueCallback( + "current_stt_model", + (model: string) => model, + [], + settings.STORE_ID, + ); + + const p2Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v2")); + const p3Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v3")); + + useEffect(() => { + if (p2Downloaded.data || p3Downloaded.data) { + onNavigate({ ...search, step: getNext(search) }); + } + }, [p2Downloaded.data, p3Downloaded.data, search, onNavigate]); + + const handleUseModel = useCallback(() => { + if (!selectedModel) return; + + handleSelectProvider("hyprnote"); + handleSelectModel(selectedModel); + void localSttCommands.downloadModel(selectedModel); + onNavigate({ ...search, step: getNext(search) }); + }, [ + selectedModel, + search, + onNavigate, + handleSelectProvider, + handleSelectModel, + ]); + + if (p2Downloaded.isLoading || p3Downloaded.isLoading) { + return ( + +
+
+
+
+ ); + } + + return ( + +
+ setSelectedModel("am-parakeet-v2")} + /> + setSelectedModel("am-parakeet-v3")} + /> +
+ +
+ +
+
+ ); +} + +function LocalModelRow({ + model, + displayName, + description, + isSelected, + onSelect, +}: { + model: SupportedSttModel; + displayName: string; + description: string; + isSelected: boolean; + onSelect: () => void; +}) { + const isDownloaded = useQuery(sttModelQueries.isDownloaded(model)); + + return ( +
{ + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onSelect(); + } + }} + className={cn([ + "relative border rounded-xl py-3 px-4 flex flex-col gap-1 text-left transition-all cursor-pointer", + isSelected + ? "border-stone-500 bg-stone-50" + : "border-neutral-200 hover:border-neutral-300", + ])} + > +
+
+

{displayName}

+

{description}

+
+ {isDownloaded.data && ( + + Already downloaded + + )} +
+
+ ); +} + function Requirement({ title, description, From 82d3ed7985bf4b168f1b77155e45a76c399ff834 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 13:08:10 +0000 Subject: [PATCH 2/2] refactor: extract shared local STT model queries and download hook - Create useLocalSttModel.ts with shared query key factory (localSttKeys) and query options (localSttQueries) - Move useLocalModelDownload hook to shared location for reuse - Update configure.tsx (settings) to use shared hook - Update configure-notice.tsx (onboarding) to use shared queries - Remove duplicated sttModelQueries from both files This follows React Query key factory patterns for better cache management and code reuse. Co-Authored-By: yujonglee --- .../onboarding/configure-notice.tsx | 32 +---- .../components/settings/ai/stt/configure.tsx | 84 +----------- .../src/components/settings/ai/stt/shared.tsx | 34 +---- apps/desktop/src/hooks/useLocalSttModel.ts | 126 ++++++++++++++++++ 4 files changed, 136 insertions(+), 140 deletions(-) create mode 100644 apps/desktop/src/hooks/useLocalSttModel.ts diff --git a/apps/desktop/src/components/onboarding/configure-notice.tsx b/apps/desktop/src/components/onboarding/configure-notice.tsx index 80d80955c..ecb5184c7 100644 --- a/apps/desktop/src/components/onboarding/configure-notice.tsx +++ b/apps/desktop/src/components/onboarding/configure-notice.tsx @@ -7,6 +7,7 @@ import { } from "@hypr/plugin-local-stt"; import { cn } from "@hypr/utils"; +import { localSttQueries } from "../../hooks/useLocalSttModel"; import { Route } from "../../routes/app/onboarding"; import * as settings from "../../store/tinybase/settings"; import { getBack, getNext, type StepProps } from "./config"; @@ -14,31 +15,6 @@ import { OnboardingContainer } from "./shared"; export const STEP_ID_CONFIGURE_NOTICE = "configure-notice" as const; -const sttModelQueries = { - isDownloaded: (model: SupportedSttModel) => ({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloaded"], - queryFn: () => localSttCommands.isModelDownloaded(model), - select: (result: { status: string; data?: boolean; error?: string }) => { - if (result.status === "error") { - throw new Error(result.error); - } - return result.data; - }, - }), - isDownloading: (model: SupportedSttModel) => ({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloading"], - queryFn: () => localSttCommands.isModelDownloading(model), - select: (result: { status: string; data?: boolean; error?: string }) => { - if (result.status === "error") { - throw new Error(result.error); - } - return result.data; - }, - }), -}; - export function ConfigureNotice({ onNavigate }: StepProps) { const search = Route.useSearch(); const backStep = getBack(search); @@ -111,8 +87,8 @@ function LocalConfigureNotice({ settings.STORE_ID, ); - const p2Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v2")); - const p3Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v3")); + const p2Downloaded = useQuery(localSttQueries.isDownloaded("am-parakeet-v2")); + const p3Downloaded = useQuery(localSttQueries.isDownloaded("am-parakeet-v3")); useEffect(() => { if (p2Downloaded.data || p3Downloaded.data) { @@ -202,7 +178,7 @@ function LocalModelRow({ isSelected: boolean; onSelect: () => void; }) { - const isDownloaded = useQuery(sttModelQueries.isDownloaded(model)); + const isDownloaded = useQuery(localSttQueries.isDownloaded(model)); return (
void, -) { - const [progress, setProgress] = useState(0); - const [isStarting, setIsStarting] = useState(false); - const [hasError, setHasError] = useState(false); - - const isDownloaded = useQuery(sttModelQueries.isDownloaded(model)); - const isDownloading = useQuery(sttModelQueries.isDownloading(model)); - - const showProgress = - !isDownloaded.data && (isStarting || (isDownloading.data ?? false)); - - useEffect(() => { - if (isDownloading.data) { - setIsStarting(false); - } - }, [isDownloading.data]); - - useEffect(() => { - const unlisten = localSttEvents.downloadProgressPayload.listen((event) => { - if (event.payload.model === model) { - if (event.payload.progress < 0) { - setHasError(true); - setIsStarting(false); - setProgress(0); - } else { - setHasError(false); - const next = Math.max(0, Math.min(100, event.payload.progress)); - setProgress(next); - } - } - }); - - return () => { - void unlisten.then((fn) => fn()); - }; - }, [model]); - - useEffect(() => { - if (isDownloaded.data && progress > 0) { - setProgress(0); - onDownloadComplete?.(model); - } - }, [isDownloaded.data, model, onDownloadComplete, progress]); - - const handleDownload = () => { - if (isDownloaded.data || isDownloading.data || isStarting) { - return; - } - setHasError(false); - setIsStarting(true); - setProgress(0); - void localSttCommands.downloadModel(model).then((result) => { - if (result.status === "error") { - setHasError(true); - setIsStarting(false); - } - }); - }; - - const handleCancel = () => { - void localSttCommands.cancelDownload(model); - setIsStarting(false); - setProgress(0); - }; - - return { - progress, - hasError, - isDownloaded: isDownloaded.data ?? false, - showProgress, - handleDownload, - handleCancel, - }; -} - function ProviderContext({ providerId }: { providerId: ProviderId }) { const content = providerId === "hyprnote" diff --git a/apps/desktop/src/components/settings/ai/stt/shared.tsx b/apps/desktop/src/components/settings/ai/stt/shared.tsx index 762adf087..510d97a65 100644 --- a/apps/desktop/src/components/settings/ai/stt/shared.tsx +++ b/apps/desktop/src/components/settings/ai/stt/shared.tsx @@ -1,9 +1,7 @@ import { Icon } from "@iconify-icon/react"; import { AssemblyAI, Fireworks, OpenAI } from "@lobehub/icons"; -import { queryOptions } from "@tanstack/react-query"; import type { ReactNode } from "react"; -import { commands as localSttCommands } from "@hypr/plugin-local-stt"; import type { AmModel, SupportedSttModel, @@ -11,12 +9,15 @@ import type { } from "@hypr/plugin-local-stt"; import { env } from "../../../../env"; +import { localSttQueries } from "../../../../hooks/useLocalSttModel"; import { type ProviderRequirement, requiresEntitlement, } from "../shared/eligibility"; import { sortProviders } from "../shared/sort-providers"; +export { localSttQueries as sttModelQueries }; + type Provider = { disabled: boolean; id: string; @@ -200,32 +201,3 @@ export const sttProviderRequiresPro = (providerId: ProviderId) => { const provider = PROVIDERS.find((p) => p.id === providerId); return provider ? requiresEntitlement(provider.requirements, "pro") : false; }; - -export const sttModelQueries = { - isDownloaded: (model: SupportedSttModel) => - queryOptions({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloaded"], - queryFn: () => localSttCommands.isModelDownloaded(model), - select: (result) => { - if (result.status === "error") { - throw new Error(result.error); - } - - return result.data; - }, - }), - isDownloading: (model: SupportedSttModel) => - queryOptions({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloading"], - queryFn: () => localSttCommands.isModelDownloading(model), - select: (result) => { - if (result.status === "error") { - throw new Error(result.error); - } - - return result.data; - }, - }), -}; diff --git a/apps/desktop/src/hooks/useLocalSttModel.ts b/apps/desktop/src/hooks/useLocalSttModel.ts new file mode 100644 index 000000000..680d1dcba --- /dev/null +++ b/apps/desktop/src/hooks/useLocalSttModel.ts @@ -0,0 +1,126 @@ +import { queryOptions } from "@tanstack/react-query"; +import { useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useState } from "react"; + +import { + commands as localSttCommands, + events as localSttEvents, + type SupportedSttModel, +} from "@hypr/plugin-local-stt"; + +export const localSttKeys = { + all: ["local-stt"] as const, + models: () => [...localSttKeys.all, "model"] as const, + model: (model: SupportedSttModel) => + [...localSttKeys.models(), model] as const, + modelDownloaded: (model: SupportedSttModel) => + [...localSttKeys.model(model), "downloaded"] as const, + modelDownloading: (model: SupportedSttModel) => + [...localSttKeys.model(model), "downloading"] as const, +}; + +export const localSttQueries = { + isDownloaded: (model: SupportedSttModel) => + queryOptions({ + refetchInterval: 1000, + queryKey: localSttKeys.modelDownloaded(model), + queryFn: () => localSttCommands.isModelDownloaded(model), + select: (result) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), + isDownloading: (model: SupportedSttModel) => + queryOptions({ + refetchInterval: 1000, + queryKey: localSttKeys.modelDownloading(model), + queryFn: () => localSttCommands.isModelDownloading(model), + select: (result) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), +}; + +export function useLocalModelDownload( + model: SupportedSttModel, + onDownloadComplete?: (model: SupportedSttModel) => void, +) { + const [progress, setProgress] = useState(0); + const [isStarting, setIsStarting] = useState(false); + const [hasError, setHasError] = useState(false); + + const isDownloaded = useQuery(localSttQueries.isDownloaded(model)); + const isDownloading = useQuery(localSttQueries.isDownloading(model)); + + const showProgress = + !isDownloaded.data && (isStarting || (isDownloading.data ?? false)); + + useEffect(() => { + if (isDownloading.data) { + setIsStarting(false); + } + }, [isDownloading.data]); + + useEffect(() => { + const unlisten = localSttEvents.downloadProgressPayload.listen((event) => { + if (event.payload.model === model) { + if (event.payload.progress < 0) { + setHasError(true); + setIsStarting(false); + setProgress(0); + } else { + setHasError(false); + const next = Math.max(0, Math.min(100, event.payload.progress)); + setProgress(next); + } + } + }); + + return () => { + void unlisten.then((fn) => fn()); + }; + }, [model]); + + useEffect(() => { + if (isDownloaded.data && progress > 0) { + setProgress(0); + onDownloadComplete?.(model); + } + }, [isDownloaded.data, model, onDownloadComplete, progress]); + + const handleDownload = useCallback(() => { + if (isDownloaded.data || isDownloading.data || isStarting) { + return; + } + setHasError(false); + setIsStarting(true); + setProgress(0); + void localSttCommands.downloadModel(model).then((result) => { + if (result.status === "error") { + setHasError(true); + setIsStarting(false); + } + }); + }, [isDownloaded.data, isDownloading.data, isStarting, model]); + + const handleCancel = useCallback(() => { + void localSttCommands.cancelDownload(model); + setIsStarting(false); + setProgress(0); + }, [model]); + + return { + progress, + hasError, + isDownloaded: isDownloaded.data ?? false, + isDownloadedLoading: isDownloaded.isLoading, + showProgress, + handleDownload, + handleCancel, + }; +}