diff --git a/apps/desktop/src/components/onboarding/configure-notice.tsx b/apps/desktop/src/components/onboarding/configure-notice.tsx index 2985f9e02..ecb5184c7 100644 --- a/apps/desktop/src/components/onboarding/configure-notice.tsx +++ b/apps/desktop/src/components/onboarding/configure-notice.tsx @@ -1,4 +1,15 @@ +import { useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useState } from "react"; + +import { + commands as localSttCommands, + type SupportedSttModel, +} from "@hypr/plugin-local-stt"; +import { cn } from "@hypr/utils"; + +import { localSttQueries } from "../../hooks/useLocalSttModel"; import { Route } from "../../routes/app/onboarding"; +import * as settings from "../../store/tinybase/settings"; import { getBack, getNext, type StepProps } from "./config"; import { OnboardingContainer } from "./shared"; @@ -8,6 +19,17 @@ export function ConfigureNotice({ onNavigate }: StepProps) { const search = Route.useSearch(); const backStep = getBack(search); + if (search.local) { + return ( + onNavigate({ ...search, step: backStep }) : undefined + } + /> + ); + } + return ( void; +}) { + const search = Route.useSearch(); + const [selectedModel, setSelectedModel] = useState( + null, + ); + + const handleSelectProvider = settings.UI.useSetValueCallback( + "current_stt_provider", + (provider: string) => provider, + [], + settings.STORE_ID, + ); + + const handleSelectModel = settings.UI.useSetValueCallback( + "current_stt_model", + (model: string) => model, + [], + settings.STORE_ID, + ); + + const p2Downloaded = useQuery(localSttQueries.isDownloaded("am-parakeet-v2")); + const p3Downloaded = useQuery(localSttQueries.isDownloaded("am-parakeet-v3")); + + useEffect(() => { + if (p2Downloaded.data || p3Downloaded.data) { + onNavigate({ ...search, step: getNext(search) }); + } + }, [p2Downloaded.data, p3Downloaded.data, search, onNavigate]); + + const handleUseModel = useCallback(() => { + if (!selectedModel) return; + + handleSelectProvider("hyprnote"); + handleSelectModel(selectedModel); + void localSttCommands.downloadModel(selectedModel); + onNavigate({ ...search, step: getNext(search) }); + }, [ + selectedModel, + search, + onNavigate, + handleSelectProvider, + handleSelectModel, + ]); + + if (p2Downloaded.isLoading || p3Downloaded.isLoading) { + return ( + +
+
+
+
+ ); + } + + return ( + +
+ setSelectedModel("am-parakeet-v2")} + /> + setSelectedModel("am-parakeet-v3")} + /> +
+ +
+ +
+
+ ); +} + +function LocalModelRow({ + model, + displayName, + description, + isSelected, + onSelect, +}: { + model: SupportedSttModel; + displayName: string; + description: string; + isSelected: boolean; + onSelect: () => void; +}) { + const isDownloaded = useQuery(localSttQueries.isDownloaded(model)); + + return ( +
{ + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onSelect(); + } + }} + className={cn([ + "relative border rounded-xl py-3 px-4 flex flex-col gap-1 text-left transition-all cursor-pointer", + isSelected + ? "border-stone-500 bg-stone-50" + : "border-neutral-200 hover:border-neutral-300", + ])} + > +
+
+

{displayName}

+

{description}

+
+ {isDownloaded.data && ( + + Already downloaded + + )} +
+
+ ); +} + function Requirement({ title, description, diff --git a/apps/desktop/src/components/settings/ai/stt/configure.tsx b/apps/desktop/src/components/settings/ai/stt/configure.tsx index eaf8960cd..a6098b0d6 100644 --- a/apps/desktop/src/components/settings/ai/stt/configure.tsx +++ b/apps/desktop/src/components/settings/ai/stt/configure.tsx @@ -2,11 +2,10 @@ import { Icon } from "@iconify-icon/react"; import { useQuery } from "@tanstack/react-query"; import { openPath } from "@tauri-apps/plugin-opener"; import { arch, platform } from "@tauri-apps/plugin-os"; -import { useCallback, useEffect, useState } from "react"; +import { useCallback } from "react"; import { commands as localSttCommands, - events as localSttEvents, type SupportedSttModel, } from "@hypr/plugin-local-stt"; import { @@ -20,9 +19,10 @@ import { cn } from "@hypr/utils"; import { useBillingAccess } from "../../../../billing"; import { useListener } from "../../../../contexts/listener"; +import { useLocalModelDownload } from "../../../../hooks/useLocalSttModel"; import * as settings from "../../../../store/tinybase/settings"; import { NonHyprProviderCard, StyledStreamdown } from "../shared"; -import { ProviderId, PROVIDERS, sttModelQueries } from "./shared"; +import { ProviderId, PROVIDERS } from "./shared"; export function ConfigureProviders() { return ( @@ -344,84 +344,6 @@ function HyprProviderLocalRow({ ); } -function useLocalModelDownload( - model: SupportedSttModel, - onDownloadComplete?: (model: SupportedSttModel) => void, -) { - const [progress, setProgress] = useState(0); - const [isStarting, setIsStarting] = useState(false); - const [hasError, setHasError] = useState(false); - - const isDownloaded = useQuery(sttModelQueries.isDownloaded(model)); - const isDownloading = useQuery(sttModelQueries.isDownloading(model)); - - const showProgress = - !isDownloaded.data && (isStarting || (isDownloading.data ?? false)); - - useEffect(() => { - if (isDownloading.data) { - setIsStarting(false); - } - }, [isDownloading.data]); - - useEffect(() => { - const unlisten = localSttEvents.downloadProgressPayload.listen((event) => { - if (event.payload.model === model) { - if (event.payload.progress < 0) { - setHasError(true); - setIsStarting(false); - setProgress(0); - } else { - setHasError(false); - const next = Math.max(0, Math.min(100, event.payload.progress)); - setProgress(next); - } - } - }); - - return () => { - void unlisten.then((fn) => fn()); - }; - }, [model]); - - useEffect(() => { - if (isDownloaded.data && progress > 0) { - setProgress(0); - onDownloadComplete?.(model); - } - }, [isDownloaded.data, model, onDownloadComplete, progress]); - - const handleDownload = () => { - if (isDownloaded.data || isDownloading.data || isStarting) { - return; - } - setHasError(false); - setIsStarting(true); - setProgress(0); - void localSttCommands.downloadModel(model).then((result) => { - if (result.status === "error") { - setHasError(true); - setIsStarting(false); - } - }); - }; - - const handleCancel = () => { - void localSttCommands.cancelDownload(model); - setIsStarting(false); - setProgress(0); - }; - - return { - progress, - hasError, - isDownloaded: isDownloaded.data ?? false, - showProgress, - handleDownload, - handleCancel, - }; -} - function ProviderContext({ providerId }: { providerId: ProviderId }) { const content = providerId === "hyprnote" diff --git a/apps/desktop/src/components/settings/ai/stt/shared.tsx b/apps/desktop/src/components/settings/ai/stt/shared.tsx index 762adf087..510d97a65 100644 --- a/apps/desktop/src/components/settings/ai/stt/shared.tsx +++ b/apps/desktop/src/components/settings/ai/stt/shared.tsx @@ -1,9 +1,7 @@ import { Icon } from "@iconify-icon/react"; import { AssemblyAI, Fireworks, OpenAI } from "@lobehub/icons"; -import { queryOptions } from "@tanstack/react-query"; import type { ReactNode } from "react"; -import { commands as localSttCommands } from "@hypr/plugin-local-stt"; import type { AmModel, SupportedSttModel, @@ -11,12 +9,15 @@ import type { } from "@hypr/plugin-local-stt"; import { env } from "../../../../env"; +import { localSttQueries } from "../../../../hooks/useLocalSttModel"; import { type ProviderRequirement, requiresEntitlement, } from "../shared/eligibility"; import { sortProviders } from "../shared/sort-providers"; +export { localSttQueries as sttModelQueries }; + type Provider = { disabled: boolean; id: string; @@ -200,32 +201,3 @@ export const sttProviderRequiresPro = (providerId: ProviderId) => { const provider = PROVIDERS.find((p) => p.id === providerId); return provider ? requiresEntitlement(provider.requirements, "pro") : false; }; - -export const sttModelQueries = { - isDownloaded: (model: SupportedSttModel) => - queryOptions({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloaded"], - queryFn: () => localSttCommands.isModelDownloaded(model), - select: (result) => { - if (result.status === "error") { - throw new Error(result.error); - } - - return result.data; - }, - }), - isDownloading: (model: SupportedSttModel) => - queryOptions({ - refetchInterval: 1000, - queryKey: ["stt", "model", model, "downloading"], - queryFn: () => localSttCommands.isModelDownloading(model), - select: (result) => { - if (result.status === "error") { - throw new Error(result.error); - } - - return result.data; - }, - }), -}; diff --git a/apps/desktop/src/hooks/useLocalSttModel.ts b/apps/desktop/src/hooks/useLocalSttModel.ts new file mode 100644 index 000000000..680d1dcba --- /dev/null +++ b/apps/desktop/src/hooks/useLocalSttModel.ts @@ -0,0 +1,126 @@ +import { queryOptions } from "@tanstack/react-query"; +import { useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useState } from "react"; + +import { + commands as localSttCommands, + events as localSttEvents, + type SupportedSttModel, +} from "@hypr/plugin-local-stt"; + +export const localSttKeys = { + all: ["local-stt"] as const, + models: () => [...localSttKeys.all, "model"] as const, + model: (model: SupportedSttModel) => + [...localSttKeys.models(), model] as const, + modelDownloaded: (model: SupportedSttModel) => + [...localSttKeys.model(model), "downloaded"] as const, + modelDownloading: (model: SupportedSttModel) => + [...localSttKeys.model(model), "downloading"] as const, +}; + +export const localSttQueries = { + isDownloaded: (model: SupportedSttModel) => + queryOptions({ + refetchInterval: 1000, + queryKey: localSttKeys.modelDownloaded(model), + queryFn: () => localSttCommands.isModelDownloaded(model), + select: (result) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), + isDownloading: (model: SupportedSttModel) => + queryOptions({ + refetchInterval: 1000, + queryKey: localSttKeys.modelDownloading(model), + queryFn: () => localSttCommands.isModelDownloading(model), + select: (result) => { + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + }), +}; + +export function useLocalModelDownload( + model: SupportedSttModel, + onDownloadComplete?: (model: SupportedSttModel) => void, +) { + const [progress, setProgress] = useState(0); + const [isStarting, setIsStarting] = useState(false); + const [hasError, setHasError] = useState(false); + + const isDownloaded = useQuery(localSttQueries.isDownloaded(model)); + const isDownloading = useQuery(localSttQueries.isDownloading(model)); + + const showProgress = + !isDownloaded.data && (isStarting || (isDownloading.data ?? false)); + + useEffect(() => { + if (isDownloading.data) { + setIsStarting(false); + } + }, [isDownloading.data]); + + useEffect(() => { + const unlisten = localSttEvents.downloadProgressPayload.listen((event) => { + if (event.payload.model === model) { + if (event.payload.progress < 0) { + setHasError(true); + setIsStarting(false); + setProgress(0); + } else { + setHasError(false); + const next = Math.max(0, Math.min(100, event.payload.progress)); + setProgress(next); + } + } + }); + + return () => { + void unlisten.then((fn) => fn()); + }; + }, [model]); + + useEffect(() => { + if (isDownloaded.data && progress > 0) { + setProgress(0); + onDownloadComplete?.(model); + } + }, [isDownloaded.data, model, onDownloadComplete, progress]); + + const handleDownload = useCallback(() => { + if (isDownloaded.data || isDownloading.data || isStarting) { + return; + } + setHasError(false); + setIsStarting(true); + setProgress(0); + void localSttCommands.downloadModel(model).then((result) => { + if (result.status === "error") { + setHasError(true); + setIsStarting(false); + } + }); + }, [isDownloaded.data, isDownloading.data, isStarting, model]); + + const handleCancel = useCallback(() => { + void localSttCommands.cancelDownload(model); + setIsStarting(false); + setProgress(0); + }, [model]); + + return { + progress, + hasError, + isDownloaded: isDownloaded.data ?? false, + isDownloadedLoading: isDownloaded.isLoading, + showProgress, + handleDownload, + handleCancel, + }; +}