diff --git a/.changeset/strong-trains-act.md b/.changeset/strong-trains-act.md new file mode 100644 index 0000000000..d0394c8efa --- /dev/null +++ b/.changeset/strong-trains-act.md @@ -0,0 +1,8 @@ +--- +'@tanstack/start-client-core': minor +'@tanstack/start-plugin-core': patch +'@tanstack/start-server-core': patch +'@tanstack/start-fn-stubs': patch +--- + +add createCsrfMiddleware based on Sec-Fetch-Site header, auto-apply to unconfigured servers, warn for others diff --git a/docs/start/framework/react/guide/middleware.md b/docs/start/framework/react/guide/middleware.md index 09c8e2afab..27cf384a43 100644 --- a/docs/start/framework/react/guide/middleware.md +++ b/docs/start/framework/react/guide/middleware.md @@ -440,6 +440,51 @@ export const startInstance = createStart(() => { > [!NOTE] > Global **request** middleware runs before **every request, including server routes, SSR and server functions**. +### CSRF Middleware + +Server functions are same-origin RPC endpoints and should be protected from cross-site requests. If your app does not define `src/start.ts`, TanStack Start installs its CSRF middleware automatically for server functions. + +If you define a custom `src/start.ts`, add `createCsrfMiddleware()` explicitly: + +```tsx +// src/start.ts +import { createStart, createCsrfMiddleware } from '@tanstack/react-start' + +const csrfMiddleware = createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', +}) + +export const startInstance = createStart(() => ({ + requestMiddleware: [csrfMiddleware], +})) +``` + +By default, `Origin` and `Referer` checks compare against the incoming request URL origin. If your deployment needs to allow a different public origin, configure it on the CSRF middleware with `createCsrfMiddleware({ origin: 'https://app.example.com' })`. + +By default, `createCsrfMiddleware()` validates every request handled by the middleware. Use `filter: (ctx) => ctx.handlerType === 'serverFn'` when installing it globally for server function protection. It verifies same-origin browser request metadata with `Sec-Fetch-Site`, `Origin`, or `Referer` headers and rejects requests that cannot be proven same-origin. + +You can also use the same middleware to protect any other route. + +```tsx +export const Route = createFileRoute('/api/foo')({ + server: { + middleware: [createCsrfMiddleware()], + handlers: { GET: () => {...} } + } +}) +``` + +If you define `src/start.ts` without the CSRF middleware, Start shows a development warning for server function requests. If you intentionally handle CSRF another way, disable the warning: + +```tsx +// vite.config.ts or rsbuild.config.ts +tanstackStart({ + serverFns: { + disableCsrfMiddlewareWarning: true, + }, +}) +``` + ### Global Server Function Middleware To have a middleware run for **every server function in your application**, add it to the `functionMiddleware` array in your `src/start.ts` file: diff --git a/docs/start/framework/react/guide/server-functions.md b/docs/start/framework/react/guide/server-functions.md index 5e9e6e16e2..fb521b8c8e 100644 --- a/docs/start/framework/react/guide/server-functions.md +++ b/docs/start/framework/react/guide/server-functions.md @@ -21,6 +21,30 @@ const time = await getServerTime() Server functions provide server capabilities (database access, environment variables, file system) while maintaining type safety across the network boundary. +## Same-Origin Requests + +Server functions are same-origin RPC endpoints for your application. Browser requests to server functions should come from the same origin, verified with Fetch Metadata (`Sec-Fetch-Site`), `Origin`, or `Referer` headers. Use server routes for public APIs or endpoints that intentionally support cross-origin requests. + +TanStack Start provides `createCsrfMiddleware()` to protect server functions from cross-site requests. If your app does not define `src/start.ts`, Start installs this middleware automatically for server functions. If you define `src/start.ts`, add the middleware explicitly: + +```tsx +// src/start.ts +import { createStart, createCsrfMiddleware } from '@tanstack/react-start' + +const csrfMiddleware = createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', +}) + +export const startInstance = createStart(() => ({ + requestMiddleware: [csrfMiddleware], +})) +``` + +By default, `Origin` and `Referer` checks compare against the incoming request URL origin. If your deployment needs to allow a different public origin, configure it on the CSRF middleware with `createCsrfMiddleware({ origin: 'https://app.example.com' })`. + +> [!TIP] +> Requests without any of these headers (`Sec-Fetch-Site`, `Origin`, or `Referer`) are rejected by default. If your deployment strips these headers and you have another layer that guarantees same-origin server function requests, you can opt in with `createCsrfMiddleware({ filter: (ctx) => ctx.handlerType === 'serverFn', allowRequestsWithoutOriginCheck: true })`. + ## Basic Usage Server functions are created with `createServerFn()` and can specify HTTP method: diff --git a/e2e/react-start/server-functions/src/start.ts b/e2e/react-start/server-functions/src/start.ts index 1cd840ca96..8c3a96dd71 100644 --- a/e2e/react-start/server-functions/src/start.ts +++ b/e2e/react-start/server-functions/src/start.ts @@ -1,4 +1,4 @@ -import { createStart } from '@tanstack/react-start' +import { createCsrfMiddleware, createStart } from '@tanstack/react-start' import type { CustomFetch } from '@tanstack/react-start' /** @@ -16,6 +16,11 @@ const globalServerFnFetch: CustomFetch = (input, init) => { } export const startInstance = createStart(() => ({ + requestMiddleware: [ + createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', + }), + ], serverFns: { fetch: globalServerFnFetch, }, diff --git a/e2e/react-start/server-functions/tests/server-functions.spec.ts b/e2e/react-start/server-functions/tests/server-functions.spec.ts index 4173362e4c..9b23cb8a2a 100644 --- a/e2e/react-start/server-functions/tests/server-functions.spec.ts +++ b/e2e/react-start/server-functions/tests/server-functions.spec.ts @@ -288,6 +288,32 @@ test('Direct POST submitting FormData to a Server function returns the correct m expect(result).toBe(expected) }) +test('CSRF middleware rejects cross-site Server function requests', async ({ + page, + request, +}) => { + await page.goto('/submit-post-formdata') + await page.waitForLoadState('networkidle') + + const actionUrl = await page + .getByTestId('submit-post-formdata-form') + .getAttribute('action') + + expect(actionUrl).toBeTruthy() + + const response = await request.post(actionUrl!, { + headers: { + 'Sec-Fetch-Site': 'cross-site', + }, + multipart: { + name: 'Sean', + }, + }) + + expect(response.status()).toBe(403) + await expect(response.text()).resolves.toBe('Forbidden') +}) + test("server function's dead code is preserved if already there", async ({ page, }) => { diff --git a/packages/react-start-client/src/tests/createServerFn.test-d.tsx b/packages/react-start-client/src/tests/createServerFn.test-d.tsx index e7d8c7be08..b020676956 100644 --- a/packages/react-start-client/src/tests/createServerFn.test-d.tsx +++ b/packages/react-start-client/src/tests/createServerFn.test-d.tsx @@ -24,7 +24,9 @@ test('createServerFn returns async array', () => { return result }) - expectTypeOf(serverFn()).toEqualTypeOf>>() + expectTypeOf>().toEqualTypeOf< + Promise> + >() }) test('createServerFn returns sync array', () => { @@ -33,7 +35,9 @@ test('createServerFn returns sync array', () => { return result }) - expectTypeOf(serverFn()).toEqualTypeOf>>() + expectTypeOf>().toEqualTypeOf< + Promise> + >() }) test('createServerFn returns async union', () => { @@ -42,7 +46,9 @@ test('createServerFn returns async union', () => { return result }) - expectTypeOf(serverFn()).toEqualTypeOf>() + expectTypeOf>().toEqualTypeOf< + Promise + >() }) test('createServerFn returns sync union', () => { @@ -51,5 +57,7 @@ test('createServerFn returns sync union', () => { return result }) - expectTypeOf(serverFn()).toEqualTypeOf>() + expectTypeOf>().toEqualTypeOf< + Promise + >() }) diff --git a/packages/start-client-core/src/createCsrfMiddleware.ts b/packages/start-client-core/src/createCsrfMiddleware.ts new file mode 100644 index 0000000000..b7152d44a9 --- /dev/null +++ b/packages/start-client-core/src/createCsrfMiddleware.ts @@ -0,0 +1,197 @@ +import { createIsomorphicFn } from '@tanstack/start-fn-stubs' +import { createMiddleware } from './createMiddleware' +import type { + RequestMiddlewareAfterServer, + RequestServerOptions, +} from './createMiddleware' +import type { Register } from '@tanstack/router-core' + +export const csrfSymbol = Symbol.for('tanstack-start:csrf-middleware') + +export type CsrfSecFetchSite = + | 'same-origin' + | 'same-site' + | 'cross-site' + | 'none' + +export type CsrfMatcher = + | TValue + | Array + | (( + value: TValue | (string & {}), + ctx: RequestServerOptions, + ) => boolean | Promise) + +export interface CsrfMiddlewareOptions< + TRegister = Register, + TMiddlewares = unknown, +> { + /** + * Return `true` to validate this request, or `false` to skip validation. + * + * @default undefined, which validates every request handled by this middleware. + */ + filter?: ( + ctx: RequestServerOptions, + ) => boolean | Promise + /** + * Allowed Origin values. Defaults to the trusted request origin. + */ + origin?: CsrfMatcher + /** + * Allowed Sec-Fetch-Site values. + * + * @default 'same-origin' + */ + secFetchSite?: CsrfMatcher + /** + * Whether to use Referer as a fallback when Sec-Fetch-Site and Origin are absent. + * + * @default true + */ + referer?: + | boolean + | (( + referer: string, + ctx: RequestServerOptions, + ) => boolean | Promise) + /** + * Allow requests when Sec-Fetch-Site, Origin, and Referer are all missing. + * + * @default false + */ + allowRequestsWithoutOriginCheck?: boolean + /** + * Optional response returned when CSRF validation fails. + * + * @default new Response('Forbidden', { status: 403 }) + */ + failureResponse?: + | Response + | (( + ctx: RequestServerOptions, + ) => Response | Promise) +} + +type CreateCsrfMiddleware = ( + opts?: CsrfMiddlewareOptions, +) => RequestMiddlewareAfterServer<{}, undefined, undefined> + +const innerCreateCsrfMiddleware: CreateCsrfMiddleware = (opts = {}) => { + const middleware = createMiddleware().server(async (ctx) => { + const csrfCtx = ctx as RequestServerOptions & typeof ctx + + if (opts.filter && !(await opts.filter(csrfCtx))) { + return ctx.next() + } + + if (await isCsrfRequestAllowed(opts, csrfCtx)) { + return ctx.next() + } + + return getFailureResponse(opts, csrfCtx) + }) + + if (process.env.NODE_ENV !== 'production') { + Object.defineProperty(middleware, csrfSymbol, { value: true }) + } + + return middleware +} + +export const createCsrfMiddleware: CreateCsrfMiddleware = + createIsomorphicFn().server(innerCreateCsrfMiddleware) as CreateCsrfMiddleware + +export async function isCsrfRequestAllowed( + opts: CsrfMiddlewareOptions, + ctx: RequestServerOptions, +): Promise { + const result = await getCsrfRequestValidationResult(opts, ctx) + return ( + result === true || + (result === undefined && opts.allowRequestsWithoutOriginCheck === true) + ) +} + +export async function getCsrfRequestValidationResult( + opts: CsrfMiddlewareOptions, + ctx: RequestServerOptions, +): Promise { + const fetchSite = ctx.request.headers.get('Sec-Fetch-Site') + if (fetchSite !== null) { + return matchValue(opts.secFetchSite ?? 'same-origin', fetchSite, ctx) + } + + const origin = ctx.request.headers.get('Origin') + if (origin !== null) { + if (opts.origin) { + return matchValue(opts.origin, origin, ctx) + } + + return origin === new URL(ctx.request.url).origin + } + + const referer = ctx.request.headers.get('Referer') + if (referer === null || opts.referer === false) { + return undefined + } + + if (typeof opts.referer === 'function') { + return opts.referer(referer, ctx) + } + + if (opts.origin) { + const refererOrigin = getOriginFromUrl(referer) + return ( + refererOrigin !== undefined && matchValue(opts.origin, refererOrigin, ctx) + ) + } + + return isRefererSameOrigin(referer, new URL(ctx.request.url).origin) +} + +async function matchValue( + matcher: CsrfMatcher, + value: string, + ctx: RequestServerOptions, +): Promise { + if (typeof matcher === 'function') { + return matcher(value, ctx) + } + + if (Array.isArray(matcher)) { + // typescript is dumb for array.includes() + return matcher.includes(value as TValue) + } + + return value === matcher +} + +function getOriginFromUrl(url: string): string | undefined { + try { + return new URL(url).origin + } catch { + return undefined + } +} + +function isRefererSameOrigin(referer: string, requestOrigin: string): boolean { + if (referer === requestOrigin) return true + if (!referer.startsWith(requestOrigin)) return false + if (referer.length === requestOrigin.length) return true + const code = referer.charCodeAt(requestOrigin.length) + return code === 47 /* '/' */ || code === 63 /* '?' */ || code === 35 /* '#' */ +} + +async function getFailureResponse( + opts: CsrfMiddlewareOptions, + ctx: RequestServerOptions, +): Promise { + if (typeof opts.failureResponse === 'function') { + return opts.failureResponse(ctx) + } + + return ( + opts.failureResponse?.clone() ?? new Response('Forbidden', { status: 403 }) + ) +} diff --git a/packages/start-client-core/src/createMiddleware.ts b/packages/start-client-core/src/createMiddleware.ts index 9e0a435791..8f4ff474fc 100644 --- a/packages/start-client-core/src/createMiddleware.ts +++ b/packages/start-client-core/src/createMiddleware.ts @@ -770,6 +770,10 @@ export interface RequestServerOptions { pathname: string context: Expand> next: RequestServerNextFn + /** + * Type of Start handler currently processing this request. + */ + handlerType: 'serverFn' | 'router' /** * Metadata about the server function being invoked. * This is only present when the request is handling a server function call. diff --git a/packages/start-client-core/src/index.tsx b/packages/start-client-core/src/index.tsx index f7835c9859..5bf09c4440 100644 --- a/packages/start-client-core/src/index.tsx +++ b/packages/start-client-core/src/index.tsx @@ -84,6 +84,17 @@ export { flattenMiddlewares, executeMiddleware, } from './createServerFn' +export { + createCsrfMiddleware, + csrfSymbol, + getCsrfRequestValidationResult, + isCsrfRequestAllowed, +} from './createCsrfMiddleware' +export type { + CsrfMatcher, + CsrfMiddlewareOptions, + CsrfSecFetchSite, +} from './createCsrfMiddleware' export { TSS_FORMDATA_CONTEXT, diff --git a/packages/start-client-core/src/tests/createCsrfMiddleware.test.ts b/packages/start-client-core/src/tests/createCsrfMiddleware.test.ts new file mode 100644 index 0000000000..f152e82144 --- /dev/null +++ b/packages/start-client-core/src/tests/createCsrfMiddleware.test.ts @@ -0,0 +1,290 @@ +import { describe, expect, it, vi } from 'vitest' +import { + createCsrfMiddleware, + csrfSymbol, + getCsrfRequestValidationResult, + isCsrfRequestAllowed, +} from '../createCsrfMiddleware' +import type { RequestServerOptions } from '../createMiddleware' +import type { Register } from '@tanstack/router-core' + +const requestOrigin = 'https://app.example.com' + +function trackHeaders(init: Record) { + const headers = new Map( + Object.entries(init).map(([key, value]) => [key.toLowerCase(), value]), + ) + const reads: Array = [] + + return { + reads, + request: { + url: `${requestOrigin}/_serverFn/test`, + headers: { + get(name: string) { + reads.push(name) + return headers.get(name.toLowerCase()) ?? null + }, + }, + } as Request, + } +} + +function createContext(init: { + headers?: Record + handlerType?: 'serverFn' | 'router' + origin?: string +}): RequestServerOptions { + const { request } = trackHeaders(init.headers ?? {}) + return { + request, + pathname: new URL(request.url).pathname, + context: undefined, + next: (() => undefined) as any, + handlerType: init.handlerType ?? 'serverFn', + } +} + +async function runMiddleware( + middleware: ReturnType, + ctx: RequestServerOptions, +) { + const next = vi.fn(() => ({ request: ctx.request, pathname: ctx.pathname })) + const result = await middleware.options.server!({ + ...ctx, + next, + } as any) + return { result, next } +} + +describe('getCsrfRequestValidationResult', () => { + it('allows same-origin fetch metadata without reading Origin or Referer', async () => { + const { request, reads } = trackHeaders({ + 'Sec-Fetch-Site': 'same-origin', + Origin: 'https://evil.example.com', + Referer: 'https://evil.example.com/path', + }) + + await expect( + getCsrfRequestValidationResult({}, createContextFromRequest(request)), + ).resolves.toBe(true) + expect(reads).toEqual(['Sec-Fetch-Site']) + }) + + it.each(['same-site', 'cross-site', 'none', 'invalid'])( + 'rejects %s fetch metadata', + async (fetchSite) => { + const ctx = createContext({ + headers: { + 'Sec-Fetch-Site': fetchSite, + Origin: requestOrigin, + Referer: `${requestOrigin}/path`, + }, + }) + + await expect(getCsrfRequestValidationResult({}, ctx)).resolves.toBe(false) + }, + ) + + it('allows matching Origin without reading Referer', async () => { + const { request, reads } = trackHeaders({ + Origin: requestOrigin, + Referer: 'https://evil.example.com/path', + }) + + await expect( + getCsrfRequestValidationResult({}, createContextFromRequest(request)), + ).resolves.toBe(true) + expect(reads).toEqual(['Sec-Fetch-Site', 'Origin']) + }) + + it('rejects mismatched Origin without reading Referer', async () => { + const { request, reads } = trackHeaders({ + Origin: 'https://evil.example.com', + Referer: `${requestOrigin}/path`, + }) + + await expect( + getCsrfRequestValidationResult({}, createContextFromRequest(request)), + ).resolves.toBe(false) + expect(reads).toEqual(['Sec-Fetch-Site', 'Origin']) + }) + + it.each([ + requestOrigin, + `${requestOrigin}/path`, + `${requestOrigin}?query=1`, + `${requestOrigin}#hash`, + ])('allows same-origin Referer fallback: %s', async (referer) => { + const ctx = createContext({ headers: { Referer: referer } }) + + await expect(getCsrfRequestValidationResult({}, ctx)).resolves.toBe(true) + }) + + it.each([ + 'https://evil.example.com/path', + `${requestOrigin}.evil/path`, + `${requestOrigin}:443/path`, + ])('rejects cross-origin Referer fallback: %s', async (referer) => { + const ctx = createContext({ headers: { Referer: referer } }) + + await expect(getCsrfRequestValidationResult({}, ctx)).resolves.toBe(false) + }) + + it('returns undefined for requests without origin check headers', async () => { + const ctx = createContext({}) + + await expect( + getCsrfRequestValidationResult({}, ctx), + ).resolves.toBeUndefined() + }) + + it('rejects empty Referer header as known invalid origin check', async () => { + const ctx = createContext({ headers: { Referer: '' } }) + + await expect(getCsrfRequestValidationResult({}, ctx)).resolves.toBe(false) + }) + + it('rejects missing origin check headers by default', async () => { + const ctx = createContext({}) + + await expect(isCsrfRequestAllowed({}, ctx)).resolves.toBe(false) + }) + + it('allows missing origin check headers with the opt-in', async () => { + const ctx = createContext({}) + + await expect( + isCsrfRequestAllowed({ allowRequestsWithoutOriginCheck: true }, ctx), + ).resolves.toBe(true) + }) + + it('does not allow invalid origin check headers with the opt-in', async () => { + const ctx = createContext({ headers: { 'Sec-Fetch-Site': 'cross-site' } }) + + await expect( + isCsrfRequestAllowed({ allowRequestsWithoutOriginCheck: true }, ctx), + ).resolves.toBe(false) + }) + + it('uses custom origin matchers', async () => { + const ctx = createContext({ + headers: { Origin: 'https://preview.example.com' }, + }) + + await expect( + getCsrfRequestValidationResult( + { origin: ['https://app.example.com', 'https://preview.example.com'] }, + ctx, + ), + ).resolves.toBe(true) + }) + + it('uses custom Sec-Fetch-Site matchers', async () => { + const ctx = createContext({ headers: { 'Sec-Fetch-Site': 'same-site' } }) + + await expect( + getCsrfRequestValidationResult( + { secFetchSite: ['same-origin', 'same-site'] }, + ctx, + ), + ).resolves.toBe(true) + }) + + it('reads request URL origin only when needed', async () => { + const getUrl = vi.fn(() => `${requestOrigin}/_serverFn/test`) + const sameOriginFetch = createContext({ + headers: { 'Sec-Fetch-Site': 'same-origin' }, + }) + Object.defineProperty(sameOriginFetch.request, 'url', { get: getUrl }) + + await expect( + getCsrfRequestValidationResult({}, sameOriginFetch), + ).resolves.toBe(true) + expect(getUrl).not.toHaveBeenCalled() + + const originFallback = createContext({ + headers: { Origin: requestOrigin }, + }) + Object.defineProperty(originFallback.request, 'url', { get: getUrl }) + + await expect( + getCsrfRequestValidationResult({}, originFallback), + ).resolves.toBe(true) + expect(getUrl).toHaveBeenCalledTimes(1) + }) +}) + +describe('createCsrfMiddleware', () => { + it('marks middleware with csrfSymbol in non-production environments', () => { + const middleware = createCsrfMiddleware() + + expect(csrfSymbol in middleware).toBe(true) + }) + + it('protects router requests by default', async () => { + const middleware = createCsrfMiddleware() + const ctx = createContext({ handlerType: 'router' }) + + const { result, next } = await runMiddleware(middleware, ctx) + + expect(next).not.toHaveBeenCalled() + expect(result).toBeInstanceOf(Response) + expect((result as Response).status).toBe(403) + }) + + it('protects server function requests by default', async () => { + const middleware = createCsrfMiddleware() + const ctx = createContext({ handlerType: 'serverFn' }) + + const { result, next } = await runMiddleware(middleware, ctx) + + expect(next).not.toHaveBeenCalled() + expect(result).toBeInstanceOf(Response) + expect((result as Response).status).toBe(403) + }) + + it('filters requests', async () => { + const middleware = createCsrfMiddleware({ filter: () => false }) + const ctx = createContext({ handlerType: 'serverFn' }) + + const { next } = await runMiddleware(middleware, ctx) + + expect(next).toHaveBeenCalledTimes(1) + }) + + it('can filter to server function requests', async () => { + const middleware = createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', + }) + const ctx = createContext({ handlerType: 'router' }) + + const { next } = await runMiddleware(middleware, ctx) + + expect(next).toHaveBeenCalledTimes(1) + }) + + it('uses custom failure responses', async () => { + const middleware = createCsrfMiddleware({ + failureResponse: new Response('CSRF failed', { status: 419 }), + }) + const ctx = createContext({ handlerType: 'serverFn' }) + + const { result } = await runMiddleware(middleware, ctx) + + expect((result as Response).status).toBe(419) + await expect((result as Response).text()).resolves.toBe('CSRF failed') + }) +}) + +function createContextFromRequest( + request: Request, +): RequestServerOptions { + return { + request, + pathname: new URL(request.url).pathname, + context: undefined, + next: (() => undefined) as any, + handlerType: 'serverFn', + } +} diff --git a/packages/start-client-core/src/tests/createServerMiddleware.test-d.ts b/packages/start-client-core/src/tests/createServerMiddleware.test-d.ts index 196c94397e..27c9753073 100644 --- a/packages/start-client-core/src/tests/createServerMiddleware.test-d.ts +++ b/packages/start-client-core/src/tests/createServerMiddleware.test-d.ts @@ -675,6 +675,7 @@ test('createMiddleware with type request, no middleware or context', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -698,6 +699,7 @@ test('createMiddleware with type request, no middleware with context', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -722,6 +724,7 @@ test('createMiddleware with type request, middleware and context', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -746,6 +749,7 @@ test('createMiddleware with type request, middleware and context', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: { a: string } + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -769,6 +773,7 @@ test('createMiddleware with type request can return Response directly', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -789,6 +794,7 @@ test('createMiddleware with type request can return Promise', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() @@ -804,6 +810,7 @@ test('createMiddleware with type request can return sync Response', () => { next: RequestServerNextFn<{}, undefined> pathname: string context: undefined + handlerType: 'serverFn' | 'router' serverFnMeta?: ServerFnMeta }>() diff --git a/packages/start-fn-stubs/src/createIsomorphicFn.ts b/packages/start-fn-stubs/src/createIsomorphicFn.ts index 55258bf5a3..695da54453 100644 --- a/packages/start-fn-stubs/src/createIsomorphicFn.ts +++ b/packages/start-fn-stubs/src/createIsomorphicFn.ts @@ -34,13 +34,34 @@ export interface IsomorphicFnBase { ) => ClientOnlyFn } -// this is a dummy function, it will be replaced by the transformer -// if we use `createIsomorphicFn` in this library itself, vite tries to execute it before the transformer runs -// therefore we must return a dummy function that allows calling `server` and `client` method chains. +// The Start compiler normally rewrites createIsomorphicFn() chains before they +// run. Some package tests/build steps execute this stub uncompiled though, for +// example while Vite loads server-side modules during a build. +// +// In those uncompiled contexts we need a real callable fallback, not just a +// chain-shaped object. These contexts are server-side, so once a .server() +// implementation is registered we keep using it even if .client() is chained +// later. Client bundles still get the correct client/no-op implementation +// because the compiler rewrites the original call chain before runtime. export function createIsomorphicFn(): IsomorphicFnBase { - const fn = () => undefined + return createRuntimeFn(() => undefined) as any +} + +type RuntimeFallbackFn = (() => any) & { + server: (serverImpl: () => any) => RuntimeFallbackFn + client: (clientImpl: () => any) => RuntimeFallbackFn +} + +function createRuntimeFn( + fn: () => any, + serverImpl?: () => any, +): RuntimeFallbackFn { return Object.assign(fn, { - server: () => ({ client: () => () => {} }), - client: () => ({ server: () => () => {} }), - }) as any + server: (nextServerImpl: () => any) => { + return createRuntimeFn(nextServerImpl, nextServerImpl) + }, + client: (clientImpl: () => any) => { + return createRuntimeFn(serverImpl ?? clientImpl, serverImpl) + }, + }) } diff --git a/packages/start-fn-stubs/tests/createIsomorphicFn.test.ts b/packages/start-fn-stubs/tests/createIsomorphicFn.test.ts new file mode 100644 index 0000000000..23e97908e6 --- /dev/null +++ b/packages/start-fn-stubs/tests/createIsomorphicFn.test.ts @@ -0,0 +1,24 @@ +import { describe, expect, it } from 'vitest' +import { createIsomorphicFn } from '../src/createIsomorphicFn' + +describe('createIsomorphicFn runtime fallback', () => { + it('returns a callable server implementation', () => { + const fn = createIsomorphicFn().server(() => 'server') + + expect(fn()).toBe('server') + }) + + it('prefers the server implementation when both implementations are registered', () => { + const fn = createIsomorphicFn() + .server(() => 'server') + .client(() => 'client') + + expect(fn()).toBe('server') + }) + + it('returns a callable client-only implementation', () => { + const fn = createIsomorphicFn().client(() => 'client') + + expect(fn()).toBe('client') + }) +}) diff --git a/packages/start-plugin-core/src/rsbuild/plugin.ts b/packages/start-plugin-core/src/rsbuild/plugin.ts index 337cdd55c4..17da5940f3 100644 --- a/packages/start-plugin-core/src/rsbuild/plugin.ts +++ b/packages/start-plugin-core/src/rsbuild/plugin.ts @@ -196,6 +196,17 @@ export function tanStackStartRsbuild( 'import.meta.env.TSS_INLINE_CSS_ENABLED': JSON.stringify( inlineCssEnabled ? 'true' : 'false', ), + 'process.env.TSS_DISABLE_CSRF_MIDDLEWARE_WARNING': JSON.stringify( + startConfig.serverFns.disableCsrfMiddlewareWarning + ? 'true' + : 'false', + ), + 'import.meta.env.TSS_DISABLE_CSRF_MIDDLEWARE_WARNING': + JSON.stringify( + startConfig.serverFns.disableCsrfMiddlewareWarning + ? 'true' + : 'false', + ), }, }, server: { diff --git a/packages/start-plugin-core/src/schema.ts b/packages/start-plugin-core/src/schema.ts index e0287858e6..4034dc9762 100644 --- a/packages/start-plugin-core/src/schema.ts +++ b/packages/start-plugin-core/src/schema.ts @@ -224,6 +224,7 @@ export const tanstackStartOptionsObjectSchema = z.object({ serverFns: z .object({ base: z.string().optional().default('/_serverFn'), + disableCsrfMiddlewareWarning: z.boolean().optional().default(false), generateFunctionId: z .function() .args( diff --git a/packages/start-plugin-core/src/vite/planning.ts b/packages/start-plugin-core/src/vite/planning.ts index 5db724ba8d..39b702c221 100644 --- a/packages/start-plugin-core/src/vite/planning.ts +++ b/packages/start-plugin-core/src/vite/planning.ts @@ -138,6 +138,7 @@ export function createViteDefineConfig(opts: { devSsrStylesBasepath: string inlineCssEnabled: boolean staticNodeEnv: boolean + disableCsrfMiddlewareWarning: boolean }) { return { ...defineReplaceEnv('TSS_SERVER_FN_BASE', opts.serverFnBase), @@ -161,6 +162,10 @@ export function createViteDefineConfig(opts: { 'TSS_INLINE_CSS_ENABLED', opts.inlineCssEnabled ? 'true' : 'false', ), + ...defineReplaceEnv( + 'TSS_DISABLE_CSRF_MIDDLEWARE_WARNING', + opts.disableCsrfMiddlewareWarning ? 'true' : 'false', + ), ...(opts.command === 'build' && opts.staticNodeEnv ? { 'process.env.NODE_ENV': JSON.stringify( diff --git a/packages/start-plugin-core/src/vite/plugin.ts b/packages/start-plugin-core/src/vite/plugin.ts index 915fcd7aca..6c2a301c97 100644 --- a/packages/start-plugin-core/src/vite/plugin.ts +++ b/packages/start-plugin-core/src/vite/plugin.ts @@ -188,6 +188,8 @@ export function tanStackStartVite( inlineCssEnabled: command === 'build' && startConfig.server.build.inlineCss, staticNodeEnv: startConfig.server.build.staticNodeEnv, + disableCsrfMiddlewareWarning: + startConfig.serverFns.disableCsrfMiddlewareWarning, }), builder: { sharedPlugins: true, diff --git a/packages/start-plugin-core/src/vite/start-compiler-plugin/plugin.ts b/packages/start-plugin-core/src/vite/start-compiler-plugin/plugin.ts index 3130ca47c1..d423798547 100644 --- a/packages/start-plugin-core/src/vite/start-compiler-plugin/plugin.ts +++ b/packages/start-plugin-core/src/vite/start-compiler-plugin/plugin.ts @@ -1,3 +1,4 @@ +import { AsyncLocalStorage } from 'node:async_hooks' import { VIRTUAL_MODULES } from '@tanstack/start-server-core' import { resolve as resolvePath } from 'pathe' import { @@ -46,6 +47,20 @@ type ModuleInvalidationEnvironment = { } } +type StartCompilerPluginContext = { + environment: { + name: string + mode: string + transformRequest: (url: string) => Promise + } + load: (options: { id: string }) => Promise<{ code?: string | null }> + resolve: ( + source: string, + importer?: string, + ) => Promise<{ id: string; external?: boolean | string } | null> + error: (message: string) => never +} + function invalidateMatchingFileModules( environment: ModuleInvalidationEnvironment, ids: Iterable, @@ -181,6 +196,17 @@ export function startCompilerPlugin( opts: StartCompilerPluginOptions, ): PluginOption { const compilers = new Map>() + const compilerContextStorage = + new AsyncLocalStorage() + + const getCompilerContext = () => { + const context = compilerContextStorage.getStore() + if (!context) { + throw new Error('Start compiler Vite context is unavailable.') + } + + return context + } // Shared registry of server functions across all environments const serverFnsById: Record = {} @@ -262,27 +288,30 @@ export function startCompilerPlugin( ? createViteDevServerFnModuleSpecifierEncoder(root) : undefined, loadModule: async (id: string) => { + const compilerContext = getCompilerContext() + if (mode === 'build') { - const loaded = await this.load({ id }) + const loaded = await compilerContext.load({ id }) const code = loaded.code ?? '' compiler!.ingestModule({ code, id }) return } - if (this.environment.mode !== 'dev') { - this.error( - `could not load module ${id}: unknown environment mode ${this.environment.mode}`, + if (compilerContext.environment.mode !== 'dev') { + compilerContext.error( + `could not load module ${id}: unknown environment mode ${compilerContext.environment.mode}`, ) } - await this.environment.transformRequest( + await compilerContext.environment.transformRequest( `${id}?${SERVER_FN_LOOKUP}`, ) }, resolveId: async (source: string, importer?: string) => { - const r = await this.resolve(source, importer) + const compilerContext = getCompilerContext() + const r = await compilerContext.resolve(source, importer) if (r) { if (!r.external) { @@ -302,11 +331,15 @@ export function startCompilerPlugin( compilerTransforms, }) - const result = await compiler.compile({ - id, - code, - detectedKinds, - }) + const result = await compilerContextStorage.run( + this as unknown as StartCompilerPluginContext, + () => + compiler.compile({ + id, + code, + detectedKinds, + }), + ) return result }, }, diff --git a/packages/start-plugin-core/tests/csrf-warning-config.test.ts b/packages/start-plugin-core/tests/csrf-warning-config.test.ts new file mode 100644 index 0000000000..abe28087d0 --- /dev/null +++ b/packages/start-plugin-core/tests/csrf-warning-config.test.ts @@ -0,0 +1,65 @@ +import { describe, expect, test } from 'vitest' +import { parseStartConfig as parseViteStartConfig } from '../src/vite/schema' +import { parseStartConfig as parseRsbuildStartConfig } from '../src/rsbuild/schema' +import { createViteDefineConfig } from '../src/vite/planning' + +const root = process.cwd() +const corePluginOpts = { framework: 'react' as const } + +describe('disableCsrfMiddlewareWarning plugin config', () => { + test('is accepted by the vite plugin config', () => { + const config = parseViteStartConfig( + { + serverFns: { + disableCsrfMiddlewareWarning: true, + }, + }, + corePluginOpts, + root, + ) + + expect(config.serverFns.disableCsrfMiddlewareWarning).toBe(true) + }) + + test('is accepted by the rsbuild plugin config', () => { + const config = parseRsbuildStartConfig( + { + serverFns: { + disableCsrfMiddlewareWarning: true, + }, + }, + corePluginOpts, + root, + ) + + expect(config.serverFns.disableCsrfMiddlewareWarning).toBe(true) + }) + + test('defaults to false', () => { + const config = parseViteStartConfig({}, corePluginOpts, root) + + expect(config.serverFns.disableCsrfMiddlewareWarning).toBe(false) + }) + + test('emits the vite define used by the server runtime', () => { + const define = createViteDefineConfig({ + command: 'serve', + mode: 'development', + serverFnBase: '/_serverFn', + routerBasepath: '/', + spaEnabled: true, + devSsrStylesEnabled: true, + devSsrStylesBasepath: '/', + inlineCssEnabled: false, + staticNodeEnv: true, + disableCsrfMiddlewareWarning: true, + }) + + expect(define['process.env.TSS_DISABLE_CSRF_MIDDLEWARE_WARNING']).toBe( + JSON.stringify('true'), + ) + expect(define['import.meta.env.TSS_DISABLE_CSRF_MIDDLEWARE_WARNING']).toBe( + JSON.stringify('true'), + ) + }) +}) diff --git a/packages/start-server-core/src/createStartHandler.ts b/packages/start-server-core/src/createStartHandler.ts index a31a81fde8..293eda1d43 100644 --- a/packages/start-server-core/src/createStartHandler.ts +++ b/packages/start-server-core/src/createStartHandler.ts @@ -1,6 +1,8 @@ import { createMemoryHistory } from '@tanstack/history' import { + createCsrfMiddleware, createNullProtoObject, + csrfSymbol, flattenMiddlewares, mergeHeaders, safeObjectMerge, @@ -340,6 +342,10 @@ interface Entries { // that can cause race conditions during module initialization let entriesPromise: Promise | undefined let baseManifestPromise: Promise | undefined +let hasWarnedMissingCsrfMiddleware = false +const defaultCsrfMiddleware = createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', +}) /** * Cached final manifest (with client entry script tag). In production, @@ -370,6 +376,39 @@ function getEntries() { return entriesPromise } +function hasCsrfMiddleware( + middlewares: Array, +): boolean { + return middlewares.some((middleware) => csrfSymbol in middleware) +} + +function warnMissingCsrfMiddlewareOnce() { + if (hasWarnedMissingCsrfMiddleware) return + hasWarnedMissingCsrfMiddleware = true + + console.warn(`TanStack Start server functions are not protected by the CSRF middleware. + +Server functions are same-origin RPC endpoints and should be protected from cross-site requests. + +Add the CSRF middleware in src/start.ts: + + const csrfMiddleware = createCsrfMiddleware({ + filter: (ctx) => ctx.handlerType === 'serverFn', + }) + + export const startInstance = createStart(() => ({ + requestMiddleware: [csrfMiddleware], + })) + +If you intentionally handle CSRF another way, disable this warning: + + tanstackStart({ + serverFns: { + disableCsrfMiddlewareWarning: true, + }, + })`) +} + /** * Returns the raw manifest data (without client entry script tag baked in). * In dev mode, always returns fresh data. In prod, cached. @@ -682,6 +721,7 @@ export function createStartHandler( } const entries = await getEntries() + const hasStartInstance = !!entries.startEntry.startInstance const startOptions: AnyStartInstanceOptions = (await entries.startEntry.startInstance?.getOptions()) || ({} as AnyStartInstanceOptions) @@ -697,12 +737,15 @@ export function createStartHandler( const requestStartOptions = { ...startOptions, + requestMiddleware: hasStartInstance + ? startOptions.requestMiddleware + : [defaultCsrfMiddleware], serializationAdapters, } // Flatten request middlewares once - const flattenedRequestMiddlewares = startOptions.requestMiddleware - ? flattenMiddlewares(startOptions.requestMiddleware) + const flattenedRequestMiddlewares = requestStartOptions.requestMiddleware + ? flattenMiddlewares(requestStartOptions.requestMiddleware) : [] // Create set for deduplication @@ -745,6 +788,14 @@ export function createStartHandler( // Check for server function requests first (early exit) if (SERVER_FN_BASE && url.pathname.startsWith(SERVER_FN_BASE)) { + if ( + process.env.NODE_ENV !== 'production' && + process.env.TSS_DISABLE_CSRF_MIDDLEWARE_WARNING !== 'true' && + !hasCsrfMiddleware(flattenedRequestMiddlewares) + ) { + warnMissingCsrfMiddlewareOnce() + } + const serverFnId = url.pathname .slice(SERVER_FN_BASE.length) .split('/')[0] @@ -778,6 +829,7 @@ export function createStartHandler( const ctx = await executeMiddleware([...middlewares, serverFnHandler], { request, pathname: url.pathname, + handlerType: 'serverFn', context: createNullProtoObject(requestOpts?.context), }) @@ -937,6 +989,7 @@ export function createStartHandler( { request, pathname: url.pathname, + handlerType: 'router', context: createNullProtoObject(requestOpts?.context), }, ) @@ -1107,6 +1160,7 @@ async function handleServerRoutes({ context, params: routeParams, pathname, + handlerType: 'router', }) // RFC 9110 ยง9.3.2: HEAD must carry the same header fields as GET but no body. diff --git a/packages/start-server-core/src/global.d.ts b/packages/start-server-core/src/global.d.ts index dd9048293e..8d15f8c211 100644 --- a/packages/start-server-core/src/global.d.ts +++ b/packages/start-server-core/src/global.d.ts @@ -7,6 +7,7 @@ declare global { TSS_SHELL?: 'true' | 'false' TSS_PRERENDERING?: 'true' | 'false' TSS_DEV_SERVER?: 'true' | 'false' + TSS_DISABLE_CSRF_MIDDLEWARE_WARNING?: 'true' | 'false' } } }