diff --git a/apps/desktop/src/components/main/sidebar/banner/component.tsx b/apps/desktop/src/components/main/sidebar/banner/component.tsx index 565b5e2c9..b62f17465 100644 --- a/apps/desktop/src/components/main/sidebar/banner/component.tsx +++ b/apps/desktop/src/components/main/sidebar/banner/component.tsx @@ -17,7 +17,10 @@ export function Banner({ className={cn([ "relative group overflow-hidden rounded-lg", "flex flex-col gap-2", - "bg-white border border-neutral-200 shadow-sm p-4", + "bg-white p-4", + banner.variant === "error" + ? "border border-red-300 shadow-sm shadow-red-100" + : "border border-neutral-200 shadow-sm", ])} > {banner.dismissible && onDismiss && ( diff --git a/apps/desktop/src/components/main/sidebar/banner/index.tsx b/apps/desktop/src/components/main/sidebar/banner/index.tsx index bcaf52fcf..398935546 100644 --- a/apps/desktop/src/components/main/sidebar/banner/index.tsx +++ b/apps/desktop/src/components/main/sidebar/banner/index.tsx @@ -1,10 +1,12 @@ import { AnimatePresence, motion } from "motion/react"; import { useCallback, useEffect, useMemo, useState } from "react"; +import type { ServerStatus } from "@hypr/plugin-local-stt"; import { cn } from "@hypr/utils"; import { useAuth } from "../../../../auth"; import { useConfigValues } from "../../../../config/use-config"; +import { useSTTConnection } from "../../../../hooks/useSTTConnection"; import { useTabs } from "../../../../store/zustand/tabs"; import { Banner } from "./component"; import { createBannerRegistry, getBannerToShow } from "./registry"; @@ -32,7 +34,18 @@ export function BannerArea({ "current_stt_model", ] as const); const hasLLMConfigured = !!(current_llm_provider && current_llm_model); - const hasSttConfigured = !!(current_stt_provider && current_stt_model); + + const { conn: sttConnection, local: sttLocal } = useSTTConnection(); + const sttServerStatus = sttLocal.data?.status as ServerStatus | undefined; + + const isLocalSttModel = + current_stt_provider === "hyprnote" && + !!current_stt_model && + current_stt_model !== "cloud"; + + const hasSttConfigured = isLocalSttModel + ? !!sttConnection + : !!(current_stt_provider && current_stt_model && sttConnection); const currentTab = useTabs((state) => state.currentTab); const isAiTranscriptionTabActive = @@ -72,6 +85,8 @@ export function BannerArea({ isAuthenticated, hasLLMConfigured, hasSttConfigured, + sttServerStatus, + isLocalSttModel, isAiTranscriptionTabActive, isAiIntelligenceTabActive, onSignIn: handleSignIn, @@ -82,6 +97,8 @@ export function BannerArea({ isAuthenticated, hasLLMConfigured, hasSttConfigured, + sttServerStatus, + isLocalSttModel, isAiTranscriptionTabActive, isAiIntelligenceTabActive, handleSignIn, diff --git a/apps/desktop/src/components/main/sidebar/banner/registry.tsx b/apps/desktop/src/components/main/sidebar/banner/registry.tsx index 16add99f7..4cb0684c4 100644 --- a/apps/desktop/src/components/main/sidebar/banner/registry.tsx +++ b/apps/desktop/src/components/main/sidebar/banner/registry.tsx @@ -1,3 +1,5 @@ +import type { ServerStatus } from "@hypr/plugin-local-stt"; + import type { BannerCondition, BannerType } from "./types"; type BannerRegistryEntry = { @@ -9,6 +11,8 @@ type BannerRegistryParams = { isAuthenticated: boolean; hasLLMConfigured: boolean; hasSttConfigured: boolean; + sttServerStatus: ServerStatus | undefined; + isLocalSttModel: boolean; isAiTranscriptionTabActive: boolean; isAiIntelligenceTabActive: boolean; onSignIn: () => void | Promise; @@ -20,6 +24,8 @@ export function createBannerRegistry({ isAuthenticated, hasLLMConfigured, hasSttConfigured, + sttServerStatus, + isLocalSttModel, isAiTranscriptionTabActive, isAiIntelligenceTabActive, onSignIn, @@ -28,6 +34,53 @@ export function createBannerRegistry({ }: BannerRegistryParams): BannerRegistryEntry[] { // order matters return [ + { + banner: { + id: "stt-loading", + description: ( + <> + Transcription model is + + loading + + . This may take a moment. + + ), + primaryAction: { + label: "View status", + onClick: onOpenSTTSettings, + }, + dismissible: false, + }, + condition: () => + isLocalSttModel && + sttServerStatus === "loading" && + !hasSttConfigured && + !isAiTranscriptionTabActive, + }, + { + banner: { + id: "stt-unreachable", + variant: "error", + description: ( + <> + Transcription model{" "} + failed to start. + Please try again. + + ), + primaryAction: { + label: "Configure transcription", + onClick: onOpenSTTSettings, + }, + dismissible: false, + }, + condition: () => + isLocalSttModel && + sttServerStatus === "unreachable" && + !hasSttConfigured && + !isAiTranscriptionTabActive, + }, { banner: { id: "missing-stt", @@ -43,7 +96,8 @@ export function createBannerRegistry({ }, dismissible: false, }, - condition: () => !hasSttConfigured && !isAiTranscriptionTabActive, + condition: () => + !hasSttConfigured && !isLocalSttModel && !isAiTranscriptionTabActive, }, { banner: { diff --git a/apps/desktop/src/components/main/sidebar/banner/types.ts b/apps/desktop/src/components/main/sidebar/banner/types.ts index f32cdab22..74fa2b6d4 100644 --- a/apps/desktop/src/components/main/sidebar/banner/types.ts +++ b/apps/desktop/src/components/main/sidebar/banner/types.ts @@ -13,6 +13,7 @@ export type BannerType = { primaryAction?: BannerAction; secondaryAction?: BannerAction; dismissible: boolean; + variant?: "default" | "error"; }; export type BannerCondition = () => boolean; diff --git a/apps/desktop/src/components/settings/ai/llm/select.tsx b/apps/desktop/src/components/settings/ai/llm/select.tsx index a8bb964b0..13802d11b 100644 --- a/apps/desktop/src/components/settings/ai/llm/select.tsx +++ b/apps/desktop/src/components/settings/ai/llm/select.tsx @@ -82,10 +82,10 @@ export function SelectProviderAndModel() {
diff --git a/apps/desktop/src/components/settings/ai/stt/select.tsx b/apps/desktop/src/components/settings/ai/stt/select.tsx index 3086fcf50..61e8e9137 100644 --- a/apps/desktop/src/components/settings/ai/stt/select.tsx +++ b/apps/desktop/src/components/settings/ai/stt/select.tsx @@ -15,6 +15,10 @@ import { cn } from "@hypr/utils"; import { useBillingAccess } from "../../../../billing"; import { useConfigValues } from "../../../../config/use-config"; +import { + isLocalSttModel, + useSTTConnection, +} from "../../../../hooks/useSTTConnection"; import * as settings from "../../../../store/tinybase/settings"; import { getProviderSelectionBlockers, @@ -35,6 +39,13 @@ export function SelectProviderAndModel() { ] as const); const billing = useBillingAccess(); const configuredProviders = useConfiguredMapping(); + const { local: sttLocal } = useSTTConnection(); + const sttServerStatus = sttLocal.data?.status; + + const isLocal = isLocalSttModel(current_stt_provider, current_stt_model); + + const isUnreachable = isLocal && sttServerStatus === "unreachable"; + const isNotConfigured = !current_stt_provider || !current_stt_model; const handleSelectProvider = settings.UI.useSetValueCallback( "current_stt_provider", @@ -79,10 +90,10 @@ export function SelectProviderAndModel() {
@@ -212,7 +223,7 @@ export function SelectProviderAndModel() { )}
- {(!current_stt_provider || !current_stt_model) && ( + {isNotConfigured && (
Transcription model is @@ -220,6 +231,15 @@ export function SelectProviderAndModel() {
)} + + {isUnreachable && ( +
+ + Transcription model{" "} + failed to start. Please try again. + +
+ )}
); diff --git a/apps/desktop/src/hooks/useSTTConnection.ts b/apps/desktop/src/hooks/useSTTConnection.ts index 8d89d6356..3ce48a84a 100644 --- a/apps/desktop/src/hooks/useSTTConnection.ts +++ b/apps/desktop/src/hooks/useSTTConnection.ts @@ -10,6 +10,17 @@ import { ProviderId } from "../components/settings/ai/stt/shared"; import { env } from "../env"; import * as settings from "../store/tinybase/settings"; +export function isLocalSttModel( + provider: string | undefined, + model: string | undefined, +): boolean { + return ( + provider === "hyprnote" && + !!model && + (model.startsWith("am-") || model.startsWith("Quantized")) + ); +} + export const useSTTConnection = () => { const auth = useAuth(); const billing = useBillingAccess(); @@ -26,11 +37,7 @@ export const useSTTConnection = () => { settings.STORE_ID, ) as AIProviderStorage | undefined; - const isLocalModel = - current_stt_provider === "hyprnote" && - !!current_stt_model && - (current_stt_model.startsWith("am-") || - current_stt_model.startsWith("Quantized")); + const isLocalModel = isLocalSttModel(current_stt_provider, current_stt_model); const isCloudModel = current_stt_provider === "hyprnote" && current_stt_model === "cloud";