|
| 1 | +import { useQuery } from "@tanstack/react-query"; |
| 2 | +import { useCallback, useEffect, useState } from "react"; |
| 3 | + |
| 4 | +import { |
| 5 | + commands as localSttCommands, |
| 6 | + type SupportedSttModel, |
| 7 | +} from "@hypr/plugin-local-stt"; |
| 8 | +import { cn } from "@hypr/utils"; |
| 9 | + |
1 | 10 | import { Route } from "../../routes/app/onboarding"; |
| 11 | +import * as settings from "../../store/tinybase/settings"; |
2 | 12 | import { getBack, getNext, type StepProps } from "./config"; |
3 | 13 | import { OnboardingContainer } from "./shared"; |
4 | 14 |
|
5 | 15 | export const STEP_ID_CONFIGURE_NOTICE = "configure-notice" as const; |
6 | 16 |
|
| 17 | +const sttModelQueries = { |
| 18 | + isDownloaded: (model: SupportedSttModel) => ({ |
| 19 | + refetchInterval: 1000, |
| 20 | + queryKey: ["stt", "model", model, "downloaded"], |
| 21 | + queryFn: () => localSttCommands.isModelDownloaded(model), |
| 22 | + select: (result: { status: string; data?: boolean; error?: string }) => { |
| 23 | + if (result.status === "error") { |
| 24 | + throw new Error(result.error); |
| 25 | + } |
| 26 | + return result.data; |
| 27 | + }, |
| 28 | + }), |
| 29 | + isDownloading: (model: SupportedSttModel) => ({ |
| 30 | + refetchInterval: 1000, |
| 31 | + queryKey: ["stt", "model", model, "downloading"], |
| 32 | + queryFn: () => localSttCommands.isModelDownloading(model), |
| 33 | + select: (result: { status: string; data?: boolean; error?: string }) => { |
| 34 | + if (result.status === "error") { |
| 35 | + throw new Error(result.error); |
| 36 | + } |
| 37 | + return result.data; |
| 38 | + }, |
| 39 | + }), |
| 40 | +}; |
| 41 | + |
7 | 42 | export function ConfigureNotice({ onNavigate }: StepProps) { |
8 | 43 | const search = Route.useSearch(); |
9 | 44 | const backStep = getBack(search); |
10 | 45 |
|
| 46 | + if (search.local) { |
| 47 | + return ( |
| 48 | + <LocalConfigureNotice |
| 49 | + onNavigate={onNavigate} |
| 50 | + onBack={ |
| 51 | + backStep ? () => onNavigate({ ...search, step: backStep }) : undefined |
| 52 | + } |
| 53 | + /> |
| 54 | + ); |
| 55 | + } |
| 56 | + |
11 | 57 | return ( |
12 | 58 | <OnboardingContainer |
13 | 59 | title="AI models are needed for best experience" |
@@ -39,6 +85,158 @@ export function ConfigureNotice({ onNavigate }: StepProps) { |
39 | 85 | ); |
40 | 86 | } |
41 | 87 |
|
| 88 | +function LocalConfigureNotice({ |
| 89 | + onNavigate, |
| 90 | + onBack, |
| 91 | +}: { |
| 92 | + onNavigate: StepProps["onNavigate"]; |
| 93 | + onBack?: () => void; |
| 94 | +}) { |
| 95 | + const search = Route.useSearch(); |
| 96 | + const [selectedModel, setSelectedModel] = useState<SupportedSttModel | null>( |
| 97 | + null, |
| 98 | + ); |
| 99 | + |
| 100 | + const handleSelectProvider = settings.UI.useSetValueCallback( |
| 101 | + "current_stt_provider", |
| 102 | + (provider: string) => provider, |
| 103 | + [], |
| 104 | + settings.STORE_ID, |
| 105 | + ); |
| 106 | + |
| 107 | + const handleSelectModel = settings.UI.useSetValueCallback( |
| 108 | + "current_stt_model", |
| 109 | + (model: string) => model, |
| 110 | + [], |
| 111 | + settings.STORE_ID, |
| 112 | + ); |
| 113 | + |
| 114 | + const p2Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v2")); |
| 115 | + const p3Downloaded = useQuery(sttModelQueries.isDownloaded("am-parakeet-v3")); |
| 116 | + |
| 117 | + useEffect(() => { |
| 118 | + if (p2Downloaded.data || p3Downloaded.data) { |
| 119 | + onNavigate({ ...search, step: getNext(search) }); |
| 120 | + } |
| 121 | + }, [p2Downloaded.data, p3Downloaded.data, search, onNavigate]); |
| 122 | + |
| 123 | + const handleUseModel = useCallback(() => { |
| 124 | + if (!selectedModel) return; |
| 125 | + |
| 126 | + handleSelectProvider("hyprnote"); |
| 127 | + handleSelectModel(selectedModel); |
| 128 | + void localSttCommands.downloadModel(selectedModel); |
| 129 | + onNavigate({ ...search, step: getNext(search) }); |
| 130 | + }, [ |
| 131 | + selectedModel, |
| 132 | + search, |
| 133 | + onNavigate, |
| 134 | + handleSelectProvider, |
| 135 | + handleSelectModel, |
| 136 | + ]); |
| 137 | + |
| 138 | + if (p2Downloaded.isLoading || p3Downloaded.isLoading) { |
| 139 | + return ( |
| 140 | + <OnboardingContainer |
| 141 | + title="Checking for existing models..." |
| 142 | + onBack={onBack} |
| 143 | + > |
| 144 | + <div className="flex justify-center py-8"> |
| 145 | + <div className="animate-spin rounded-full h-8 w-8 border-b-2 border-stone-500"></div> |
| 146 | + </div> |
| 147 | + </OnboardingContainer> |
| 148 | + ); |
| 149 | + } |
| 150 | + |
| 151 | + return ( |
| 152 | + <OnboardingContainer |
| 153 | + title="Help Hyprnote listen to your conversations" |
| 154 | + description="Select a speech-to-text model to download" |
| 155 | + onBack={onBack} |
| 156 | + > |
| 157 | + <div className="flex flex-col gap-3"> |
| 158 | + <LocalModelRow |
| 159 | + model="am-parakeet-v2" |
| 160 | + displayName="Parakeet v2" |
| 161 | + description="Best for English" |
| 162 | + isSelected={selectedModel === "am-parakeet-v2"} |
| 163 | + onSelect={() => setSelectedModel("am-parakeet-v2")} |
| 164 | + /> |
| 165 | + <LocalModelRow |
| 166 | + model="am-parakeet-v3" |
| 167 | + displayName="Parakeet v3" |
| 168 | + description="Better for European languages" |
| 169 | + isSelected={selectedModel === "am-parakeet-v3"} |
| 170 | + onSelect={() => setSelectedModel("am-parakeet-v3")} |
| 171 | + /> |
| 172 | + </div> |
| 173 | + |
| 174 | + <div className="flex flex-col gap-3 mt-4"> |
| 175 | + <button |
| 176 | + onClick={handleUseModel} |
| 177 | + disabled={!selectedModel} |
| 178 | + className={cn([ |
| 179 | + "w-full py-3 rounded-full text-white text-sm font-medium duration-150", |
| 180 | + selectedModel |
| 181 | + ? "bg-gradient-to-t from-stone-600 to-stone-500 hover:scale-[1.01] active:scale-[0.99]" |
| 182 | + : "bg-gray-300 cursor-not-allowed opacity-50", |
| 183 | + ])} |
| 184 | + > |
| 185 | + Use this model |
| 186 | + </button> |
| 187 | + </div> |
| 188 | + </OnboardingContainer> |
| 189 | + ); |
| 190 | +} |
| 191 | + |
| 192 | +function LocalModelRow({ |
| 193 | + model, |
| 194 | + displayName, |
| 195 | + description, |
| 196 | + isSelected, |
| 197 | + onSelect, |
| 198 | +}: { |
| 199 | + model: SupportedSttModel; |
| 200 | + displayName: string; |
| 201 | + description: string; |
| 202 | + isSelected: boolean; |
| 203 | + onSelect: () => void; |
| 204 | +}) { |
| 205 | + const isDownloaded = useQuery(sttModelQueries.isDownloaded(model)); |
| 206 | + |
| 207 | + return ( |
| 208 | + <div |
| 209 | + role="button" |
| 210 | + tabIndex={0} |
| 211 | + onClick={onSelect} |
| 212 | + onKeyDown={(e) => { |
| 213 | + if (e.key === "Enter" || e.key === " ") { |
| 214 | + e.preventDefault(); |
| 215 | + onSelect(); |
| 216 | + } |
| 217 | + }} |
| 218 | + className={cn([ |
| 219 | + "relative border rounded-xl py-3 px-4 flex flex-col gap-1 text-left transition-all cursor-pointer", |
| 220 | + isSelected |
| 221 | + ? "border-stone-500 bg-stone-50" |
| 222 | + : "border-neutral-200 hover:border-neutral-300", |
| 223 | + ])} |
| 224 | + > |
| 225 | + <div className="flex items-center justify-between w-full"> |
| 226 | + <div className="flex flex-col gap-1"> |
| 227 | + <p className="text-sm font-medium">{displayName}</p> |
| 228 | + <p className="text-xs text-neutral-500 flex-1">{description}</p> |
| 229 | + </div> |
| 230 | + {isDownloaded.data && ( |
| 231 | + <span className="text-xs text-green-600 font-medium"> |
| 232 | + Already downloaded |
| 233 | + </span> |
| 234 | + )} |
| 235 | + </div> |
| 236 | + </div> |
| 237 | + ); |
| 238 | +} |
| 239 | + |
42 | 240 | function Requirement({ |
43 | 241 | title, |
44 | 242 | description, |
|
0 commit comments