diff --git a/CHANGELOG.md b/CHANGELOG.md index 676584906..c82dfee67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [EE] Added mermaid diagram rendering to Ask Sourcebot answers, with pan/zoom, copy/export, in-thread deep links, and an interleaved right-panel view. [#1369](https://github.com/sourcebot-dev/sourcebot/pull/1369) - [EE] Added a context-window usage gauge to the Ask Sourcebot chat details, showing how much of the selected model's context window each turn occupies. Window sizes are resolved from the models.dev catalog. [#1370](https://github.com/sourcebot-dev/sourcebot/pull/1370) - Added language model input-modality and document capability resolution, automatically resolved from the models.dev catalog (falls back to text-only for uncatalogued/self-hosted models). [#1372](https://github.com/sourcebot-dev/sourcebot/pull/1372) +- [EE] Added DPoP sender-constrained OAuth tokens for MCP clients. [#1395](https://github.com/sourcebot-dev/sourcebot/pull/1395) ### Fixed - Send anonymous server-side PostHog events as personless so unauthenticated requests don't inflate person counts. [#1367](https://github.com/sourcebot-dev/sourcebot/pull/1367) diff --git a/packages/db/prisma/migrations/20260629193000_add_oauth_dpop_binding/migration.sql b/packages/db/prisma/migrations/20260629193000_add_oauth_dpop_binding/migration.sql new file mode 100644 index 000000000..86546d2e9 --- /dev/null +++ b/packages/db/prisma/migrations/20260629193000_add_oauth_dpop_binding/migration.sql @@ -0,0 +1,5 @@ +ALTER TABLE "OAuthAuthorizationCode" ADD COLUMN "dpopJkt" TEXT; + +ALTER TABLE "OAuthRefreshToken" ADD COLUMN "dpopJkt" TEXT; + +ALTER TABLE "OAuthToken" ADD COLUMN "dpopJkt" TEXT; diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 54444bbe2..014526907 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -654,6 +654,7 @@ model OAuthAuthorizationCode { redirectUri String codeChallenge String // BASE64URL(SHA-256(codeVerifier)) resource String? // RFC 8707: canonical URI of the target resource server + dpopJkt String? // RFC 9449: DPoP JWK SHA-256 thumbprint binding expiresAt DateTime createdAt DateTime @default(now()) } @@ -667,6 +668,7 @@ model OAuthRefreshToken { user User @relation(fields: [userId], references: [id], onDelete: Cascade) scope String @default("") resource String? // RFC 8707 + dpopJkt String? // RFC 9449 expiresAt DateTime createdAt DateTime @default(now()) @@ -682,6 +684,7 @@ model OAuthToken { user User @relation(fields: [userId], references: [id], onDelete: Cascade) scope String @default("") resource String? // RFC 8707: canonical URI of the target resource server + dpopJkt String? // RFC 9449: DPoP JWK SHA-256 thumbprint binding expiresAt DateTime createdAt DateTime @default(now()) lastUsedAt DateTime? diff --git a/packages/web/src/__mocks__/prisma.ts b/packages/web/src/__mocks__/prisma.ts index 5e5c28682..9761ebe24 100644 --- a/packages/web/src/__mocks__/prisma.ts +++ b/packages/web/src/__mocks__/prisma.ts @@ -55,6 +55,7 @@ export const MOCK_OAUTH_TOKEN: OAuthToken & { user: User & { accounts: Account[] userId: MOCK_USER_WITH_ACCOUNTS.id, scope: '', resource: null, + dpopJkt: null, expiresAt: new Date(Date.now() + 1000 * 60 * 60), // 1 hour from now createdAt: new Date(), lastUsedAt: null, @@ -67,8 +68,9 @@ export const MOCK_REFRESH_TOKEN: OAuthRefreshToken = { userId: MOCK_USER_WITH_ACCOUNTS.id, scope: '', resource: null, + dpopJkt: null, expiresAt: new Date(Date.now() + 1000 * 60 * 60 * 24 * 90), // 90 days from now createdAt: new Date(), } -export const userScopedPrismaClientExtension = vi.fn(); \ No newline at end of file +export const userScopedPrismaClientExtension = vi.fn(); diff --git a/packages/web/src/app/api/(server)/ee/.well-known/oauth-authorization-server/route.ts b/packages/web/src/app/api/(server)/ee/.well-known/oauth-authorization-server/route.ts index 179d632e1..0b568d6c8 100644 --- a/packages/web/src/app/api/(server)/ee/.well-known/oauth-authorization-server/route.ts +++ b/packages/web/src/app/api/(server)/ee/.well-known/oauth-authorization-server/route.ts @@ -2,6 +2,7 @@ import { oauthApiHandler } from '@/ee/features/oauth/apiHandler'; import { env } from '@sourcebot/shared'; import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; import { hasEntitlement } from '@/lib/entitlements'; +import { SUPPORTED_DPOP_SIGNING_ALGS } from '@/ee/features/oauth/dpop'; // RFC 8414: OAuth 2.0 Authorization Server Metadata // @see: https://datatracker.ietf.org/doc/html/rfc8414 @@ -26,6 +27,7 @@ export const GET = oauthApiHandler(async () => { grant_types_supported: ['authorization_code', 'refresh_token'], code_challenge_methods_supported: ['S256'], token_endpoint_auth_methods_supported: ['none'], + dpop_signing_alg_values_supported: SUPPORTED_DPOP_SIGNING_ALGS, service_documentation: 'https://docs.sourcebot.dev', }); }); diff --git a/packages/web/src/app/api/(server)/ee/.well-known/oauth-protected-resource/[...path]/route.ts b/packages/web/src/app/api/(server)/ee/.well-known/oauth-protected-resource/[...path]/route.ts index 8afdf5031..dc3cd8103 100644 --- a/packages/web/src/app/api/(server)/ee/.well-known/oauth-protected-resource/[...path]/route.ts +++ b/packages/web/src/app/api/(server)/ee/.well-known/oauth-protected-resource/[...path]/route.ts @@ -37,5 +37,6 @@ export const GET = oauthApiHandler(async (_request: NextRequest, { params }: { p authorization_servers: [ issuer ], + bearer_methods_supported: ['header'], }); }); diff --git a/packages/web/src/app/api/(server)/ee/mcp/route.ts b/packages/web/src/app/api/(server)/ee/mcp/route.ts index 15c7f032a..7d7d375dd 100644 --- a/packages/web/src/app/api/(server)/ee/mcp/route.ts +++ b/packages/web/src/app/api/(server)/ee/mcp/route.ts @@ -22,10 +22,14 @@ async function mcpErrorResponse(error: ServiceError): Promise { const response = serviceErrorResponse(error); if (error.statusCode === StatusCodes.UNAUTHORIZED && await hasEntitlement('oauth')) { const issuer = env.AUTH_URL.replace(/\/$/, ''); - response.headers.set( + response.headers.append( 'WWW-Authenticate', `Bearer realm="Sourcebot", resource_metadata_uri="${issuer}/.well-known/oauth-protected-resource/api/mcp"` ); + response.headers.append( + 'WWW-Authenticate', + `DPoP realm="Sourcebot", resource_metadata_uri="${issuer}/.well-known/oauth-protected-resource/api/mcp"` + ); } return response; } diff --git a/packages/web/src/app/api/(server)/ee/oauth/token/route.ts b/packages/web/src/app/api/(server)/ee/oauth/token/route.ts index 3b9843459..9de18feb6 100644 --- a/packages/web/src/app/api/(server)/ee/oauth/token/route.ts +++ b/packages/web/src/app/api/(server)/ee/oauth/token/route.ts @@ -4,6 +4,7 @@ import { env } from '@sourcebot/shared'; import { NextRequest } from 'next/server'; import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; import { hasEntitlement } from '@/lib/entitlements'; +import { DPOP_PROOF_HEADER, DPOP_TOKEN_TYPE, verifyDpopProof } from '@/ee/features/oauth/dpop'; // OAuth 2.0 Token Endpoint // Supports grant_type=authorization_code with PKCE (RFC 7636). @@ -30,6 +31,20 @@ export const POST = oauthApiHandler(async (request: NextRequest) => { ); } + const dpopProof = request.headers.get(DPOP_PROOF_HEADER); + const dpopProofResult = dpopProof + ? await verifyDpopProof({ request, proof: dpopProof }) + : null; + + if (dpopProofResult && !dpopProofResult.ok) { + return Response.json( + { error: dpopProofResult.error, error_description: dpopProofResult.errorDescription }, + { status: 400 } + ); + } + + const dpopJkt = dpopProofResult?.ok ? dpopProofResult.jkt : null; + if (grantType === 'authorization_code') { const code = formData.get('code'); const redirectUri = formData.get('redirect_uri'); @@ -48,6 +63,7 @@ export const POST = oauthApiHandler(async (request: NextRequest) => { redirectUri: redirectUri.toString(), codeVerifier: codeVerifier.toString(), resource: resource ? resource.toString() : null, + dpopJkt, }); if ('error' in result) { @@ -60,7 +76,7 @@ export const POST = oauthApiHandler(async (request: NextRequest) => { return Response.json({ access_token: result.token, refresh_token: result.refreshToken, - token_type: 'Bearer', + token_type: result.dpopJkt ? DPOP_TOKEN_TYPE : 'Bearer', expires_in: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS, scope: '', }); @@ -80,6 +96,7 @@ export const POST = oauthApiHandler(async (request: NextRequest) => { rawRefreshToken: rawRefreshToken.toString(), clientId: clientId.toString(), resource: resource ? resource.toString() : null, + dpopJkt, }); if ('error' in result) { @@ -92,7 +109,7 @@ export const POST = oauthApiHandler(async (request: NextRequest) => { return Response.json({ access_token: result.token, refresh_token: result.refreshToken, - token_type: 'Bearer', + token_type: result.dpopJkt ? DPOP_TOKEN_TYPE : 'Bearer', expires_in: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS, scope: '', }); diff --git a/packages/web/src/app/oauth/authorize/components/consentScreen.tsx b/packages/web/src/app/oauth/authorize/components/consentScreen.tsx index 94eb27ea8..65bc0431e 100644 --- a/packages/web/src/app/oauth/authorize/components/consentScreen.tsx +++ b/packages/web/src/app/oauth/authorize/components/consentScreen.tsx @@ -18,6 +18,7 @@ interface ConsentScreenProps { redirectUri: string; codeChallenge: string; resource: string | null; + dpopJkt: string | null; state: string | undefined; userEmail: string; } @@ -29,6 +30,7 @@ export function ConsentScreen({ redirectUri, codeChallenge, resource, + dpopJkt, state, userEmail, }: ConsentScreenProps) { @@ -43,7 +45,7 @@ export function ConsentScreen({ const onApprove = async () => { captureEvent('wa_oauth_authorization_approved', { clientId, clientName }); setPending('approve'); - const result = await approveAuthorization({ clientId, redirectUri, codeChallenge, resource, state }); + const result = await approveAuthorization({ clientId, redirectUri, codeChallenge, resource, dpopJkt, state }); if (!isServiceError(result)) { if (!isPermittedRedirectUrl(result)) { toast({ description: `❌ Redirect URL is not permitted.` }); diff --git a/packages/web/src/app/oauth/authorize/page.tsx b/packages/web/src/app/oauth/authorize/page.tsx index ae58ee585..67a0d8356 100644 --- a/packages/web/src/app/oauth/authorize/page.tsx +++ b/packages/web/src/app/oauth/authorize/page.tsx @@ -4,6 +4,7 @@ import { ConsentScreen } from './components/consentScreen'; import { __unsafePrisma } from '@/prisma'; import { hasEntitlement } from '@/lib/entitlements'; import { redirect } from 'next/navigation'; +import { isValidDpopJkt } from '@/ee/features/oauth/dpop'; export const dynamic = 'force-dynamic'; @@ -16,6 +17,7 @@ interface AuthorizePageProps { response_type?: string; state?: string; resource?: string | string[]; + dpop_jkt?: string | string[]; }>; } @@ -25,13 +27,14 @@ export default async function AuthorizePage({ searchParams }: AuthorizePageProps } const params = await searchParams; - const { client_id, redirect_uri, code_challenge, code_challenge_method, response_type, state, resource: _resource } = params; + const { client_id, redirect_uri, code_challenge, code_challenge_method, response_type, state, resource: _resource, dpop_jkt: _dpopJkt } = params; // RFC 8707 allows multiple resource parameters to indicate a token intended for multiple resources. // Sourcebot only supports a single resource (the MCP endpoint), so we take the first value. // // @see: https://www.rfc-editor.org/rfc/rfc8707.html#section-2-2.2 const resource = Array.isArray(_resource) ? _resource[0] : _resource; + const dpopJkt = Array.isArray(_dpopJkt) ? _dpopJkt[0] : _dpopJkt; // Validate required parameters. Per spec, do NOT redirect on client errors — // show an error page instead to avoid open redirect vulnerabilities. @@ -47,6 +50,10 @@ export default async function AuthorizePage({ searchParams }: AuthorizePageProps return ; } + if (dpopJkt && !isValidDpopJkt(dpopJkt)) { + return ; + } + const client = await __unsafePrisma.oAuthClient.findUnique({ where: { id: client_id } }); if (!client) { @@ -74,6 +81,7 @@ export default async function AuthorizePage({ searchParams }: AuthorizePageProps redirectUri={redirect_uri!} codeChallenge={code_challenge!} resource={resource ?? null} + dpopJkt={dpopJkt ?? null} state={state} userEmail={session!.user.email!} /> diff --git a/packages/web/src/ee/features/oauth/actions.ts b/packages/web/src/ee/features/oauth/actions.ts index 5d9928d63..1f53a4bac 100644 --- a/packages/web/src/ee/features/oauth/actions.ts +++ b/packages/web/src/ee/features/oauth/actions.ts @@ -4,6 +4,9 @@ import { sew } from "@/middleware/sew"; import { generateAndStoreAuthCode } from '@/ee/features/oauth/server'; import { withAuth } from '@/middleware/withAuth'; import { UNPERMITTED_SCHEMES } from '@/ee/features/oauth/constants'; +import { isValidDpopJkt } from '@/ee/features/oauth/dpop'; +import { ErrorCode } from '@/lib/errorCodes'; +import { StatusCodes } from 'http-status-codes'; export interface ConnectedOauthClient { id: string; @@ -38,21 +41,32 @@ export const approveAuthorization = async ({ redirectUri, codeChallenge, resource, + dpopJkt, state, }: { clientId: string; redirectUri: string; codeChallenge: string; resource: string | null; + dpopJkt: string | null; state: string | undefined; }) => sew(() => withAuth(async ({ user }) => { + if (dpopJkt !== null && !isValidDpopJkt(dpopJkt)) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_QUERY_PARAMS, + message: 'Invalid dpop_jkt parameter.', + }; + } + const rawCode = await generateAndStoreAuthCode({ clientId, userId: user.id, redirectUri, codeChallenge, resource, + dpopJkt, }); const callbackUrl = new URL(redirectUri); diff --git a/packages/web/src/ee/features/oauth/dpop.test.ts b/packages/web/src/ee/features/oauth/dpop.test.ts new file mode 100644 index 000000000..e8a644887 --- /dev/null +++ b/packages/web/src/ee/features/oauth/dpop.test.ts @@ -0,0 +1,257 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; +import crypto from 'crypto'; +import { + __clearDpopReplayCacheForTests, + calculateDpopJkt, + getDpopAccessTokenHash, + verifyDpopProof, +} from './dpop'; + +vi.mock('server-only', () => ({ default: vi.fn() })); + +vi.mock('@sourcebot/shared', () => ({ + env: { + AUTH_URL: 'https://sourcebot.test', + }, +})); + +type PublicEcJwk = { + kty: 'EC'; + crv: 'P-256'; + x: string; + y: string; +}; + +type KeyPair = { + privateKey: CryptoKey; + publicJwk: PublicEcJwk; +}; + +beforeEach(() => { + __clearDpopReplayCacheForTests(); +}); + +afterEach(() => { + vi.useRealTimers(); +}); + +describe('verifyDpopProof', () => { + test('accepts a valid resource request proof for a DPoP-bound access token', async () => { + const keyPair = await generateKeyPair(); + const accessToken = 'sboa_access-token'; + const request = new Request('http://internal.test/api/ee/mcp?ignored=true', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/mcp', + accessToken, + }); + + const result = await verifyDpopProof({ + request, + proof, + expectedJkt: calculateDpopJkt(keyPair.publicJwk), + accessToken, + requireAccessTokenHash: true, + }); + + expect(result).toEqual({ + ok: true, + jkt: calculateDpopJkt(keyPair.publicJwk), + }); + }); + + test('accepts a token endpoint proof without an access-token hash', async () => { + const keyPair = await generateKeyPair(); + const request = new Request('http://internal.test/api/ee/oauth/token', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/ee/oauth/token', + }); + + const result = await verifyDpopProof({ + request, + proof, + }); + + expect(result).toEqual({ + ok: true, + jkt: calculateDpopJkt(keyPair.publicJwk), + }); + }); + + test('rejects a resource request proof with the wrong access-token hash', async () => { + const keyPair = await generateKeyPair(); + const request = new Request('http://internal.test/api/ee/mcp', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/mcp', + accessToken: 'sboa_other-token', + }); + + const result = await verifyDpopProof({ + request, + proof, + expectedJkt: calculateDpopJkt(keyPair.publicJwk), + accessToken: 'sboa_access-token', + requireAccessTokenHash: true, + }); + + expect(result).toMatchObject({ + ok: false, + error: 'invalid_dpop_proof', + }); + }); + + test('rejects replayed proof ids', async () => { + const keyPair = await generateKeyPair(); + const request = new Request('http://internal.test/api/ee/oauth/token', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/ee/oauth/token', + jti: 'replayed-proof', + }); + + await expect(verifyDpopProof({ request, proof })).resolves.toMatchObject({ ok: true }); + await expect(verifyDpopProof({ request, proof })).resolves.toMatchObject({ + ok: false, + error: 'invalid_dpop_proof', + }); + }); + + test('does not record a proof id before access-token hash validation passes', async () => { + const keyPair = await generateKeyPair(); + const request = new Request('http://internal.test/api/ee/mcp', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/mcp', + accessToken: 'sboa_actual-token', + jti: 'ath-mismatch-not-recorded', + }); + + await expect(verifyDpopProof({ + request, + proof, + expectedJkt: calculateDpopJkt(keyPair.publicJwk), + accessToken: 'sboa_other-token', + requireAccessTokenHash: true, + })).resolves.toMatchObject({ + ok: false, + error: 'invalid_dpop_proof', + }); + + await expect(verifyDpopProof({ + request, + proof, + expectedJkt: calculateDpopJkt(keyPair.publicJwk), + accessToken: 'sboa_actual-token', + requireAccessTokenHash: true, + })).resolves.toMatchObject({ ok: true }); + }); + + test('expires replay cache entries based on proof iat plus the accepted window', async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date('2026-01-01T00:00:00.000Z')); + + const keyPair = await generateKeyPair(); + const request = new Request('http://internal.test/api/ee/oauth/token', { method: 'POST' }); + const proof = await signDpopProof({ + ...keyPair, + htm: 'POST', + htu: 'https://sourcebot.test/api/ee/oauth/token', + iat: Math.floor(Date.now() / 1000) + 60, + jti: 'future-iat-replay', + }); + + await expect(verifyDpopProof({ request, proof })).resolves.toMatchObject({ ok: true }); + + vi.setSystemTime(new Date('2026-01-01T00:05:01.000Z')); + + await expect(verifyDpopProof({ request, proof })).resolves.toMatchObject({ + ok: false, + error: 'invalid_dpop_proof', + errorDescription: 'DPoP proof jti has already been used.', + }); + }); + + test.each([ + { header: null, payload: {} }, + { header: [], payload: {} }, + { header: 'header', payload: {} }, + { header: {}, payload: null }, + { header: {}, payload: [] }, + { header: {}, payload: 'payload' }, + ])('rejects non-object JWT JSON values %#', async ({ header, payload }) => { + const request = new Request('http://internal.test/api/ee/oauth/token', { method: 'POST' }); + const proof = `${base64UrlJson(header)}.${base64UrlJson(payload)}.signature`; + + await expect(verifyDpopProof({ request, proof })).resolves.toMatchObject({ + ok: false, + error: 'invalid_dpop_proof', + }); + }); +}); + +async function generateKeyPair(): Promise { + const keyPair = await crypto.webcrypto.subtle.generateKey( + { name: 'ECDSA', namedCurve: 'P-256' }, + true, + ['sign', 'verify'], + ); + const publicJwk = await crypto.webcrypto.subtle.exportKey('jwk', keyPair.publicKey); + + return { + privateKey: keyPair.privateKey, + publicJwk: { + kty: 'EC', + crv: 'P-256', + x: publicJwk.x!, + y: publicJwk.y!, + }, + }; +} + +async function signDpopProof({ + privateKey, + publicJwk, + htm, + htu, + accessToken, + iat = Math.floor(Date.now() / 1000), + jti = crypto.randomUUID(), +}: KeyPair & { + htm: string; + htu: string; + accessToken?: string; + iat?: number; + jti?: string; +}): Promise { + const encodedHeader = base64UrlJson({ + typ: 'dpop+jwt', + alg: 'ES256', + jwk: publicJwk, + }); + const encodedPayload = base64UrlJson({ + htm, + htu, + iat, + jti, + ...(accessToken ? { ath: getDpopAccessTokenHash(accessToken) } : {}), + }); + const signingInput = `${encodedHeader}.${encodedPayload}`; + const signature = await crypto.webcrypto.subtle.sign( + { name: 'ECDSA', hash: 'SHA-256' }, + privateKey, + new TextEncoder().encode(signingInput), + ); + + return `${signingInput}.${Buffer.from(signature).toString('base64url')}`; +} + +function base64UrlJson(value: unknown): string { + return Buffer.from(JSON.stringify(value)).toString('base64url'); +} diff --git a/packages/web/src/ee/features/oauth/dpop.ts b/packages/web/src/ee/features/oauth/dpop.ts new file mode 100644 index 000000000..2b908f3e0 --- /dev/null +++ b/packages/web/src/ee/features/oauth/dpop.ts @@ -0,0 +1,268 @@ +import 'server-only'; + +import crypto from 'crypto'; +import { env } from '@sourcebot/shared'; + +export const DPOP_AUTH_SCHEME = 'DPoP'; +export const DPOP_PROOF_HEADER = 'DPoP'; +export const DPOP_TOKEN_TYPE = 'DPoP'; +export const SUPPORTED_DPOP_SIGNING_ALGS = ['ES256']; + +const DPOP_PROOF_IAT_WINDOW_SECONDS = 5 * 60; +const DPOP_JKT_PATTERN = /^[A-Za-z0-9_-]{43}$/; + +type DpopJwk = { + kty?: string; + crv?: string; + x?: string; + y?: string; + d?: string; + [key: string]: unknown; +}; + +type DpopHeader = { + typ?: string; + alg?: string; + jwk?: DpopJwk; +}; + +type DpopPayload = { + htm?: unknown; + htu?: unknown; + iat?: unknown; + jti?: unknown; + ath?: unknown; +}; + +type VerifyDpopProofOptions = { + request: Request; + proof: string | null; + expectedJkt?: string | null; + accessToken?: string; + requireAccessTokenHash?: boolean; +}; + +type VerifyDpopProofResult = + | { ok: true; jkt: string } + | { ok: false; error: 'invalid_dpop_proof'; errorDescription: string }; + +const seenProofJtis = new Map(); + +export function getCanonicalRequestUri(request: Request): string { + const requestUrl = new URL(request.url); + const issuer = env.AUTH_URL?.replace(/\/$/, ''); + const pathname = requestUrl.pathname === '/api/ee/mcp' ? '/api/mcp' : requestUrl.pathname; + + if (issuer) { + return `${issuer}${pathname}`; + } + + return `${requestUrl.origin}${pathname}`; +} + +export function getDpopAccessTokenHash(accessToken: string): string { + return crypto.createHash('sha256').update(accessToken).digest('base64url'); +} + +export function isValidDpopJkt(value: string): boolean { + return DPOP_JKT_PATTERN.test(value); +} + +export function calculateDpopJkt(jwk: DpopJwk): string | undefined { + if (!isSupportedPublicJwk(jwk)) { + return undefined; + } + + const thumbprintInput = JSON.stringify({ + crv: jwk.crv, + kty: jwk.kty, + x: jwk.x, + y: jwk.y, + }); + + return crypto.createHash('sha256').update(thumbprintInput).digest('base64url'); +} + +export async function verifyDpopProof({ + request, + proof, + expectedJkt, + accessToken, + requireAccessTokenHash = false, +}: VerifyDpopProofOptions): Promise { + if (!proof) { + return invalidDpopProof('Missing DPoP proof.'); + } + + const parts = proof.split('.'); + if (parts.length !== 3 || parts.some((part) => part.length === 0)) { + return invalidDpopProof('DPoP proof must be a compact JWT.'); + } + + let parsedHeader: unknown; + let parsedPayload: unknown; + try { + parsedHeader = JSON.parse(base64UrlDecode(parts[0]).toString('utf8')); + parsedPayload = JSON.parse(base64UrlDecode(parts[1]).toString('utf8')); + } catch { + return invalidDpopProof('DPoP proof header or payload is not valid JSON.'); + } + + if (!isPlainObject(parsedHeader) || !isPlainObject(parsedPayload)) { + return invalidDpopProof('DPoP proof header and payload must be JSON objects.'); + } + + const header = parsedHeader as DpopHeader; + const payload = parsedPayload as DpopPayload; + + if (header.typ?.toLowerCase() !== 'dpop+jwt') { + return invalidDpopProof('DPoP proof typ must be dpop+jwt.'); + } + + if (header.alg !== 'ES256') { + return invalidDpopProof('DPoP proof alg is not supported.'); + } + + if (!header.jwk || !isSupportedPublicJwk(header.jwk)) { + return invalidDpopProof('DPoP proof must include a supported public JWK.'); + } + + const jkt = calculateDpopJkt(header.jwk); + if (!jkt) { + return invalidDpopProof('DPoP proof JWK thumbprint could not be calculated.'); + } + + if (expectedJkt && jkt !== expectedJkt) { + return invalidDpopProof('DPoP proof key does not match the token binding.'); + } + + const signatureIsValid = await verifyEs256Signature({ + jwk: header.jwk, + signingInput: `${parts[0]}.${parts[1]}`, + signature: parts[2], + }); + if (!signatureIsValid) { + return invalidDpopProof('DPoP proof signature is invalid.'); + } + + const expectedHtm = request.method.toUpperCase(); + if (payload.htm !== expectedHtm) { + return invalidDpopProof('DPoP proof htm does not match the request method.'); + } + + const expectedHtu = getCanonicalRequestUri(request); + if (payload.htu !== expectedHtu) { + return invalidDpopProof('DPoP proof htu does not match the request URI.'); + } + + if (typeof payload.iat !== 'number' || !Number.isFinite(payload.iat)) { + return invalidDpopProof('DPoP proof iat must be a numeric timestamp.'); + } + + const nowSeconds = Math.floor(Date.now() / 1000); + if (Math.abs(nowSeconds - payload.iat) > DPOP_PROOF_IAT_WINDOW_SECONDS) { + return invalidDpopProof('DPoP proof iat is outside the accepted time window.'); + } + + if (typeof payload.jti !== 'string' || payload.jti.length === 0) { + return invalidDpopProof('DPoP proof jti is required.'); + } + const proofExpiresAt = (payload.iat + DPOP_PROOF_IAT_WINDOW_SECONDS) * 1000; + + if (accessToken || requireAccessTokenHash) { + if (typeof payload.ath !== 'string' || !accessToken) { + return invalidDpopProof('DPoP proof ath is required.'); + } + + if (payload.ath !== getDpopAccessTokenHash(accessToken)) { + return invalidDpopProof('DPoP proof ath does not match the access token.'); + } + } + + if (!recordProofJti(jkt, payload.jti, proofExpiresAt)) { + return invalidDpopProof('DPoP proof jti has already been used.'); + } + + return { ok: true, jkt }; +} + +export function __clearDpopReplayCacheForTests() { + seenProofJtis.clear(); +} + +function invalidDpopProof(errorDescription: string): VerifyDpopProofResult { + return { + ok: false, + error: 'invalid_dpop_proof', + errorDescription, + }; +} + +function isPlainObject(value: unknown): value is Record { + return typeof value === 'object' && value !== null && !Array.isArray(value); +} + +function isSupportedPublicJwk(jwk: DpopJwk): jwk is Required> & DpopJwk { + return ( + jwk.kty === 'EC' && + jwk.crv === 'P-256' && + typeof jwk.x === 'string' && + typeof jwk.y === 'string' && + typeof jwk.d !== 'string' + ); +} + +async function verifyEs256Signature({ + jwk, + signingInput, + signature, +}: { + jwk: Required>; + signingInput: string; + signature: string; +}): Promise { + try { + const key = await crypto.webcrypto.subtle.importKey( + 'jwk', + { + kty: jwk.kty, + crv: jwk.crv, + x: jwk.x, + y: jwk.y, + }, + { name: 'ECDSA', namedCurve: 'P-256' }, + false, + ['verify'], + ); + + return await crypto.webcrypto.subtle.verify( + { name: 'ECDSA', hash: 'SHA-256' }, + key, + base64UrlDecode(signature), + new TextEncoder().encode(signingInput), + ); + } catch { + return false; + } +} + +function base64UrlDecode(value: string): Buffer { + return Buffer.from(value, 'base64url'); +} + +function recordProofJti(jkt: string, jti: string, proofExpiresAt: number): boolean { + const now = Date.now(); + for (const [cacheKey, expiresAt] of seenProofJtis.entries()) { + if (expiresAt <= now) { + seenProofJtis.delete(cacheKey); + } + } + + const cacheKey = `${jkt}:${jti}`; + if (seenProofJtis.has(cacheKey)) { + return false; + } + + seenProofJtis.set(cacheKey, proofExpiresAt); + return true; +} diff --git a/packages/web/src/ee/features/oauth/server.test.ts b/packages/web/src/ee/features/oauth/server.test.ts index 244749d6e..85eb1c8b5 100644 --- a/packages/web/src/ee/features/oauth/server.test.ts +++ b/packages/web/src/ee/features/oauth/server.test.ts @@ -31,6 +31,7 @@ const VALID_AUTH_CODE = { // SHA-256('myverifier') base64url codeChallenge: 'Eb223qLjTQNFkRjCVsrDbsBk5ycPKwHdbHNRX99tTeQ', resource: null, + dpopJkt: null, expiresAt: new Date(Date.now() + 10 * 60 * 1000), createdAt: new Date(), }; @@ -59,6 +60,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ @@ -77,6 +79,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -95,6 +98,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -109,6 +113,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -123,6 +128,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/wrong', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -137,6 +143,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'wrongverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -154,11 +161,55 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: 'https://other.com/api/mcp', + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_target' }); }); + test('binds issued tokens to the DPoP proof thumbprint', async () => { + prisma.oAuthAuthorizationCode.findUnique.mockResolvedValue(VALID_AUTH_CODE); + prisma.oAuthAuthorizationCode.delete.mockResolvedValue(VALID_AUTH_CODE); + prisma.oAuthToken.create.mockResolvedValue({} as never); + prisma.oAuthRefreshToken.create.mockResolvedValue({} as never); + + const result = await verifyAndExchangeCode({ + rawCode: VALID_CODE_HASH, + clientId: 'test-client-id', + redirectUri: 'http://localhost:9999/callback', + codeVerifier: 'myverifier', + resource: null, + dpopJkt: 'dpop-thumbprint', + }); + + expect(result).toMatchObject({ dpopJkt: 'dpop-thumbprint' }); + expect(prisma.oAuthToken.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ dpopJkt: 'dpop-thumbprint' }), + }); + expect(prisma.oAuthRefreshToken.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ dpopJkt: 'dpop-thumbprint' }), + }); + }); + + test('returns invalid_dpop_proof if DPoP key does not match the bound authorization code', async () => { + prisma.oAuthAuthorizationCode.findUnique.mockResolvedValue({ + ...VALID_AUTH_CODE, + dpopJkt: 'expected-thumbprint', + }); + + const result = await verifyAndExchangeCode({ + rawCode: VALID_CODE_HASH, + clientId: 'test-client-id', + redirectUri: 'http://localhost:9999/callback', + codeVerifier: 'myverifier', + resource: null, + dpopJkt: 'other-thumbprint', + }); + + expect(result).toMatchObject({ error: 'invalid_dpop_proof' }); + expect(prisma.oAuthAuthorizationCode.delete).not.toHaveBeenCalled(); + }); + test('returns invalid_grant if code was already used (P2025)', async () => { const { Prisma } = await vi.importActual('@prisma/client'); prisma.oAuthAuthorizationCode.findUnique.mockResolvedValue(VALID_AUTH_CODE); @@ -172,6 +223,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -188,6 +240,7 @@ describe('verifyAndExchangeCode', () => { redirectUri: 'http://localhost:9999/callback', codeVerifier: 'myverifier', resource: null, + dpopJkt: null, }) ).rejects.toThrow('DB connection lost'); }); @@ -208,6 +261,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_refreshtoken', clientId: 'test-client-id', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ @@ -222,6 +276,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'invalid_prefix_token', clientId: 'test-client-id', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -237,6 +292,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_used', clientId: 'test-client-id', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -251,6 +307,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_refreshtoken', clientId: 'wrong-client-id', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -267,6 +324,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_refreshtoken', clientId: 'test-client-id', resource: null, + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_grant' }); @@ -282,11 +340,54 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_refreshtoken', clientId: 'test-client-id', resource: 'https://other.com/api/mcp', + dpopJkt: null, }); expect(result).toMatchObject({ error: 'invalid_target' }); }); + test('preserves the DPoP thumbprint when rotating a bound refresh token', async () => { + prisma.oAuthRefreshToken.findUnique.mockResolvedValue({ + ...MOCK_REFRESH_TOKEN, + dpopJkt: 'dpop-thumbprint', + }); + prisma.oAuthRefreshToken.delete.mockResolvedValue(MOCK_REFRESH_TOKEN); + prisma.oAuthToken.create.mockResolvedValue({} as never); + prisma.oAuthRefreshToken.create.mockResolvedValue({} as never); + + const result = await verifyAndRotateRefreshToken({ + rawRefreshToken: 'sbor_refreshtoken', + clientId: 'test-client-id', + resource: null, + dpopJkt: 'dpop-thumbprint', + }); + + expect(result).toMatchObject({ dpopJkt: 'dpop-thumbprint' }); + expect(prisma.oAuthToken.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ dpopJkt: 'dpop-thumbprint' }), + }); + expect(prisma.oAuthRefreshToken.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ dpopJkt: 'dpop-thumbprint' }), + }); + }); + + test('returns invalid_dpop_proof if DPoP key does not match the refresh token binding', async () => { + prisma.oAuthRefreshToken.findUnique.mockResolvedValue({ + ...MOCK_REFRESH_TOKEN, + dpopJkt: 'expected-thumbprint', + }); + + const result = await verifyAndRotateRefreshToken({ + rawRefreshToken: 'sbor_refreshtoken', + clientId: 'test-client-id', + resource: null, + dpopJkt: 'other-thumbprint', + }); + + expect(result).toMatchObject({ error: 'invalid_dpop_proof' }); + expect(prisma.oAuthRefreshToken.delete).not.toHaveBeenCalled(); + }); + test('old refresh token is deleted during rotation', async () => { prisma.oAuthRefreshToken.findUnique.mockResolvedValue(MOCK_REFRESH_TOKEN); prisma.oAuthRefreshToken.delete.mockResolvedValue(MOCK_REFRESH_TOKEN); @@ -297,6 +398,7 @@ describe('verifyAndRotateRefreshToken', () => { rawRefreshToken: 'sbor_refreshtoken', clientId: 'test-client-id', resource: null, + dpopJkt: null, }); expect(prisma.oAuthRefreshToken.delete).toHaveBeenCalledWith({ where: { hash: 'refreshtoken' } }); diff --git a/packages/web/src/ee/features/oauth/server.ts b/packages/web/src/ee/features/oauth/server.ts index ac5676b23..be00b55ef 100644 --- a/packages/web/src/ee/features/oauth/server.ts +++ b/packages/web/src/ee/features/oauth/server.ts @@ -20,12 +20,14 @@ export async function generateAndStoreAuthCode({ redirectUri, codeChallenge, resource, + dpopJkt, }: { clientId: string; userId: string; redirectUri: string; codeChallenge: string; resource: string | null; + dpopJkt: string | null; }): Promise { const rawCode = crypto.randomBytes(32).toString('hex'); const codeHash = hashSecret(rawCode); @@ -38,6 +40,7 @@ export async function generateAndStoreAuthCode({ redirectUri, codeChallenge, resource, + dpopJkt, expiresAt: new Date(Date.now() + env.OAUTH_AUTHORIZATION_CODE_TTL_SECONDS * 1000), }, }); @@ -53,13 +56,15 @@ export async function verifyAndExchangeCode({ redirectUri, codeVerifier, resource, + dpopJkt, }: { rawCode: string; clientId: string; redirectUri: string; codeVerifier: string; resource: string | null; -}): Promise<{ token: string; refreshToken: string; expiresIn: number } | { error: string; errorDescription: string }> { + dpopJkt: string | null; +}): Promise<{ token: string; refreshToken: string; expiresIn: number; dpopJkt: string | null } | { error: string; errorDescription: string }> { const codeHash = hashSecret(rawCode); const authCode = await __unsafePrisma.oAuthAuthorizationCode.findUnique({ @@ -98,6 +103,10 @@ export async function verifyAndExchangeCode({ return { error: 'invalid_target', errorDescription: 'resource parameter does not match the value bound to the authorization code.' }; } + if (authCode.dpopJkt !== null && authCode.dpopJkt !== dpopJkt) { + return { error: 'invalid_dpop_proof', errorDescription: 'DPoP proof key does not match the value bound to the authorization code.' }; + } + // Single-use: delete the auth code before issuing token. // Handle concurrent consume attempts gracefully. try { @@ -111,6 +120,7 @@ export async function verifyAndExchangeCode({ const { token, hash } = generateOAuthToken(); const { token: refreshToken, hash: refreshHash } = generateOAuthRefreshToken(); + const tokenDpopJkt = authCode.dpopJkt ?? dpopJkt; await __unsafePrisma.$transaction([ __unsafePrisma.oAuthToken.create({ @@ -119,6 +129,7 @@ export async function verifyAndExchangeCode({ clientId, userId: authCode.userId, resource: authCode.resource, + dpopJkt: tokenDpopJkt, expiresAt: new Date(Date.now() + env.OAUTH_ACCESS_TOKEN_TTL_SECONDS * 1000), }, }), @@ -128,12 +139,13 @@ export async function verifyAndExchangeCode({ clientId, userId: authCode.userId, resource: authCode.resource, + dpopJkt: tokenDpopJkt, expiresAt: new Date(Date.now() + env.OAUTH_REFRESH_TOKEN_TTL_SECONDS * 1000), }, }), ]); - return { token, refreshToken, expiresIn: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS }; + return { token, refreshToken, expiresIn: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS, dpopJkt: tokenDpopJkt }; } // Verifies a refresh token, rotates it, and issues a new access token + refresh token. @@ -143,11 +155,13 @@ export async function verifyAndRotateRefreshToken({ rawRefreshToken, clientId, resource, + dpopJkt, }: { rawRefreshToken: string; clientId: string; resource: string | null; -}): Promise<{ token: string; refreshToken: string; expiresIn: number } | { error: string; errorDescription: string }> { + dpopJkt: string | null; +}): Promise<{ token: string; refreshToken: string; expiresIn: number; dpopJkt: string | null } | { error: string; errorDescription: string }> { if (!rawRefreshToken.startsWith(OAUTH_REFRESH_TOKEN_PREFIX)) { return { error: 'invalid_grant', errorDescription: 'Refresh token is invalid.' }; } @@ -173,8 +187,13 @@ export async function verifyAndRotateRefreshToken({ return { error: 'invalid_target', errorDescription: 'resource parameter does not match the refresh token.' }; } + if (existing.dpopJkt !== null && existing.dpopJkt !== dpopJkt) { + return { error: 'invalid_dpop_proof', errorDescription: 'DPoP proof key does not match the refresh token binding.' }; + } + const { token, hash: newTokenHash } = generateOAuthToken(); const { token: refreshToken, hash: newRefreshHash } = generateOAuthRefreshToken(); + const tokenDpopJkt = existing.dpopJkt ?? dpopJkt; await __unsafePrisma.$transaction([ __unsafePrisma.oAuthRefreshToken.delete({ where: { hash } }), @@ -184,6 +203,7 @@ export async function verifyAndRotateRefreshToken({ clientId, userId: existing.userId, resource: existing.resource, + dpopJkt: tokenDpopJkt, expiresAt: new Date(Date.now() + env.OAUTH_ACCESS_TOKEN_TTL_SECONDS * 1000), }, }), @@ -193,12 +213,13 @@ export async function verifyAndRotateRefreshToken({ clientId, userId: existing.userId, resource: existing.resource, + dpopJkt: tokenDpopJkt, expiresAt: new Date(Date.now() + env.OAUTH_REFRESH_TOKEN_TTL_SECONDS * 1000), }, }), ]); - return { token, refreshToken, expiresIn: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS }; + return { token, refreshToken, expiresIn: env.OAUTH_ACCESS_TOKEN_TTL_SECONDS, dpopJkt: tokenDpopJkt }; } // Revokes an access token or refresh token by hashing it and deleting the DB record. diff --git a/packages/web/src/lib/apiHandler.test.ts b/packages/web/src/lib/apiHandler.test.ts new file mode 100644 index 000000000..087f2b116 --- /dev/null +++ b/packages/web/src/lib/apiHandler.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { apiHandler } from './apiHandler'; +import { getCurrentRequest } from './requestContext'; + +vi.mock('./posthog', () => ({ + captureEvent: vi.fn(), +})); + +describe('apiHandler', () => { + test('stores the current request while the handler runs', async () => { + const request = new NextRequest('https://sourcebot.example.com/api/test', { + method: 'POST', + }); + + const handler = apiHandler(async () => { + expect(getCurrentRequest()).toBe(request); + await Promise.resolve(); + expect(getCurrentRequest()).toBe(request); + + return new Response(null, { status: 204 }); + }); + + const response = await handler(request); + + expect(response.status).toBe(204); + expect(getCurrentRequest()).toBeUndefined(); + }); +}); diff --git a/packages/web/src/lib/apiHandler.ts b/packages/web/src/lib/apiHandler.ts index 0015a805a..5d36d8d14 100644 --- a/packages/web/src/lib/apiHandler.ts +++ b/packages/web/src/lib/apiHandler.ts @@ -1,5 +1,6 @@ import { NextRequest } from 'next/server'; import { captureEvent } from './posthog'; +import { runWithRequestContext } from './requestContext'; interface ApiHandlerConfig { /** @@ -40,19 +41,20 @@ export function apiHandler( ): H { const { track = true } = config; - const wrappedHandler = async (request: NextRequest, ...rest: unknown[]) => { - if (track) { - const path = request.nextUrl.pathname; - const method = request.method; - const source = request.headers.get('X-Sourcebot-Client-Source') ?? 'unknown'; + const wrappedHandler = async (request: NextRequest, ...rest: unknown[]) => + runWithRequestContext(request, async () => { + if (track) { + const path = request.nextUrl.pathname; + const method = request.method; + const source = request.headers.get('X-Sourcebot-Client-Source') ?? 'unknown'; - // Fire and forget - don't await to avoid blocking the request - captureEvent('api_request', { path, method, source }); - } + // Fire and forget - don't await to avoid blocking the request + captureEvent('api_request', { path, method, source }); + } - // Call the original handler with all arguments - return handler(request, ...rest); - }; + // Call the original handler with all arguments + return handler(request, ...rest); + }); return wrappedHandler as H; } diff --git a/packages/web/src/lib/requestContext.ts b/packages/web/src/lib/requestContext.ts new file mode 100644 index 000000000..b9412a52d --- /dev/null +++ b/packages/web/src/lib/requestContext.ts @@ -0,0 +1,12 @@ +import { AsyncLocalStorage } from 'node:async_hooks'; +import type { NextRequest } from 'next/server'; + +const requestStorage = new AsyncLocalStorage(); + +export function runWithRequestContext(request: NextRequest, fn: () => T): T { + return requestStorage.run(request, fn); +} + +export function getCurrentRequest(): NextRequest | undefined { + return requestStorage.getStore(); +} diff --git a/packages/web/src/middleware/withAuth.test.ts b/packages/web/src/middleware/withAuth.test.ts index bc0586615..f8e59c724 100644 --- a/packages/web/src/middleware/withAuth.test.ts +++ b/packages/web/src/middleware/withAuth.test.ts @@ -1,5 +1,6 @@ import { expect, test, vi, beforeEach, describe } from 'vitest'; import { Session } from 'next-auth'; +import { NextRequest } from 'next/server'; import { notAuthenticated } from '../lib/serviceError'; import { getAuthContext, getAuthenticatedUser, withAuth, withOptionalAuth } from './withAuth'; import { MOCK_API_KEY, MOCK_OAUTH_TOKEN, MOCK_ORG, MOCK_USER_WITH_ACCOUNTS, prisma } from '../__mocks__/prisma'; @@ -7,6 +8,7 @@ import { OrgRole } from '@sourcebot/db'; import { ErrorCode } from '../lib/errorCodes'; import { StatusCodes } from 'http-status-codes'; import { userScopedPrismaClientExtension } from '@/prisma'; +import { runWithRequestContext } from '@/lib/requestContext'; const mocks = vi.hoisted(() => { return { @@ -184,6 +186,40 @@ describe('getAuthenticatedUser', () => { }); }); + test('should use the current request context when no request is passed', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + + const request = new NextRequest('https://sourcebot.example.com/api/test', { + headers: { + Authorization: 'Bearer sourcebot-apikey', + }, + }); + + const result = await runWithRequestContext(request, () => getAuthenticatedUser()); + + expect(result).not.toBeUndefined(); + expect(result?.user.id).toBe(userId); + expect(result?.source).toBe('api_key'); + expect(mocks.headers).not.toHaveBeenCalled(); + expect(prisma.apiKey.update).toHaveBeenCalledWith({ + where: { + hash: 'apikey', + }, + data: { + lastUsedAt: expect.any(Date), + }, + }); + }); + test('should return undefined if a Bearer token is present but the API key does not exist', async () => { prisma.apiKey.findUnique.mockResolvedValue(null); setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); @@ -250,6 +286,27 @@ describe('getAuthenticatedUser', () => { expect(prisma.apiKey.findUnique).not.toHaveBeenCalled(); }); + test('should reject a DPoP-bound OAuth token presented as Bearer', async () => { + mocks.hasEntitlement.mockReturnValue(true); + prisma.oAuthToken.findUnique.mockResolvedValue({ + ...MOCK_OAUTH_TOKEN, + dpopJkt: 'dpop-thumbprint', + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sboa_oauthtoken' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + expect(prisma.oAuthToken.update).not.toHaveBeenCalled(); + }); + + test('should reject an unbound OAuth token presented with the DPoP scheme', async () => { + mocks.hasEntitlement.mockReturnValue(true); + prisma.oAuthToken.findUnique.mockResolvedValue(MOCK_OAUTH_TOKEN); + setMockHeaders(new Headers({ 'Authorization': 'DPoP sboa_oauthtoken' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + expect(prisma.oAuthToken.update).not.toHaveBeenCalled(); + }); + test('should return undefined if a Bearer token is present but the user is not found', async () => { prisma.user.findUnique.mockResolvedValue(null); prisma.apiKey.findUnique.mockResolvedValue({ diff --git a/packages/web/src/middleware/withAuth.ts b/packages/web/src/middleware/withAuth.ts index 0e930fa63..a501c0ad5 100644 --- a/packages/web/src/middleware/withAuth.ts +++ b/packages/web/src/middleware/withAuth.ts @@ -9,6 +9,8 @@ import { StatusCodes } from "http-status-codes"; import { ErrorCode } from "../lib/errorCodes"; import { isServiceError } from "../lib/utils"; import { hasEntitlement, isAnonymousAccessEnabled } from "@/lib/entitlements"; +import { DPOP_AUTH_SCHEME, DPOP_PROOF_HEADER, verifyDpopProof } from "@/ee/features/oauth/dpop"; +import { getCurrentRequest } from "@/lib/requestContext"; const LAST_ACTIVE_AT_THRESHOLD_MS = 5 * 60 * 1000; @@ -148,10 +150,14 @@ export const getAuthenticatedUser = async (): Promise<{ user: UserWithAccounts, return user ? { user, source: 'session' } : undefined; } + const currentRequest = getCurrentRequest(); + const requestHeaders = currentRequest?.headers ?? await headers(); + // If not, check for a Bearer token in the Authorization header. - const authorizationHeader = (await headers()).get("Authorization") ?? undefined; - if (authorizationHeader?.startsWith("Bearer ")) { - const bearerToken = authorizationHeader.slice(7); + const authorizationHeader = requestHeaders.get("Authorization") ?? undefined; + const authorization = parseAuthorizationHeader(authorizationHeader); + if (authorization && (authorization.scheme === 'Bearer' || authorization.scheme === DPOP_AUTH_SCHEME)) { + const bearerToken = authorization.token; // OAuth access token if (bearerToken.startsWith(OAUTH_ACCESS_TOKEN_PREFIX)) { @@ -166,6 +172,28 @@ export const getAuthenticatedUser = async (): Promise<{ user: UserWithAccounts, include: { user: { include: { accounts: true } } }, }); if (oauthToken && oauthToken.expiresAt > new Date()) { + if (!oauthToken.dpopJkt && authorization.scheme === DPOP_AUTH_SCHEME) { + return undefined; + } + + if (oauthToken.dpopJkt) { + if (authorization.scheme !== DPOP_AUTH_SCHEME || !currentRequest) { + return undefined; + } + + const proofResult = await verifyDpopProof({ + request: currentRequest, + proof: requestHeaders.get(DPOP_PROOF_HEADER), + expectedJkt: oauthToken.dpopJkt, + accessToken: bearerToken, + requireAccessTokenHash: true, + }); + + if (!proofResult.ok) { + return undefined; + } + } + await __unsafePrisma.oAuthToken.update({ where: { hash }, data: { lastUsedAt: new Date() }, @@ -174,6 +202,10 @@ export const getAuthenticatedUser = async (): Promise<{ user: UserWithAccounts, } } + if (authorization.scheme !== 'Bearer') { + return undefined; + } + // API key Bearer token (sourcebot-) const apiKey = await getVerifiedApiObject(bearerToken); if (apiKey) { @@ -192,7 +224,7 @@ export const getAuthenticatedUser = async (): Promise<{ user: UserWithAccounts, } // If not, check if we have a valid API key. - const apiKeyString = (await headers()).get("X-Sourcebot-Api-Key") ?? undefined; + const apiKeyString = requestHeaders.get("X-Sourcebot-Api-Key") ?? undefined; if (apiKeyString) { const apiKey = await getVerifiedApiObject(apiKeyString); if (!apiKey) { @@ -229,6 +261,24 @@ export const getAuthenticatedUser = async (): Promise<{ user: UserWithAccounts, return undefined; } +function parseAuthorizationHeader(authorizationHeader: string | undefined): { scheme: string; token: string } | undefined { + const match = authorizationHeader?.match(/^(\S+)\s+(.+)$/); + if (!match) { + return undefined; + } + + const scheme = match[1].toLowerCase(); + if (scheme === 'bearer') { + return { scheme: 'Bearer', token: match[2] }; + } + + if (scheme === 'dpop') { + return { scheme: DPOP_AUTH_SCHEME, token: match[2] }; + } + + return { scheme: match[1], token: match[2] }; +} + /** * Returns an API key object if the API key string is valid, otherwise returns undefined. * Supports both the current prefix (sbk_) and the legacy prefix (sourcebot-). @@ -263,5 +313,3 @@ export const getVerifiedApiObject = async (apiKeyString: string): Promise