diff --git a/apps/sim/lib/api-key/byok.test.ts b/apps/sim/lib/api-key/byok.test.ts index 439c392d94..6b288b6213 100644 --- a/apps/sim/lib/api-key/byok.test.ts +++ b/apps/sim/lib/api-key/byok.test.ts @@ -3,9 +3,8 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockOrderBy, mockGetWorkspaceById, mockDecryptSecret } = vi.hoisted(() => ({ +const { mockOrderBy, mockDecryptSecret } = vi.hoisted(() => ({ mockOrderBy: vi.fn(), - mockGetWorkspaceById: vi.fn(), mockDecryptSecret: vi.fn(), })) @@ -19,10 +18,6 @@ vi.mock('@sim/db', () => ({ }, })) -vi.mock('@/lib/workspaces/permissions/utils', () => ({ - getWorkspaceById: mockGetWorkspaceById, -})) - vi.mock('@/lib/core/security/encryption', () => ({ decryptSecret: mockDecryptSecret, })) @@ -70,7 +65,6 @@ const storedKey = (id: string) => ({ id, encryptedApiKey: `encrypted-${id}` }) describe('getBYOKKey', () => { beforeEach(() => { vi.clearAllMocks() - mockGetWorkspaceById.mockResolvedValue({ id: 'workspace' }) mockOrderBy.mockResolvedValue([]) mockDecryptSecret.mockImplementation(async (encrypted: string) => ({ decrypted: encrypted.replace('encrypted-', 'decrypted-'), @@ -80,13 +74,6 @@ describe('getBYOKKey', () => { it('returns null when no workspaceId is provided', async () => { expect(await getBYOKKey(undefined, 'openai')).toBeNull() expect(await getBYOKKey(null, 'openai')).toBeNull() - expect(mockGetWorkspaceById).not.toHaveBeenCalled() - }) - - it('returns null when the workspace does not exist', async () => { - mockGetWorkspaceById.mockResolvedValue(null) - - expect(await getBYOKKey(uniqueWorkspaceId(), 'openai')).toBeNull() }) it('returns null when the workspace has no keys for the provider', async () => { @@ -123,6 +110,17 @@ describe('getBYOKKey', () => { ]) }) + it('reads the key list fresh from the database on every call', async () => { + const workspaceId = uniqueWorkspaceId() + mockOrderBy.mockResolvedValue([storedKey('key-1')]) + + await getBYOKKey(workspaceId, 'openai') + await getBYOKKey(workspaceId, 'openai') + await getBYOKKey(workspaceId, 'openai') + + expect(mockOrderBy).toHaveBeenCalledTimes(3) + }) + it('tracks rotation independently per provider within a workspace', async () => { const workspaceId = uniqueWorkspaceId() mockOrderBy.mockResolvedValue([storedKey('key-1'), storedKey('key-2')]) diff --git a/apps/sim/lib/api-key/byok.ts b/apps/sim/lib/api-key/byok.ts index b131d2f742..73df4a3a1b 100644 --- a/apps/sim/lib/api-key/byok.ts +++ b/apps/sim/lib/api-key/byok.ts @@ -6,7 +6,6 @@ import { getRotatingApiKey } from '@/lib/core/config/api-keys' import { env } from '@/lib/core/config/env' import { isHosted } from '@/lib/core/config/env-flags' import { decryptSecret } from '@/lib/core/security/encryption' -import { getWorkspaceById } from '@/lib/workspaces/permissions/utils' import { getHostedModels } from '@/providers/models' import { PROVIDER_PLACEHOLDER_KEY } from '@/providers/utils' import { useProvidersStore } from '@/stores/providers/store' @@ -37,6 +36,9 @@ function nextRotationIndex(poolKey: string, poolSize: number): number { * multiple keys stored for the provider, requests round-robin across them in * creation order. A key that fails to decrypt is skipped in favor of the next * one in the pool. + * + * The key list is read fresh every call (not cached): BYOK is not a hot query, + * and reading fresh keeps revocation immediate across ECS tasks. */ export async function getBYOKKey( workspaceId: string | undefined | null, @@ -47,11 +49,6 @@ export async function getBYOKKey( } try { - const activeWorkspace = await getWorkspaceById(workspaceId) - if (!activeWorkspace) { - return null - } - const keys = await db .select({ id: workspaceBYOKKeys.id, encryptedApiKey: workspaceBYOKKeys.encryptedApiKey }) .from(workspaceBYOKKeys) diff --git a/apps/sim/lib/billing/calculations/usage-monitor.ts b/apps/sim/lib/billing/calculations/usage-monitor.ts index 906007689f..4e259f2036 100644 --- a/apps/sim/lib/billing/calculations/usage-monitor.ts +++ b/apps/sim/lib/billing/calculations/usage-monitor.ts @@ -467,6 +467,9 @@ export async function checkOrgMemberUsageLimit( return { isExceeded: false, currentUsage: 0, limit: null } } + // Resolve the cap first and short-circuit when unset (the common case); only + // then is computing usage worthwhile. Kept sequential, not raced, to avoid a + // usage query on every uncapped member's execution. const limit = await getOrgMemberUsageLimit(organizationId, userId) if (limit === null) { return { isExceeded: false, currentUsage: 0, limit: null } diff --git a/apps/sim/lib/billing/core/plan.test.ts b/apps/sim/lib/billing/core/plan.test.ts new file mode 100644 index 0000000000..cf223952b0 --- /dev/null +++ b/apps/sim/lib/billing/core/plan.test.ts @@ -0,0 +1,193 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +/** + * Drizzle mock for `getHighestPrioritySubscription`. It issues up to four + * queries keyed by table: + * - `subscription` for the user's personal subs (parallelized with members) + * - `member` for the user's org memberships (parallelized with subs) + * - `organization` for the org-existence follow-up + * - `subscription` again for the org-scoped subs follow-up + * + * The mock routes results by the table object passed to `.from()`, serving the + * (twice-read) `subscription` table from a FIFO queue (first read = personal, + * second = org). It records which tables were queried so we can assert the + * parallelized pair both run and that follow-ups are skipped when appropriate. + * + * Table sentinels and shared mock state live inside `vi.hoisted` so the + * `vi.mock` factories (hoisted to the top of the file) can reference them. + */ +const { SUBSCRIPTION_TABLE, MEMBER_TABLE, ORGANIZATION_TABLE, resultsByTable, fromCalls, select } = + vi.hoisted(() => { + const SUBSCRIPTION_TABLE = { __table: 'subscription' } + const MEMBER_TABLE = { __table: 'member' } + const ORGANIZATION_TABLE = { __table: 'organization' } + + const resultsByTable: Record = { + subscription: [], + member: [], + organization: [], + } + const fromCalls: string[] = [] + + const select = vi.fn(() => ({ + from: (table: { __table: string }) => { + fromCalls.push(table.__table) + const where = () => { + const queue = resultsByTable[table.__table] + const next = queue.length > 0 ? queue.shift() : [] + return Promise.resolve(next ?? []) + } + return { where } + }, + })) + + return { + SUBSCRIPTION_TABLE, + MEMBER_TABLE, + ORGANIZATION_TABLE, + resultsByTable, + fromCalls, + select, + } + }) + +vi.mock('@sim/db', () => ({ + db: { select }, +})) + +vi.mock('@sim/db/schema', () => ({ + subscription: SUBSCRIPTION_TABLE, + member: MEMBER_TABLE, + organization: ORGANIZATION_TABLE, +})) + +/** + * Realistic plan-check predicates so `pickHighestPrioritySubscription` exercises + * the real Enterprise > Team > Pro priority ordering over the rows we feed it. + */ +vi.mock('@/lib/billing/subscriptions/utils', () => ({ + ENTITLED_SUBSCRIPTION_STATUSES: ['active', 'past_due'], + checkEnterprisePlan: (s: any) => + s?.plan === 'enterprise' && ['active', 'past_due'].includes(s?.status), + checkTeamPlan: (s: any) => s?.plan === 'team' && ['active', 'past_due'].includes(s?.status), + checkProPlan: (s: any) => s?.plan === 'pro' && ['active', 'past_due'].includes(s?.status), +})) + +import { getHighestPrioritySubscription } from '@/lib/billing/core/plan' + +interface SubRow { + id: string + referenceId: string + plan: string + status: string +} + +function personalPro(userId: string): SubRow { + return { id: 'sub-personal-pro', referenceId: userId, plan: 'pro', status: 'active' } +} + +function orgEnterprise(orgId: string): SubRow { + return { id: 'sub-org-enterprise', referenceId: orgId, plan: 'enterprise', status: 'active' } +} + +function queue(table: 'subscription' | 'member' | 'organization', rows: unknown[]) { + resultsByTable[table].push(rows) +} + +describe('getHighestPrioritySubscription', () => { + beforeEach(() => { + vi.clearAllMocks() + resultsByTable.subscription = [] + resultsByTable.member = [] + resultsByTable.organization = [] + fromCalls.length = 0 + }) + + it('picks the org Enterprise sub over a personal Pro sub (priority order)', async () => { + queue('subscription', [personalPro('user-1')]) // personalSubs query + queue('member', [{ organizationId: 'org-1' }]) // memberships query + queue('organization', [{ id: 'org-1' }]) // org-existence query + queue('subscription', [orgEnterprise('org-1')]) // org-subscriptions query + + const result = await getHighestPrioritySubscription('user-1') + + expect(result).not.toBeNull() + expect(result?.id).toBe('sub-org-enterprise') + expect(result?.plan).toBe('enterprise') + }) + + it('selection is deterministic regardless of which parallelized query resolves first', async () => { + queue('subscription', [personalPro('user-1')]) + queue('member', [{ organizationId: 'org-1' }]) + queue('organization', [{ id: 'org-1' }]) + queue('subscription', [orgEnterprise('org-1')]) + + const result = await getHighestPrioritySubscription('user-1') + + expect(result?.id).toBe('sub-org-enterprise') + }) + + it('issues BOTH the personal-subscriptions and memberships queries (parallelized pair)', async () => { + queue('subscription', [personalPro('user-1')]) + queue('member', [{ organizationId: 'org-1' }]) + queue('organization', [{ id: 'org-1' }]) + queue('subscription', [orgEnterprise('org-1')]) + + await getHighestPrioritySubscription('user-1') + + expect(fromCalls).toContain('subscription') + expect(fromCalls).toContain('member') + // First two queries are exactly the parallelized pair (in either order). + expect(fromCalls.slice(0, 2).sort()).toEqual(['member', 'subscription']) + }) + + it('returns the personal sub and skips org follow-ups when there are no memberships', async () => { + queue('subscription', [personalPro('user-1')]) + queue('member', []) + + const result = await getHighestPrioritySubscription('user-1') + + expect(result?.id).toBe('sub-personal-pro') + expect(result?.plan).toBe('pro') + // org-existence + org-subscription follow-ups are NOT issued. + expect(fromCalls).not.toContain('organization') + expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1) + }) + + it('returns null when neither personal nor org subscriptions exist', async () => { + queue('subscription', []) + queue('member', []) + + const result = await getHighestPrioritySubscription('user-1') + + expect(result).toBeNull() + }) + + it('excludes orphaned org memberships whose organization row no longer exists', async () => { + queue('subscription', []) + queue('member', [{ organizationId: 'ghost-org' }]) // membership points at a deleted org + queue('organization', []) + + const result = await getHighestPrioritySubscription('user-1') + + // Org subs are never fetched (no valid org ids) -> falls back to null. + expect(result).toBeNull() + expect(fromCalls).toContain('organization') + // Only the initial personal-subs read on `subscription`; org-subs query skipped. + expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1) + }) + + it('falls back to the personal sub when the only org is orphaned', async () => { + queue('subscription', [personalPro('user-1')]) + queue('member', [{ organizationId: 'ghost-org' }]) + queue('organization', []) + + const result = await getHighestPrioritySubscription('user-1') + + expect(result?.id).toBe('sub-personal-pro') + expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1) + }) +}) diff --git a/apps/sim/lib/billing/core/plan.ts b/apps/sim/lib/billing/core/plan.ts index b4a56dab13..633d2e55c8 100644 --- a/apps/sim/lib/billing/core/plan.ts +++ b/apps/sim/lib/billing/core/plan.ts @@ -82,20 +82,21 @@ export async function getHighestPrioritySubscription( ) { const { onError = 'return-null', executor = db } = options try { - const personalSubs = await executor - .select() - .from(subscription) - .where( - and( - eq(subscription.referenceId, userId), - inArray(subscription.status, ENTITLED_SUBSCRIPTION_STATUSES) - ) - ) - - const memberships = await executor - .select({ organizationId: member.organizationId }) - .from(member) - .where(eq(member.userId, userId)) + const [personalSubs, memberships] = await Promise.all([ + executor + .select() + .from(subscription) + .where( + and( + eq(subscription.referenceId, userId), + inArray(subscription.status, ENTITLED_SUBSCRIPTION_STATUSES) + ) + ), + executor + .select({ organizationId: member.organizationId }) + .from(member) + .where(eq(member.userId, userId)), + ]) const orgIds = memberships.map((m: { organizationId: string }) => m.organizationId) diff --git a/apps/sim/lib/execution/preprocessing.test.ts b/apps/sim/lib/execution/preprocessing.test.ts index c00e5653ec..e72c44a567 100644 --- a/apps/sim/lib/execution/preprocessing.test.ts +++ b/apps/sim/lib/execution/preprocessing.test.ts @@ -28,7 +28,11 @@ vi.mock('@/lib/core/execution-limits', () => ({ getExecutionTimeout: vi.fn(() => 0), })) vi.mock('@/lib/core/rate-limiter/rate-limiter', () => ({ - RateLimiter: vi.fn(() => ({ checkRateLimitWithSubscription: mockCheckRateLimit })), + // Regular function (not an arrow) so `new RateLimiter()` is constructable under + // vitest 4.x, which rejects `new` on an arrow-implemented mock. + RateLimiter: vi.fn(function (this: unknown) { + return { checkRateLimitWithSubscription: mockCheckRateLimit } + }), })) vi.mock('@/lib/logs/execution/logging-session', () => loggingSessionMock) vi.mock('@/lib/workspaces/utils', () => ({ @@ -176,7 +180,7 @@ describe('preprocessExecution ban gate', () => { } as any) }) - it('blocks execution with 403 when the actor is banned, before any billing queries', async () => { + it('blocks execution with 403 when the actor is banned (ban wins over the parallel gates)', async () => { mockGetActivelyBannedUserIds.mockResolvedValue(['billed-account-1']) const loggingSession = { @@ -194,8 +198,79 @@ describe('preprocessExecution ban gate', () => { error: { statusCode: 403, logCreated: true, message: 'Account suspended' }, }) expect(loggingSession.safeStart).toHaveBeenCalled() - expect(getHighestPrioritySubscription).not.toHaveBeenCalled() - expect(checkServerSideUsageLimits).not.toHaveBeenCalled() + }) + + it('returns 403 (ban precedence) when ban, usage, and rate limit all fail simultaneously', async () => { + mockGetActivelyBannedUserIds.mockResolvedValue(['billed-account-1']) + vi.mocked(checkServerSideUsageLimits).mockResolvedValue({ + isExceeded: true, + currentUsage: 20, + limit: 10, + message: 'Usage limit exceeded. Please upgrade your plan to continue.', + } as any) + mockCheckRateLimit.mockResolvedValue({ + allowed: false, + remaining: 0, + resetAt: new Date(), + }) + + const loggingSession = { + safeStart: vi.fn().mockResolvedValue(true), + safeCompleteWithError: vi.fn().mockResolvedValue(undefined), + } + + const result = await preprocessExecution({ + ...baseOptions, + checkRateLimit: true, + loggingSession: loggingSession as any, + }) + + // Ban (403) takes precedence over usage (402) and rate limit (429), + // independent of which parallel gate's promise settled first. + expect(result).toMatchObject({ + success: false, + error: { statusCode: 403, logCreated: true, message: 'Account suspended' }, + }) + }) + + it('does not debit rate-limit quota when the ban gate rejects', async () => { + // The rate-limit gate consumes a token, so it must not run for a request + // an earlier gate (ban) already rejects. + mockGetActivelyBannedUserIds.mockResolvedValue(['billed-account-1']) + + const result = await preprocessExecution({ ...baseOptions, checkRateLimit: true }) + + expect(result).toMatchObject({ success: false, error: { statusCode: 403 } }) + expect(mockCheckRateLimit).not.toHaveBeenCalled() + }) + + it('does not debit rate-limit quota when the usage gate rejects', async () => { + vi.mocked(checkServerSideUsageLimits).mockResolvedValue({ + isExceeded: true, + currentUsage: 20, + limit: 10, + message: 'Usage limit exceeded. Please upgrade your plan to continue.', + } as any) + + const result = await preprocessExecution({ ...baseOptions, checkRateLimit: true }) + + expect(result).toMatchObject({ success: false, error: { statusCode: 402 } }) + expect(mockCheckRateLimit).not.toHaveBeenCalled() + }) + + it('consumes the rate-limit gate exactly once when the ban and usage gates pass', async () => { + mockCheckRateLimit.mockResolvedValue({ allowed: true, remaining: 5, resetAt: new Date() }) + + // skipConcurrencyReservation bypasses the STEP 7 admission reservation so the + // assertion isolates the rate gate and does not depend on Redis availability. + const result = await preprocessExecution({ + ...baseOptions, + checkRateLimit: true, + skipConcurrencyReservation: true, + }) + + expect(result.success).toBe(true) + expect(mockCheckRateLimit).toHaveBeenCalledTimes(1) }) it('checks the billing actor, caller-provided userId, and workflow owner in one call', async () => { @@ -234,6 +309,5 @@ describe('preprocessExecution ban gate', () => { success: false, error: { statusCode: 500, logCreated: true }, }) - expect(checkServerSideUsageLimits).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/lib/execution/preprocessing.ts b/apps/sim/lib/execution/preprocessing.ts index 075ee57af9..a41d3dbbbc 100644 --- a/apps/sim/lib/execution/preprocessing.ts +++ b/apps/sim/lib/execution/preprocessing.ts @@ -322,85 +322,118 @@ export async function preprocessExecution( } } - // ========== STEP 3.5: Reject Banned Accounts ========== - // Blocks executions when the billing actor, the workflow owner, or the - // caller-provided userId (chat deployer, authenticated caller) has an - // active ban or a blocked email domain. The owner comes from the workflow - // record so schedules — which pass the 'unknown' sentinel — are covered. - const banCandidateIds = [actorUserId] - if (userId && userId !== 'unknown' && userId !== actorUserId) { - banCandidateIds.push(userId) + // ========== STEPS 3.5–6: Preflight Gates ========== + // Read-only gates (ban, subscription, usage) run concurrently; the stateful + // rate-limit gate runs after they pass. Precedence: ban 403 → usage 402 → rate 429. + + /** + * A failing gate's deferred outcome: the response to return, plus an optional + * error-log write to flush before returning. Evaluated in precedence order. + */ + interface GateFailure { + response: PreprocessExecutionResult + recordError?: Parameters[0] } - if (workflowRecord.userId && !banCandidateIds.includes(workflowRecord.userId)) { - banCandidateIds.push(workflowRecord.userId) + + /** Usage figures captured by STEP 5 and reused by the STEP 7 reservation. */ + interface UsageSnapshot { + currentUsage: number + limit: number } - try { - const bannedUserIds = await getActivelyBannedUserIds(banCandidateIds) - if (bannedUserIds.length > 0) { - logger.warn(`[${requestId}] Execution blocked: banned account`, { - workflowId, - bannedUserIds, - triggerType, - }) - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: 'This account has been suspended. Workflow executions are blocked.', - loggingSession: providedLoggingSession, - triggerData, - }) + const banCheck = (async (): Promise => { + // Blocks executions when the billing actor, the workflow owner, or the + // caller-provided userId (chat deployer, authenticated caller) has an + // active ban or a blocked email domain. The owner comes from the workflow + // record so schedules — which pass the 'unknown' sentinel — are covered. + const banCandidateIds = [actorUserId] + if (userId && userId !== 'unknown' && userId !== actorUserId) { + banCandidateIds.push(userId) + } + if (workflowRecord.userId && !banCandidateIds.includes(workflowRecord.userId)) { + banCandidateIds.push(workflowRecord.userId) + } + try { + const bannedUserIds = await getActivelyBannedUserIds(banCandidateIds) + if (bannedUserIds.length > 0) { + logger.warn(`[${requestId}] Execution blocked: banned account`, { + workflowId, + bannedUserIds, + triggerType, + }) + + return { + response: { + success: false, + error: { + message: 'Account suspended', + statusCode: 403, + logCreated: true, + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: 'This account has been suspended. Workflow executions are blocked.', + loggingSession: providedLoggingSession, + triggerData, + }, + } + } + return null + } catch (error) { + logger.error(`[${requestId}] Error checking account ban status`, { error, actorUserId }) return { - success: false, - error: { - message: 'Account suspended', - statusCode: 403, - logCreated: true, + response: { + success: false, + error: { + message: 'Unable to verify account status. Execution blocked for security.', + statusCode: 500, + logCreated: true, + retryable: isRetryableInfrastructureError(error), + cause: describeRetryableInfrastructureError(error), + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: 'Unable to verify account status. Execution blocked for security.', + loggingSession: providedLoggingSession, + triggerData, }, } } - } catch (error) { - logger.error(`[${requestId}] Error checking account ban status`, { error, actorUserId }) - - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: 'Unable to verify account status. Execution blocked for security.', - loggingSession: providedLoggingSession, - triggerData, - }) - - return { - success: false, - error: { - message: 'Unable to verify account status. Execution blocked for security.', - statusCode: 500, - logCreated: true, - retryable: isRetryableInfrastructureError(error), - cause: describeRetryableInfrastructureError(error), - }, - } - } + })() // ========== STEP 4: Get Subscription ========== - const userSubscription = await getHighestPrioritySubscription(actorUserId) + const subscriptionFetch = getHighestPrioritySubscription(actorUserId) + + const [banFailure, userSubscription] = await Promise.all([banCheck, subscriptionFetch]) - // ========== STEP 5: Check Usage Limits ========== - // Snapshot reused by the STEP 7 admission reservation. - let usageSnapshot: { currentUsage: number; limit: number } | null = null - if (!skipUsageLimits) { + /** + * STEP 5: usage + per-member org usage gate. Returns the failure outcome (or + * `null` on pass/skip) plus the usage snapshot reused by the STEP 7 admission + * reservation. The snapshot is returned rather than written to an outer + * variable so concurrent gate tasks share no mutable state. + */ + const usageCheckTask = (async (): Promise<{ + failure: GateFailure | null + snapshot: UsageSnapshot | null + }> => { + if (skipUsageLimits) return { failure: null, snapshot: null } + let snapshot: UsageSnapshot | null = null try { const usageCheck = await checkServerSideUsageLimits(actorUserId, userSubscription) - usageSnapshot = { currentUsage: usageCheck.currentUsage, limit: usageCheck.limit } + snapshot = { currentUsage: usageCheck.currentUsage, limit: usageCheck.limit } if (usageCheck.isExceeded) { logger.warn( `[${requestId}] User ${actorUserId} has exceeded usage limits. Blocking execution.`, @@ -412,28 +445,33 @@ export async function preprocessExecution( } ) - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: - usageCheck.message || - `Usage limit exceeded: $${usageCheck.currentUsage?.toFixed(2)} used of $${usageCheck.limit?.toFixed(2)} limit. Please upgrade your plan to continue.`, - loggingSession: providedLoggingSession, - triggerData, - }) - return { - success: false, - error: { - message: - usageCheck.message || 'Usage limit exceeded. Please upgrade your plan to continue.', - statusCode: 402, - logCreated: true, + failure: { + response: { + success: false, + error: { + message: + usageCheck.message || + 'Usage limit exceeded. Please upgrade your plan to continue.', + statusCode: 402, + logCreated: true, + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: + usageCheck.message || + `Usage limit exceeded: $${usageCheck.currentUsage?.toFixed(2)} used of $${usageCheck.limit?.toFixed(2)} limit. Please upgrade your plan to continue.`, + loggingSession: providedLoggingSession, + triggerData, + }, }, + snapshot, } } @@ -457,126 +495,167 @@ export async function preprocessExecution( } ) - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: memberLimitMessage, - loggingSession: providedLoggingSession, - triggerData, - }) - return { - success: false, - error: { - message: memberLimitMessage, - statusCode: 402, - logCreated: true, + failure: { + response: { + success: false, + error: { + message: memberLimitMessage, + statusCode: 402, + logCreated: true, + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: memberLimitMessage, + loggingSession: providedLoggingSession, + triggerData, + }, }, + snapshot, } } + return { failure: null, snapshot } } catch (error) { logger.error(`[${requestId}] Error checking usage limits`, { error, actorUserId, }) - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: - 'Unable to determine usage limits. Execution blocked for security. Please contact support.', - loggingSession: providedLoggingSession, - triggerData, - }) - return { - success: false, - error: { - message: 'Unable to determine usage limits. Execution blocked for security.', - statusCode: 500, - logCreated: true, - retryable: isRetryableInfrastructureError(error), - cause: describeRetryableInfrastructureError(error), + failure: { + response: { + success: false, + error: { + message: 'Unable to determine usage limits. Execution blocked for security.', + statusCode: 500, + logCreated: true, + retryable: isRetryableInfrastructureError(error), + cause: describeRetryableInfrastructureError(error), + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: + 'Unable to determine usage limits. Execution blocked for security. Please contact support.', + loggingSession: providedLoggingSession, + triggerData, + }, }, + snapshot, } } - } + })() // ========== STEP 6: Check Rate Limits ========== let rateLimitInfo: { allowed: boolean; remaining: number; resetAt: Date } | undefined - if (checkRateLimit) { + /** + * STEP 6: rate-limit gate. Unlike the other gates this one is NOT read-only — + * `checkRateLimitWithSubscription` consumes a token — so it is invoked + * sequentially only after the ban and usage gates pass, matching the original + * order. Running it eagerly or in parallel would debit rate-limit quota for + * requests that ban or usage rejects. Returns the failure outcome, or `null` + * on pass/skip; on a non-error outcome it populates `rateLimitInfo`. + */ + const runRateLimitGate = async (): Promise => { + if (!checkRateLimit) return null try { const rateLimiter = new RateLimiter() - rateLimitInfo = await rateLimiter.checkRateLimitWithSubscription( + const info = await rateLimiter.checkRateLimitWithSubscription( actorUserId, userSubscription, triggerType, false // not async ) + rateLimitInfo = info - if (!rateLimitInfo.allowed) { + if (!info.allowed) { logger.warn(`[${requestId}] Rate limit exceeded for user ${actorUserId}`, { triggerType, - remaining: rateLimitInfo.remaining, - resetAt: rateLimitInfo.resetAt, + remaining: info.remaining, + resetAt: info.resetAt, }) - await recordPreprocessingError({ + return { + response: { + success: false, + error: { + message: `Rate limit exceeded. Please try again later.`, + statusCode: 429, + logCreated: true, + }, + }, + recordError: { + workflowId, + executionId, + triggerType, + requestId, + userId: actorUserId, + workspaceId, + errorMessage: `Rate limit exceeded. ${info.remaining} requests remaining. Resets at ${info.resetAt.toISOString()}.`, + loggingSession: providedLoggingSession, + triggerData, + }, + } + } + return null + } catch (error) { + logger.error(`[${requestId}] Error checking rate limits`, { error, actorUserId }) + + return { + response: { + success: false, + error: { + message: 'Error checking rate limits', + statusCode: 500, + logCreated: true, + retryable: isRetryableInfrastructureError(error), + cause: describeRetryableInfrastructureError(error), + }, + }, + recordError: { workflowId, executionId, triggerType, requestId, userId: actorUserId, workspaceId, - errorMessage: `Rate limit exceeded. ${rateLimitInfo.remaining} requests remaining. Resets at ${rateLimitInfo.resetAt.toISOString()}.`, + errorMessage: 'Error checking rate limits. Execution blocked for safety.', loggingSession: providedLoggingSession, triggerData, - }) - - return { - success: false, - error: { - message: `Rate limit exceeded. Please try again later.`, - statusCode: 429, - logCreated: true, - }, - } + }, } - } catch (error) { - logger.error(`[${requestId}] Error checking rate limits`, { error, actorUserId }) + } + } - await recordPreprocessingError({ - workflowId, - executionId, - triggerType, - requestId, - userId: actorUserId, - workspaceId, - errorMessage: 'Error checking rate limits. Execution blocked for safety.', - loggingSession: providedLoggingSession, - triggerData, - }) + const usageResult = await usageCheckTask + const usageSnapshot = usageResult.snapshot - return { - success: false, - error: { - message: 'Error checking rate limits', - statusCode: 500, - logCreated: true, - retryable: isRetryableInfrastructureError(error), - cause: describeRetryableInfrastructureError(error), - }, - } + const readGateFailure = banFailure ?? usageResult.failure + if (readGateFailure) { + if (readGateFailure.recordError) { + await recordPreprocessingError(readGateFailure.recordError) + } + return readGateFailure.response + } + + const rateLimitFailure = await runRateLimitGate() + if (rateLimitFailure) { + if (rateLimitFailure.recordError) { + await recordPreprocessingError(rateLimitFailure.recordError) } + return rateLimitFailure.response } /** diff --git a/apps/sim/lib/workflows/executor/execution-core.test.ts b/apps/sim/lib/workflows/executor/execution-core.test.ts index eba2011484..58fbca16d5 100644 --- a/apps/sim/lib/workflows/executor/execution-core.test.ts +++ b/apps/sim/lib/workflows/executor/execution-core.test.ts @@ -72,26 +72,22 @@ vi.mock('@/lib/workflows/triggers/triggers', () => ({ vi.mock('@/lib/workflows/utils', () => workflowsUtilsMock) vi.mock('@/executor', () => ({ - Executor: vi.fn().mockImplementation( - class { - constructor(args: unknown) { - executorConstructorMock(args) - // biome-ignore lint/correctness/noConstructorReturn: vitest 4 constructs mocks via Reflect.construct; returning the instance overrides `new Executor(...)` - return { - execute: executorExecuteMock, - executeFromBlock: executorExecuteMock, - } + Executor: class { + constructor(args: unknown) { + executorConstructorMock(args) + // biome-ignore lint/correctness/noConstructorReturn: returning the instance overrides `new Executor(...)` so consumers get the mocked methods + return { + execute: executorExecuteMock, + executeFromBlock: executorExecuteMock, } } - ), + }, })) vi.mock('@/serializer', () => ({ - Serializer: vi.fn().mockImplementation( - class { - serializeWorkflow = serializeWorkflowMock - } - ), + Serializer: class { + serializeWorkflow = serializeWorkflowMock + }, })) import { @@ -192,6 +188,96 @@ describe('executeWorkflowCore terminal finalization sequencing', () => { clearExecutionCancellationMock.mockResolvedValue(undefined) }) + it('loads workflow state and env vars concurrently, then starts logging before constructing the executor', async () => { + const callOrder: string[] = [] + + let releaseWorkflowLoad: (() => void) | undefined + let releaseEnvLoad: (() => void) | undefined + const workflowLoadGate = new Promise((resolve) => { + releaseWorkflowLoad = resolve + }) + const envLoadGate = new Promise((resolve) => { + releaseEnvLoad = resolve + }) + + loadWorkflowFromNormalizedTablesMock.mockImplementation(async () => { + callOrder.push('load-workflow:start') + await workflowLoadGate + callOrder.push('load-workflow:end') + return { + blocks: { + 'start-block': { + id: 'start-block', + type: 'start_trigger', + subBlocks: {}, + name: 'Start', + }, + }, + edges: [], + loops: {}, + parallels: {}, + } + }) + + getPersonalAndWorkspaceEnvMock.mockImplementation(async () => { + callOrder.push('load-env:start') + await envLoadGate + callOrder.push('load-env:end') + return { + personalEncrypted: {}, + workspaceEncrypted: {}, + personalDecrypted: {}, + workspaceDecrypted: {}, + } + }) + + safeStartMock.mockImplementation(async () => { + callOrder.push('safeStart') + return true + }) + + executorConstructorMock.mockImplementation(() => { + callOrder.push('executor-construct') + }) + + executorExecuteMock.mockResolvedValue({ + success: true, + status: 'completed', + output: { done: true }, + logs: [], + metadata: { duration: 123, startTime: 'start', endTime: 'end' }, + }) + + const executionPromise = executeWorkflowCore({ + snapshot: createSnapshot() as any, + callbacks: {}, + loggingSession: loggingSession as any, + }) + + await Promise.resolve() + + expect(callOrder).toContain('load-workflow:start') + expect(callOrder).toContain('load-env:start') + expect(callOrder).not.toContain('safeStart') + expect(callOrder).not.toContain('executor-construct') + + releaseWorkflowLoad?.() + releaseEnvLoad?.() + + await executionPromise + + expect(callOrder).toEqual([ + 'load-workflow:start', + 'load-env:start', + 'load-workflow:end', + 'load-env:end', + 'safeStart', + 'executor-construct', + ]) + expect(safeStartMock).toHaveBeenCalledTimes(1) + expect(executorConstructorMock).toHaveBeenCalledTimes(1) + }) + it('routes onBlockStart through logging session persistence path', async () => { executorExecuteMock.mockResolvedValue({ success: true, diff --git a/apps/sim/lib/workflows/executor/execution-core.ts b/apps/sim/lib/workflows/executor/execution-core.ts index 3763203aaa..068e4bf9e3 100644 --- a/apps/sim/lib/workflows/executor/execution-core.ts +++ b/apps/sim/lib/workflows/executor/execution-core.ts @@ -349,62 +349,78 @@ export async function executeWorkflowCore( } try { - let blocks - let edges: Edge[] - let loops - let parallels - - // Use workflowStateOverride if provided (for diff workflows) - if (metadata.workflowStateOverride) { - blocks = metadata.workflowStateOverride.blocks - edges = metadata.workflowStateOverride.edges - loops = metadata.workflowStateOverride.loops || {} - parallels = metadata.workflowStateOverride.parallels || {} - deploymentVersionId = metadata.workflowStateOverride.deploymentVersionId - - logger.info(`[${requestId}] Using workflow state override (diff workflow execution)`, { - blocksCount: Object.keys(blocks).length, - edgesCount: edges.length, - }) - } else if (useDraftState) { - const draftData = await loadWorkflowFromNormalizedTables(workflowId) + const personalEnvUserId = + metadata.isClientSession && metadata.sessionUserId + ? metadata.sessionUserId + : metadata.workflowUserId - if (!draftData) { - throw new Error('Workflow not found or not yet saved') + if (!personalEnvUserId) { + throw new Error('Missing workflowUserId in execution metadata') + } + + /** + * Resolves the workflow state from the override, the draft tables, or the + * deployed snapshot. The async load (draft/deployed) has no data dependency + * on the environment load, so the two are awaited concurrently below. + */ + const loadWorkflowState = async () => { + if (metadata.workflowStateOverride) { + const override = metadata.workflowStateOverride + logger.info(`[${requestId}] Using workflow state override (diff workflow execution)`, { + blocksCount: Object.keys(override.blocks).length, + edgesCount: override.edges.length, + }) + return { + blocks: override.blocks, + edges: override.edges, + loops: override.loops || {}, + parallels: override.parallels || {}, + deploymentVersionId: override.deploymentVersionId, + } } - blocks = draftData.blocks - edges = draftData.edges - loops = draftData.loops - parallels = draftData.parallels + if (useDraftState) { + const draftData = await loadWorkflowFromNormalizedTables(workflowId) - logger.info( - `[${requestId}] Using draft workflow state from normalized tables (client execution)` - ) - } else { - const deployedData = await loadDeployedWorkflowState(workflowId) - blocks = deployedData.blocks - edges = deployedData.edges - loops = deployedData.loops - parallels = deployedData.parallels - deploymentVersionId = deployedData.deploymentVersionId + if (!draftData) { + throw new Error('Workflow not found or not yet saved') + } + logger.info( + `[${requestId}] Using draft workflow state from normalized tables (client execution)` + ) + return { + blocks: draftData.blocks, + edges: draftData.edges, + loops: draftData.loops, + parallels: draftData.parallels, + deploymentVersionId: undefined, + } + } + + const deployedData = await loadDeployedWorkflowState(workflowId) logger.info(`[${requestId}] Using deployed workflow state (deployed execution)`) + return { + blocks: deployedData.blocks, + edges: deployedData.edges, + loops: deployedData.loops, + parallels: deployedData.parallels, + deploymentVersionId: deployedData.deploymentVersionId, + } } - const mergedStates = mergeSubblockStateWithValues(blocks) + const [workflowState, env] = await Promise.all([ + loadWorkflowState(), + getPersonalAndWorkspaceEnv(personalEnvUserId, providedWorkspaceId), + ]) - const personalEnvUserId = - metadata.isClientSession && metadata.sessionUserId - ? metadata.sessionUserId - : metadata.workflowUserId + const { blocks, loops, parallels } = workflowState + const edges: Edge[] = workflowState.edges + deploymentVersionId = workflowState.deploymentVersionId - if (!personalEnvUserId) { - throw new Error('Missing workflowUserId in execution metadata') - } + const mergedStates = mergeSubblockStateWithValues(blocks) - const { personalEncrypted, workspaceEncrypted, personalDecrypted, workspaceDecrypted } = - await getPersonalAndWorkspaceEnv(personalEnvUserId, providedWorkspaceId) + const { personalEncrypted, workspaceEncrypted, personalDecrypted, workspaceDecrypted } = env // Use encrypted values for logging (don't log decrypted secrets) const variables = EnvVarsSchema.parse({ ...personalEncrypted, ...workspaceEncrypted }) diff --git a/apps/sim/lib/workflows/persistence/utils.test.ts b/apps/sim/lib/workflows/persistence/utils.test.ts index afcc4081fa..964577f9cf 100644 --- a/apps/sim/lib/workflows/persistence/utils.test.ts +++ b/apps/sim/lib/workflows/persistence/utils.test.ts @@ -113,6 +113,22 @@ vi.mock('@sim/db', () => ({ webhook: {}, })) +const { mockSanitizeAgentToolsInBlocks } = vi.hoisted(() => ({ + mockSanitizeAgentToolsInBlocks: vi.fn(), +})) + +/** + * Default identity behavior for the mocked migration step. Re-applied in the + * cache describe block's `beforeEach` because the outer `afterEach` calls + * `vi.resetAllMocks()`, which clears implementations. + */ +const sanitizeIdentity = (blocks: unknown) => ({ blocks }) +mockSanitizeAgentToolsInBlocks.mockImplementation(sanitizeIdentity) + +vi.mock('@/lib/workflows/sanitization/validation', () => ({ + sanitizeAgentToolsInBlocks: mockSanitizeAgentToolsInBlocks, +})) + import * as dbHelpers from '@/lib/workflows/persistence/utils' const mockWorkflowId = 'test-workflow-123' @@ -307,6 +323,7 @@ const mockWorkflowState = createWorkflowState({ describe('Database Helpers', () => { beforeEach(() => { vi.clearAllMocks() + mockSanitizeAgentToolsInBlocks.mockImplementation(sanitizeIdentity) }) afterEach(() => { @@ -1550,4 +1567,157 @@ describe('Database Helpers', () => { expect(messages2).toEqual([{ role: 'system', content: 'System' }]) }) }) + + describe('loadDeployedWorkflowState deployed-state cache', () => { + /** + * Minimal but realistic deployed state: a couple of plain (non-agent, + * credential-free) blocks plus an edge. Plain blocks make the real + * downstream migration steps (agent-message, subblock-id, credential, + * canonical-mode) no-ops, so the only observable "heavy work" is the + * mocked `sanitizeAgentToolsInBlocks` first step, which we use as the + * migration call counter. + */ + function buildDeployedState() { + return { + blocks: { + 'block-1': { + id: 'block-1', + type: 'api', + name: 'API Block', + position: { x: 0, y: 0 }, + enabled: true, + subBlocks: { url: { id: 'url', type: 'short-input', value: 'https://example.com' } }, + outputs: {}, + data: {}, + }, + 'block-2': { + id: 'block-2', + type: 'function', + name: 'Function Block', + position: { x: 100, y: 0 }, + enabled: true, + subBlocks: { code: { id: 'code', type: 'code', value: 'return 1' } }, + outputs: {}, + data: {}, + }, + }, + edges: [ + { + id: 'edge-1', + source: 'block-1', + target: 'block-2', + sourceHandle: 'output', + targetHandle: 'input', + }, + ], + loops: {}, + parallels: {}, + variables: { threshold: 5 }, + } + } + + /** + * Wires `db.select` to return a single active deployment-version row for the + * given id. Returns the inner `where` spy so tests can assert how many times + * the active-version SELECT ran. + */ + function mockActiveVersionSelect(versionId: string, state: unknown) { + const where = vi.fn().mockReturnValue({ + orderBy: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: versionId, state, createdAt: new Date() }]), + }), + }) + mockDb.select.mockReturnValue({ + from: vi.fn().mockReturnValue({ where }), + }) + return where + } + + beforeEach(() => { + vi.clearAllMocks() + mockSanitizeAgentToolsInBlocks.mockImplementation(sanitizeIdentity) + dbHelpers.invalidateDeployedStateCache() + }) + + it('serves a cache HIT, skipping migrations on the second call for the same active version', async () => { + const where = mockActiveVersionSelect('dv-hit', buildDeployedState()) + + const first = await dbHelpers.loadDeployedWorkflowState('wf-1', 'workspace-1') + const second = await dbHelpers.loadDeployedWorkflowState('wf-1', 'workspace-1') + + expect(first).toBeDefined() + expect(second).toBeDefined() + expect(mockSanitizeAgentToolsInBlocks).toHaveBeenCalledTimes(1) + expect(where).toHaveBeenCalledTimes(2) + }) + + it('still runs the active-version SELECT on every call so rollback/redeploy stays observable', async () => { + const where = mockActiveVersionSelect('dv-active', buildDeployedState()) + + await dbHelpers.loadDeployedWorkflowState('wf-2', 'workspace-1') + await dbHelpers.loadDeployedWorkflowState('wf-2', 'workspace-1') + + expect(where).toHaveBeenCalledTimes(2) + }) + + it('deep-clones on read: mutating the first result does not corrupt the cached copy', async () => { + mockActiveVersionSelect('dv-clone', buildDeployedState()) + + const first = await dbHelpers.loadDeployedWorkflowState('wf-3', 'workspace-1') + ;(first.blocks['block-1'] as any).name = 'MUTATED' + ;(first.blocks['block-1'].subBlocks.url as any).value = 'https://hacked.example' + first.edges.push({ + id: 'edge-injected', + source: 'block-2', + target: 'block-1', + } as any) + + const second = await dbHelpers.loadDeployedWorkflowState('wf-3', 'workspace-1') + + expect(second.blocks['block-1'].name).toBe('API Block') + expect(second.blocks['block-1'].subBlocks.url.value).toBe('https://example.com') + expect(second.edges).toHaveLength(1) + expect(second.blocks).toEqual(buildDeployedState().blocks) + }) + + it('keys the cache by deploymentVersionId: a different active id triggers a fresh build', async () => { + mockActiveVersionSelect('dv-old', buildDeployedState()) + await dbHelpers.loadDeployedWorkflowState('wf-4', 'workspace-1') + expect(mockSanitizeAgentToolsInBlocks).toHaveBeenCalledTimes(1) + + mockActiveVersionSelect('dv-new', buildDeployedState()) + await dbHelpers.loadDeployedWorkflowState('wf-4', 'workspace-1') + expect(mockSanitizeAgentToolsInBlocks).toHaveBeenCalledTimes(2) + }) + + it('invalidateDeployedStateCache(id) forces a rebuild on the next call', async () => { + mockActiveVersionSelect('dv-inv', buildDeployedState()) + + await dbHelpers.loadDeployedWorkflowState('wf-5', 'workspace-1') + await dbHelpers.loadDeployedWorkflowState('wf-5', 'workspace-1') + expect(mockSanitizeAgentToolsInBlocks).toHaveBeenCalledTimes(1) + + dbHelpers.invalidateDeployedStateCache('dv-inv') + + await dbHelpers.loadDeployedWorkflowState('wf-5', 'workspace-1') + expect(mockSanitizeAgentToolsInBlocks).toHaveBeenCalledTimes(2) + }) + + it('throws when there is no active deployment and does not cache the failure', async () => { + const where = vi.fn().mockReturnValue({ + orderBy: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }) + mockDb.select.mockReturnValue({ + from: vi.fn().mockReturnValue({ where }), + }) + + await expect(dbHelpers.loadDeployedWorkflowState('wf-6', 'workspace-1')).rejects.toThrow( + 'Workflow wf-6 has no active deployment' + ) + + expect(mockSanitizeAgentToolsInBlocks).not.toHaveBeenCalled() + }) + }) }) diff --git a/apps/sim/lib/workflows/persistence/utils.ts b/apps/sim/lib/workflows/persistence/utils.ts index 5663d0f3fd..725e7a1347 100644 --- a/apps/sim/lib/workflows/persistence/utils.ts +++ b/apps/sim/lib/workflows/persistence/utils.ts @@ -13,6 +13,7 @@ import type { DbOrTx, NormalizedWorkflowData } from '@sim/workflow-persistence/t import type { BlockState, Loop, Parallel, WorkflowState } from '@sim/workflow-types/workflow' import type { InferSelectModel } from 'drizzle-orm' import { and, desc, eq, inArray, lt, sql } from 'drizzle-orm' +import { LRUCache } from 'lru-cache' import type { Edge } from 'reactflow' import { remapConditionBlockIds, remapConditionEdgeHandle } from '@/lib/workflows/condition-ids' import { @@ -99,6 +100,29 @@ export async function blockExistsInDeployment( } } +const DEPLOYED_STATE_CACHE_MAX_ENTRIES = 500 +const DEPLOYED_STATE_CACHE_TTL_MS = 5 * 60 * 1000 + +/** + * Caches post-migration deployed state by the immutable `deploymentVersionId`, so + * a redeploy/rollback (which changes the active id) self-invalidates. The TTL is + * absolute on purpose — it bounds the one non-immutable part, the live credential + * remap in `applyBlockMigrations` — so credential changes still propagate. + */ +const deployedStateCache = new LRUCache({ + max: DEPLOYED_STATE_CACHE_MAX_ENTRIES, + ttl: DEPLOYED_STATE_CACHE_TTL_MS, +}) + +/** Evicts one deployed-state entry, or clears the cache when no id is given. */ +export function invalidateDeployedStateCache(deploymentVersionId?: string): void { + if (deploymentVersionId) { + deployedStateCache.delete(deploymentVersionId) + return + } + deployedStateCache.clear() +} + export async function loadDeployedWorkflowState( workflowId: string, providedWorkspaceId?: string @@ -124,6 +148,11 @@ export async function loadDeployedWorkflowState( throw new Error(`Workflow ${workflowId} has no active deployment`) } + const cached = deployedStateCache.get(active.id) + if (cached) { + return structuredClone(cached) + } + const state = active.state as WorkflowState & { variables?: Record } let resolvedWorkspaceId = providedWorkspaceId @@ -141,7 +170,7 @@ export async function loadDeployedWorkflowState( resolvedWorkspaceId ) - return { + const deployedState: DeployedWorkflowData = { blocks: migratedBlocks, edges: state.edges || [], loops: state.loops || {}, @@ -150,6 +179,10 @@ export async function loadDeployedWorkflowState( isFromNormalizedTables: false, deploymentVersionId: active.id, } + + deployedStateCache.set(active.id, deployedState) + + return structuredClone(deployedState) } catch (error) { logger.error(`Error loading deployed workflow state ${workflowId}:`, error) throw error diff --git a/apps/sim/providers/anthropic/index.ts b/apps/sim/providers/anthropic/index.ts index 543c328fb1..043ae4b0f0 100644 --- a/apps/sim/providers/anthropic/index.ts +++ b/apps/sim/providers/anthropic/index.ts @@ -2,6 +2,7 @@ import Anthropic from '@anthropic-ai/sdk' import { createLogger } from '@sim/logger' import type { StreamingExecution } from '@/executor/types' import { executeAnthropicProviderRequest } from '@/providers/anthropic/core' +import { getCachedProviderClient } from '@/providers/client-cache' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' @@ -21,13 +22,19 @@ export const anthropicProvider: ProviderConfig = { return executeAnthropicProviderRequest(request, { providerId: 'anthropic', providerLabel: 'Anthropic', - createClient: (apiKey, useNativeStructuredOutputs) => - new Anthropic({ - apiKey, - defaultHeaders: useNativeStructuredOutputs - ? { 'anthropic-beta': 'structured-outputs-2025-11-13' } - : undefined, - }), + createClient: (apiKey, useNativeStructuredOutputs) => { + const cacheKey = `anthropic::${apiKey}::${useNativeStructuredOutputs ? 'beta' : 'default'}` + return getCachedProviderClient( + cacheKey, + () => + new Anthropic({ + apiKey, + defaultHeaders: useNativeStructuredOutputs + ? { 'anthropic-beta': 'structured-outputs-2025-11-13' } + : undefined, + }) + ) + }, logger, }) }, diff --git a/apps/sim/providers/azure-anthropic/index.ts b/apps/sim/providers/azure-anthropic/index.ts index 39980d77c2..2f0498992a 100644 --- a/apps/sim/providers/azure-anthropic/index.ts +++ b/apps/sim/providers/azure-anthropic/index.ts @@ -4,6 +4,7 @@ import { env } from '@/lib/core/config/env' import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import type { StreamingExecution } from '@/executor/types' import { executeAnthropicProviderRequest } from '@/providers/anthropic/core' +import { getCachedProviderClient } from '@/providers/client-cache' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' @@ -29,6 +30,7 @@ export const azureAnthropicProvider: ProviderConfig = { } let pinnedFetch: typeof fetch | undefined + let pinnedIP: string | undefined if (userProvidedEndpoint) { const validation = await validateUrlWithDNS(userProvidedEndpoint, 'azureEndpoint') if (!validation.isValid) { @@ -41,7 +43,8 @@ export const azureAnthropicProvider: ProviderConfig = { if (!validation.resolvedIP) { throw new Error('Invalid Azure Anthropic endpoint: could not resolve a pinnable IP address') } - pinnedFetch = createPinnedFetch(validation.resolvedIP) + pinnedIP = validation.resolvedIP + pinnedFetch = createPinnedFetch(pinnedIP) } const apiKey = request.apiKey @@ -68,19 +71,32 @@ export const azureAnthropicProvider: ProviderConfig = { { providerId: 'azure-anthropic', providerLabel: 'Azure Anthropic', - createClient: (apiKey, useNativeStructuredOutputs) => - new Anthropic({ - baseURL, + createClient: (apiKey, useNativeStructuredOutputs) => { + const cacheKey = [ + 'azure-anthropic', apiKey, - ...(pinnedFetch ? { fetch: pinnedFetch } : {}), - defaultHeaders: { - 'api-key': apiKey, - 'anthropic-version': anthropicVersion, - ...(useNativeStructuredOutputs - ? { 'anthropic-beta': 'structured-outputs-2025-11-13' } - : {}), - }, - }), + baseURL, + anthropicVersion, + pinnedIP ?? 'no-pin', + useNativeStructuredOutputs ? 'beta' : 'default', + ].join('::') + return getCachedProviderClient( + cacheKey, + () => + new Anthropic({ + baseURL, + apiKey, + ...(pinnedFetch ? { fetch: pinnedFetch } : {}), + defaultHeaders: { + 'api-key': apiKey, + 'anthropic-version': anthropicVersion, + ...(useNativeStructuredOutputs + ? { 'anthropic-beta': 'structured-outputs-2025-11-13' } + : {}), + }, + }) + ) + }, logger, } ) diff --git a/apps/sim/providers/bedrock/index.test.ts b/apps/sim/providers/bedrock/index.test.ts index 38cb857425..3a9abceefb 100644 --- a/apps/sim/providers/bedrock/index.test.ts +++ b/apps/sim/providers/bedrock/index.test.ts @@ -50,10 +50,12 @@ vi.mock('@/tools', () => ({ import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime' import { bedrockProvider } from '@/providers/bedrock/index' +import { clearProviderClientCacheForTests } from '@/providers/client-cache' describe('bedrockProvider credential handling', () => { beforeEach(() => { vi.clearAllMocks() + clearProviderClientCacheForTests() mockSend.mockResolvedValue({ output: { message: { content: [{ text: 'response' }] } }, usage: { inputTokens: 10, outputTokens: 5 }, diff --git a/apps/sim/providers/bedrock/index.ts b/apps/sim/providers/bedrock/index.ts index 32be407867..4e512d15b4 100644 --- a/apps/sim/providers/bedrock/index.ts +++ b/apps/sim/providers/bedrock/index.ts @@ -24,6 +24,7 @@ import { generateToolUseId, getBedrockInferenceProfileId, } from '@/providers/bedrock/utils' +import { getCachedProviderClient } from '@/providers/client-cache' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import { createStreamingExecution } from '@/providers/streaming-execution' import { enrichLastModelSegment } from '@/providers/trace-enrichment' @@ -138,7 +139,16 @@ export const bedrockProvider: ProviderConfig = { } } - const client = new BedrockRuntimeClient(clientConfig) + // Key on the full credential (access key id + secret) so a corrected secret + // under the same access key id yields a fresh client rather than a stale one. + const credentialKey = + request.bedrockAccessKeyId && request.bedrockSecretKey + ? `${request.bedrockAccessKeyId}:${request.bedrockSecretKey}` + : 'default-chain' + const client = getCachedProviderClient( + `bedrock::${region}::${credentialKey}`, + () => new BedrockRuntimeClient(clientConfig) + ) const messages: BedrockMessage[] = [] const systemContent: SystemContentBlock[] = [] diff --git a/apps/sim/providers/client-cache.test.ts b/apps/sim/providers/client-cache.test.ts new file mode 100644 index 0000000000..6fd03ee97b --- /dev/null +++ b/apps/sim/providers/client-cache.test.ts @@ -0,0 +1,107 @@ +/** + * @vitest-environment node + */ +import { describe, expect, it, vi } from 'vitest' +import { getCachedProviderClient } from '@/providers/client-cache' + +/** + * Builds a fresh fake "client" object on every call so identity comparisons + * (`toBe`) tell us whether the cache returned the memoized instance or a new one + * from the factory. We never construct a real SDK client — these tests exercise + * the cache, not any provider SDK. + */ +function makeFactory() { + return vi.fn(() => ({}) as object) +} + +/** + * Generates a unique suffix per test so distinct tests never collide on cache + * keys. The cache util exposes no reset hook, so isolation is achieved by + * namespacing keys rather than clearing shared state. + */ +let keyCounter = 0 +function uniqueNs(): string { + keyCounter += 1 + return `ns-${keyCounter}-${Date.now()}` +} + +describe('getCachedProviderClient', () => { + it('returns the SAME instance for an identical key and runs the factory once (memoized)', () => { + const key = `anthropic::${uniqueNs()}::default` + const factory = makeFactory() + + const first = getCachedProviderClient(key, factory) + const second = getCachedProviderClient(key, factory) + + expect(second).toBe(first) + expect(factory).toHaveBeenCalledTimes(1) + }) + + it('returns a DIFFERENT instance for a different apiKey (tenant isolation)', () => { + const ns = uniqueNs() + const factoryA = makeFactory() + const factoryB = makeFactory() + + const tenantA = getCachedProviderClient(`anthropic::${ns}-tenant-a::default`, factoryA) + const tenantB = getCachedProviderClient(`anthropic::${ns}-tenant-b::default`, factoryB) + + expect(tenantB).not.toBe(tenantA) + expect(factoryA).toHaveBeenCalledTimes(1) + expect(factoryB).toHaveBeenCalledTimes(1) + }) + + it('namespaces by provider: the same apiKey under different provider prefixes does not collide', () => { + const ns = uniqueNs() + const apiKey = `${ns}-shared-key` + const anthropicFactory = makeFactory() + const bedrockFactory = makeFactory() + + const anthropicClient = getCachedProviderClient(`anthropic::${apiKey}`, anthropicFactory) + const bedrockClient = getCachedProviderClient(`bedrock::${apiKey}`, bedrockFactory) + + expect(bedrockClient).not.toBe(anthropicClient) + }) + + it('treats every distinct key dimension as a distinct client', () => { + const ns = uniqueNs() + const base = `azure-anthropic::${ns}-key::https://a.example.com::2023-06-01::10.0.0.1::default` + const baseFactory = makeFactory() + const baseClient = getCachedProviderClient(base, baseFactory) + + const variants = [ + `azure-anthropic::${ns}-key::https://b.example.com::2023-06-01::10.0.0.1::default`, + `azure-anthropic::${ns}-key::https://a.example.com::2024-10-22::10.0.0.1::default`, + `azure-anthropic::${ns}-key::https://a.example.com::2023-06-01::10.0.0.2::default`, + `azure-anthropic::${ns}-key::https://a.example.com::2023-06-01::no-pin::default`, + `azure-anthropic::${ns}-key::https://a.example.com::2023-06-01::10.0.0.1::beta`, + ] + + for (const key of variants) { + const factory = makeFactory() + const client = getCachedProviderClient(key, factory) + expect(client).not.toBe(baseClient) + expect(factory).toHaveBeenCalledTimes(1) + } + }) + + it('evicts the least-recently-used entry once the cache cap is exceeded', () => { + const ns = uniqueNs() + const CAP = 1_000 + + const oldestKey = `evict::${ns}::0` + const oldestFactory = makeFactory() + getCachedProviderClient(oldestKey, oldestFactory) + expect(oldestFactory).toHaveBeenCalledTimes(1) + + // Fill the remaining capacity, then push one past the cap. The oldest key has + // not been touched since insertion, so it is the LRU eviction victim. + for (let i = 1; i <= CAP; i += 1) { + getCachedProviderClient(`evict::${ns}::${i}`, makeFactory()) + } + + const reFactory = makeFactory() + getCachedProviderClient(oldestKey, reFactory) + expect(reFactory).toHaveBeenCalledTimes(1) + expect(oldestFactory).toHaveBeenCalledTimes(1) + }) +}) diff --git a/apps/sim/providers/client-cache.ts b/apps/sim/providers/client-cache.ts new file mode 100644 index 0000000000..7908a94d21 --- /dev/null +++ b/apps/sim/providers/client-cache.ts @@ -0,0 +1,36 @@ +import { LRUCache } from 'lru-cache' + +const CLIENT_CACHE_MAX_ENTRIES = 1_000 +const CLIENT_CACHE_TTL_MS = 30 * 60 * 1_000 + +/** + * `updateAgeOnGet` makes the TTL idle-based: a continuously-used client keeps its + * warm keep-alive connections, while idle keys age out. + */ +const clientCache = new LRUCache({ + max: CLIENT_CACHE_MAX_ENTRIES, + ttl: CLIENT_CACHE_TTL_MS, + updateAgeOnGet: true, +}) + +/** + * Memoizes provider SDK clients so connections stay warm across requests rather + * than re-handshaking per call. The key must be namespaced per provider and + * encode every input that varies the client; the API key is always part of it, + * making it the tenant boundary (clients are never shared across keys). + */ +export function getCachedProviderClient(key: string, factory: () => T): T { + const existing = clientCache.get(key) + if (existing) { + return existing as T + } + + const client = factory() + clientCache.set(key, client) + return client +} + +/** Clears the cache so tests asserting client construction start from a miss. */ +export function clearProviderClientCacheForTests(): void { + clientCache.clear() +} diff --git a/apps/sim/providers/vllm/index.test.ts b/apps/sim/providers/vllm/index.test.ts index 8739c95f98..925c61ca2d 100644 --- a/apps/sim/providers/vllm/index.test.ts +++ b/apps/sim/providers/vllm/index.test.ts @@ -79,6 +79,7 @@ vi.mock('@/stores/providers', () => ({ useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, })) +import { clearProviderClientCacheForTests } from '@/providers/client-cache' import type { ProviderToolConfig } from '@/providers/types' import { vllmProvider } from '@/providers/vllm/index' @@ -117,6 +118,7 @@ const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0] describe('vllmProvider', () => { beforeEach(() => { vi.clearAllMocks() + clearProviderClientCacheForTests() openAIArgs.length = 0 envState.VLLM_BASE_URL = 'http://localhost:8000' envState.VLLM_API_KEY = undefined diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 572b5df51c..936bf3af63 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -7,6 +7,7 @@ import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input import type { StreamingExecution } from '@/executor/types' import { MAX_TOOL_ITERATIONS } from '@/providers' import { formatMessagesForProvider } from '@/providers/attachments' +import { getCachedProviderClient } from '@/providers/client-cache' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import { createStreamingExecution } from '@/providers/streaming-execution' import { adaptOpenAIChatToolSchema } from '@/providers/tool-schema-adapter' @@ -114,6 +115,7 @@ export const vllmProvider: ProviderConfig = { * IP blocklist and blocked-port checks still apply, so SSRF protection is intact. */ let pinnedFetch: typeof fetch | undefined + let pinnedIP: string | undefined if (userProvidedEndpoint) { const validation = await validateUrlWithDNS(userProvidedEndpoint, 'vLLM endpoint', { allowHttp: true, @@ -128,15 +130,20 @@ export const vllmProvider: ProviderConfig = { if (!validation.resolvedIP) { throw new Error('Invalid vLLM endpoint: could not resolve a pinnable IP address') } - pinnedFetch = createPinnedFetch(validation.resolvedIP) + pinnedIP = validation.resolvedIP + pinnedFetch = createPinnedFetch(pinnedIP) } const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty' - const vllm = new OpenAI({ - apiKey, - baseURL: `${baseUrl}/v1`, - ...(pinnedFetch ? { fetch: pinnedFetch } : {}), - }) + const vllm = getCachedProviderClient( + `vllm::${apiKey}::${baseUrl}::${pinnedIP ?? 'no-pin'}`, + () => + new OpenAI({ + apiKey, + baseURL: `${baseUrl}/v1`, + ...(pinnedFetch ? { fetch: pinnedFetch } : {}), + }) + ) const allMessages: Message[] = []