Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions apps/sim/lib/api-key/byok.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { mockOrderBy, mockGetWorkspaceById, mockDecryptSecret } = vi.hoisted(() => ({
const { mockOrderBy, mockDecryptSecret } = vi.hoisted(() => ({
mockOrderBy: vi.fn(),
mockGetWorkspaceById: vi.fn(),
mockDecryptSecret: vi.fn(),
}))

Expand All @@ -19,10 +18,6 @@ vi.mock('@sim/db', () => ({
},
}))

vi.mock('@/lib/workspaces/permissions/utils', () => ({
getWorkspaceById: mockGetWorkspaceById,
}))

vi.mock('@/lib/core/security/encryption', () => ({
decryptSecret: mockDecryptSecret,
}))
Expand Down Expand Up @@ -70,7 +65,6 @@ const storedKey = (id: string) => ({ id, encryptedApiKey: `encrypted-${id}` })
describe('getBYOKKey', () => {
beforeEach(() => {
vi.clearAllMocks()
mockGetWorkspaceById.mockResolvedValue({ id: 'workspace' })
mockOrderBy.mockResolvedValue([])
mockDecryptSecret.mockImplementation(async (encrypted: string) => ({
decrypted: encrypted.replace('encrypted-', 'decrypted-'),
Expand All @@ -80,13 +74,6 @@ describe('getBYOKKey', () => {
it('returns null when no workspaceId is provided', async () => {
expect(await getBYOKKey(undefined, 'openai')).toBeNull()
expect(await getBYOKKey(null, 'openai')).toBeNull()
expect(mockGetWorkspaceById).not.toHaveBeenCalled()
})

it('returns null when the workspace does not exist', async () => {
mockGetWorkspaceById.mockResolvedValue(null)

expect(await getBYOKKey(uniqueWorkspaceId(), 'openai')).toBeNull()
})

it('returns null when the workspace has no keys for the provider', async () => {
Expand Down Expand Up @@ -123,6 +110,17 @@ describe('getBYOKKey', () => {
])
})

it('reads the key list fresh from the database on every call', async () => {
const workspaceId = uniqueWorkspaceId()
mockOrderBy.mockResolvedValue([storedKey('key-1')])

await getBYOKKey(workspaceId, 'openai')
await getBYOKKey(workspaceId, 'openai')
await getBYOKKey(workspaceId, 'openai')

expect(mockOrderBy).toHaveBeenCalledTimes(3)
})

it('tracks rotation independently per provider within a workspace', async () => {
const workspaceId = uniqueWorkspaceId()
mockOrderBy.mockResolvedValue([storedKey('key-1'), storedKey('key-2')])
Expand Down
9 changes: 3 additions & 6 deletions apps/sim/lib/api-key/byok.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { getRotatingApiKey } from '@/lib/core/config/api-keys'
import { env } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/env-flags'
import { decryptSecret } from '@/lib/core/security/encryption'
import { getWorkspaceById } from '@/lib/workspaces/permissions/utils'
import { getHostedModels } from '@/providers/models'
import { PROVIDER_PLACEHOLDER_KEY } from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store'
Expand Down Expand Up @@ -37,6 +36,9 @@ function nextRotationIndex(poolKey: string, poolSize: number): number {
* multiple keys stored for the provider, requests round-robin across them in
* creation order. A key that fails to decrypt is skipped in favor of the next
* one in the pool.
*
* The key list is read fresh every call (not cached): BYOK is not a hot query,
* and reading fresh keeps revocation immediate across ECS tasks.
*/
export async function getBYOKKey(
workspaceId: string | undefined | null,
Expand All @@ -47,11 +49,6 @@ export async function getBYOKKey(
}

try {
const activeWorkspace = await getWorkspaceById(workspaceId)
if (!activeWorkspace) {
return null
}

const keys = await db
.select({ id: workspaceBYOKKeys.id, encryptedApiKey: workspaceBYOKKeys.encryptedApiKey })
.from(workspaceBYOKKeys)
Expand Down
3 changes: 3 additions & 0 deletions apps/sim/lib/billing/calculations/usage-monitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ export async function checkOrgMemberUsageLimit(
return { isExceeded: false, currentUsage: 0, limit: null }
}

// Resolve the cap first and short-circuit when unset (the common case); only
// then is computing usage worthwhile. Kept sequential, not raced, to avoid a
// usage query on every uncapped member's execution.
const limit = await getOrgMemberUsageLimit(organizationId, userId)
if (limit === null) {
return { isExceeded: false, currentUsage: 0, limit: null }
Expand Down
193 changes: 193 additions & 0 deletions apps/sim/lib/billing/core/plan.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/**
* @vitest-environment node
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

/**
* Drizzle mock for `getHighestPrioritySubscription`. It issues up to four
* queries keyed by table:
* - `subscription` for the user's personal subs (parallelized with members)
* - `member` for the user's org memberships (parallelized with subs)
* - `organization` for the org-existence follow-up
* - `subscription` again for the org-scoped subs follow-up
*
* The mock routes results by the table object passed to `.from()`, serving the
* (twice-read) `subscription` table from a FIFO queue (first read = personal,
* second = org). It records which tables were queried so we can assert the
* parallelized pair both run and that follow-ups are skipped when appropriate.
*
* Table sentinels and shared mock state live inside `vi.hoisted` so the
* `vi.mock` factories (hoisted to the top of the file) can reference them.
*/
const { SUBSCRIPTION_TABLE, MEMBER_TABLE, ORGANIZATION_TABLE, resultsByTable, fromCalls, select } =
vi.hoisted(() => {
const SUBSCRIPTION_TABLE = { __table: 'subscription' }
const MEMBER_TABLE = { __table: 'member' }
const ORGANIZATION_TABLE = { __table: 'organization' }

const resultsByTable: Record<string, unknown[][]> = {
subscription: [],
member: [],
organization: [],
}
const fromCalls: string[] = []

const select = vi.fn(() => ({
from: (table: { __table: string }) => {
fromCalls.push(table.__table)
const where = () => {
const queue = resultsByTable[table.__table]
const next = queue.length > 0 ? queue.shift() : []
return Promise.resolve(next ?? [])
}
return { where }
},
}))

return {
SUBSCRIPTION_TABLE,
MEMBER_TABLE,
ORGANIZATION_TABLE,
resultsByTable,
fromCalls,
select,
}
})

vi.mock('@sim/db', () => ({
db: { select },
}))

vi.mock('@sim/db/schema', () => ({
subscription: SUBSCRIPTION_TABLE,
member: MEMBER_TABLE,
organization: ORGANIZATION_TABLE,
}))

/**
* Realistic plan-check predicates so `pickHighestPrioritySubscription` exercises
* the real Enterprise > Team > Pro priority ordering over the rows we feed it.
*/
vi.mock('@/lib/billing/subscriptions/utils', () => ({
ENTITLED_SUBSCRIPTION_STATUSES: ['active', 'past_due'],
checkEnterprisePlan: (s: any) =>
s?.plan === 'enterprise' && ['active', 'past_due'].includes(s?.status),
checkTeamPlan: (s: any) => s?.plan === 'team' && ['active', 'past_due'].includes(s?.status),
checkProPlan: (s: any) => s?.plan === 'pro' && ['active', 'past_due'].includes(s?.status),
}))

import { getHighestPrioritySubscription } from '@/lib/billing/core/plan'

interface SubRow {
id: string
referenceId: string
plan: string
status: string
}

function personalPro(userId: string): SubRow {
return { id: 'sub-personal-pro', referenceId: userId, plan: 'pro', status: 'active' }
}

function orgEnterprise(orgId: string): SubRow {
return { id: 'sub-org-enterprise', referenceId: orgId, plan: 'enterprise', status: 'active' }
}

function queue(table: 'subscription' | 'member' | 'organization', rows: unknown[]) {
resultsByTable[table].push(rows)
}

describe('getHighestPrioritySubscription', () => {
beforeEach(() => {
vi.clearAllMocks()
resultsByTable.subscription = []
resultsByTable.member = []
resultsByTable.organization = []
fromCalls.length = 0
})

it('picks the org Enterprise sub over a personal Pro sub (priority order)', async () => {
queue('subscription', [personalPro('user-1')]) // personalSubs query
queue('member', [{ organizationId: 'org-1' }]) // memberships query
queue('organization', [{ id: 'org-1' }]) // org-existence query
queue('subscription', [orgEnterprise('org-1')]) // org-subscriptions query

const result = await getHighestPrioritySubscription('user-1')

expect(result).not.toBeNull()
expect(result?.id).toBe('sub-org-enterprise')
expect(result?.plan).toBe('enterprise')
})

it('selection is deterministic regardless of which parallelized query resolves first', async () => {
queue('subscription', [personalPro('user-1')])
queue('member', [{ organizationId: 'org-1' }])
queue('organization', [{ id: 'org-1' }])
queue('subscription', [orgEnterprise('org-1')])

const result = await getHighestPrioritySubscription('user-1')

expect(result?.id).toBe('sub-org-enterprise')
})

it('issues BOTH the personal-subscriptions and memberships queries (parallelized pair)', async () => {
queue('subscription', [personalPro('user-1')])
queue('member', [{ organizationId: 'org-1' }])
queue('organization', [{ id: 'org-1' }])
queue('subscription', [orgEnterprise('org-1')])

await getHighestPrioritySubscription('user-1')

expect(fromCalls).toContain('subscription')
expect(fromCalls).toContain('member')
// First two queries are exactly the parallelized pair (in either order).
expect(fromCalls.slice(0, 2).sort()).toEqual(['member', 'subscription'])
})

it('returns the personal sub and skips org follow-ups when there are no memberships', async () => {
queue('subscription', [personalPro('user-1')])
queue('member', [])

const result = await getHighestPrioritySubscription('user-1')

expect(result?.id).toBe('sub-personal-pro')
expect(result?.plan).toBe('pro')
// org-existence + org-subscription follow-ups are NOT issued.
expect(fromCalls).not.toContain('organization')
expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1)
})

it('returns null when neither personal nor org subscriptions exist', async () => {
queue('subscription', [])
queue('member', [])

const result = await getHighestPrioritySubscription('user-1')

expect(result).toBeNull()
})

it('excludes orphaned org memberships whose organization row no longer exists', async () => {
queue('subscription', [])
queue('member', [{ organizationId: 'ghost-org' }]) // membership points at a deleted org
queue('organization', [])

const result = await getHighestPrioritySubscription('user-1')

// Org subs are never fetched (no valid org ids) -> falls back to null.
expect(result).toBeNull()
expect(fromCalls).toContain('organization')
// Only the initial personal-subs read on `subscription`; org-subs query skipped.
expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1)
})

it('falls back to the personal sub when the only org is orphaned', async () => {
queue('subscription', [personalPro('user-1')])
queue('member', [{ organizationId: 'ghost-org' }])
queue('organization', [])

const result = await getHighestPrioritySubscription('user-1')

expect(result?.id).toBe('sub-personal-pro')
expect(fromCalls.filter((t) => t === 'subscription')).toHaveLength(1)
})
})
29 changes: 15 additions & 14 deletions apps/sim/lib/billing/core/plan.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,21 @@ export async function getHighestPrioritySubscription(
) {
const { onError = 'return-null', executor = db } = options
try {
const personalSubs = await executor
.select()
.from(subscription)
.where(
and(
eq(subscription.referenceId, userId),
inArray(subscription.status, ENTITLED_SUBSCRIPTION_STATUSES)
)
)

const memberships = await executor
.select({ organizationId: member.organizationId })
.from(member)
.where(eq(member.userId, userId))
const [personalSubs, memberships] = await Promise.all([
executor
.select()
.from(subscription)
.where(
and(
eq(subscription.referenceId, userId),
inArray(subscription.status, ENTITLED_SUBSCRIPTION_STATUSES)
)
),
executor
.select({ organizationId: member.organizationId })
.from(member)
.where(eq(member.userId, userId)),
])

const orgIds = memberships.map((m: { organizationId: string }) => m.organizationId)

Expand Down
Loading
Loading