diff --git a/packages/loro-websocket/src/client/index.ts b/packages/loro-websocket/src/client/index.ts index bdb79f3..f528c72 100644 --- a/packages/loro-websocket/src/client/index.ts +++ b/packages/loro-websocket/src/client/index.ts @@ -24,6 +24,9 @@ import type { CrdtDocAdaptor } from "loro-adaptors"; export * from "loro-adaptors"; +export type AuthProvider = () => Uint8Array | Promise; +type AuthOption = Uint8Array | AuthProvider; + interface FragmentBatch { header: DocUpdateFragmentHeader; fragments: Map; @@ -36,7 +39,7 @@ interface PendingRoom { reject: (error: Error) => void; adaptor: CrdtDocAdaptor; roomId: string; - auth?: Uint8Array; + auth?: AuthOption; isRejoin?: boolean; } @@ -173,7 +176,7 @@ export class LoroWebsocketClient { private roomAdaptors: Map = new Map(); // Track roomId for each active id so we can rejoin on reconnect private roomIds: Map = new Map(); - private roomAuth: Map = new Map(); + private roomAuth: Map = new Map(); private roomStatusListeners: Map< string, Set<(s: RoomJoinStatusValue) => void> @@ -206,6 +209,17 @@ export class LoroWebsocketClient { void this.connect(); } + private async resolveAuth(auth?: AuthOption): Promise { + if (typeof auth === "function") { + const value = await auth(); + if (!(value instanceof Uint8Array)) { + throw new Error("Auth provider must return Uint8Array"); + } + return value; + } + return auth ?? new Uint8Array(); + } + get socket(): WebSocket { return this.ws; } @@ -562,17 +576,27 @@ export class LoroWebsocketClient { if (!roomId) continue; const active = this.activeRooms.get(id); if (!active) continue; - this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id)); + void this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id)); } } - private sendRejoinRequest( + private async sendRejoinRequest( id: string, roomId: string, adaptor: CrdtDocAdaptor, room: LoroWebsocketClientRoom, - auth?: Uint8Array + auth?: AuthOption ) { + let authValue: Uint8Array; + try { + authValue = await this.resolveAuth(auth); + } catch (e) { + console.error("Failed to resolve auth for rejoin:", e); + this.cleanupRoom(roomId, adaptor.crdtType); + this.emitRoomStatus(id, RoomJoinStatus.Error); + return; + } + // Prepare a lightweight pending entry so JoinError handling can retry version formats const pending: PendingRoom = { room: Promise.resolve(room), @@ -589,6 +613,7 @@ export class LoroWebsocketClient { }, reject: (error: Error) => { console.error("Rejoin failed:", error); + this.pendingRooms.delete(id); this.cleanupRoom(roomId, adaptor.crdtType); this.emitRoomStatus(id, RoomJoinStatus.Error); }, @@ -603,7 +628,7 @@ export class LoroWebsocketClient { type: MessageType.JoinRequest, crdt: adaptor.crdtType, roomId, - auth: auth ?? new Uint8Array(), + auth: authValue, version: adaptor.getVersion(), } as JoinRequest); @@ -677,7 +702,7 @@ export class LoroWebsocketClient { // Drop any in-flight join since the server explicitly removed us this.pendingRooms.delete(roomId); if (shouldRejoin && active && adaptor) { - this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth); + void this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth); } else { // Remove local room state so client does not auto-retry unless requested this.cleanupRoom(msg.roomId, msg.crdt); @@ -815,6 +840,19 @@ export class LoroWebsocketClient { roomId: string ) { if (msg.code === JoinErrorCode.VersionUnknown) { + let authValue: Uint8Array; + try { + authValue = await this.resolveAuth(pending.auth); + } catch (e) { + pending.reject(e as Error); + this.pendingRooms.delete(roomId); + this.emitRoomStatus( + pending.adaptor.crdtType + pending.roomId, + RoomJoinStatus.Error + ); + return; + } + // Try alternative version format const currentVersion = pending.adaptor.getVersion(); const alternativeVersion = @@ -826,7 +864,7 @@ export class LoroWebsocketClient { type: MessageType.JoinRequest, crdt: pending.adaptor.crdtType, roomId: pending.roomId, - auth: pending.auth ?? new Uint8Array(), + auth: authValue, version: alternativeVersion, } as JoinRequest) ); @@ -838,7 +876,7 @@ export class LoroWebsocketClient { type: MessageType.JoinRequest, crdt: pending.adaptor.crdtType, roomId: pending.roomId, - auth: pending.auth ?? new Uint8Array(), + auth: authValue, version: new Uint8Array(), } as JoinRequest) ); @@ -915,7 +953,11 @@ export class LoroWebsocketClient { } /** - * Join a room; `auth` carries application-defined join metadata forwarded to the server. + * Join a room. + * - `auth` may be a `Uint8Array` or a provider function. + * - The provider is invoked on the initial join and again on protocol-driven retries + * (e.g. `VersionUnknown`) and reconnect rejoins, so it can refresh short-lived tokens. + * If callers need a stable token, memoize in the provider. */ join({ roomId, @@ -925,7 +967,7 @@ export class LoroWebsocketClient { }: { roomId: string; crdtAdaptor: CrdtDocAdaptor; - auth?: Uint8Array; + auth?: AuthOption; onStatusChange?: (s: RoomJoinStatusValue) => void; }): Promise { const id = crdtAdaptor.crdtType + roomId; @@ -940,8 +982,8 @@ export class LoroWebsocketClient { return Promise.resolve(active.room); } - let resolve: (res: JoinResponseOk) => void; - let reject: (error: Error) => void; + let resolve!: (res: JoinResponseOk) => void; + let reject!: (error: Error) => void; const response = new Promise((resolve_, reject_) => { resolve = resolve_; @@ -1005,6 +1047,7 @@ export class LoroWebsocketClient { return room; }); + // Register pending room immediately so concurrent join calls dedupe this.pendingRooms.set(id, { room, resolve: resolve!, @@ -1015,21 +1058,30 @@ export class LoroWebsocketClient { }); this.roomAuth.set(id, auth); - const joinPayload = encode({ - type: MessageType.JoinRequest, - crdt: crdtAdaptor.crdtType, - roomId, - auth: auth ?? new Uint8Array(), - version: crdtAdaptor.getVersion(), - } as JoinRequest); + void this.resolveAuth(auth) + .then(authValue => { + const joinPayload = encode({ + type: MessageType.JoinRequest, + crdt: crdtAdaptor.crdtType, + roomId, + auth: authValue, + version: crdtAdaptor.getVersion(), + } as JoinRequest); - if (this.ws && this.ws.readyState === WebSocket.OPEN) { - this.ws.send(joinPayload); - } else { - this.enqueueJoin(joinPayload); - // ensure a connection attempt is running - void this.connect(); - } + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(joinPayload); + } else { + this.enqueueJoin(joinPayload); + // ensure a connection attempt is running + void this.connect(); + } + }) + .catch(err => { + const error = err instanceof Error ? err : new Error(String(err)); + this.emitRoomStatus(id, RoomJoinStatus.Error); + reject(error); + this.cleanupRoom(roomId, crdtAdaptor.crdtType); + }); return room; } diff --git a/packages/loro-websocket/tests/e2e.test.ts b/packages/loro-websocket/tests/e2e.test.ts index 67f6fcb..5e17c1d 100644 --- a/packages/loro-websocket/tests/e2e.test.ts +++ b/packages/loro-websocket/tests/e2e.test.ts @@ -571,6 +571,88 @@ describe("E2E: Client-Server Sync", () => { await authServer.stop(); }, 15000); + it("fetches fresh auth on rejoin when auth provider is used", async () => { + const port = await getPort(); + const tokens: string[] = []; + + const server = new SimpleServer({ + port, + authenticate: async (_roomId, _crdt, auth) => { + tokens.push(new TextDecoder().decode(auth)); + return "write"; + }, + }); + await server.start(); + + const client = new LoroWebsocketClient({ + url: `ws://localhost:${port}`, + reconnect: { initialDelayMs: 20, maxDelayMs: 100, jitter: 0 }, + }); + + let room: LoroWebsocketClientRoom | undefined; + try { + await client.waitConnected(); + let call = 0; + const adaptor = new LoroAdaptor(); + + room = await client.join({ + roomId: "auth-refresh", + crdtAdaptor: adaptor, + auth: async () => new TextEncoder().encode(`token-${++call}`), + }); + + await waitUntil(() => tokens.length >= 1, 5000, 25); + + await server.stop(); + await new Promise(resolve => setTimeout(resolve, 60)); + await server.start(); + + await waitUntil(() => tokens.some(t => t === "token-2"), 10000, 50); + + expect(tokens[0]).toBe("token-1"); + expect(tokens.some(t => t === "token-2")).toBe(true); + } finally { + await room?.destroy(); + client.destroy(); + await server.stop(); + } + }, 15000); + + it("dedupes concurrent join calls even before auth resolves", async () => { + const port = await getPort(); + const tokens: string[] = []; + + const server = new SimpleServer({ + port, + authenticate: async (_roomId, _crdt, auth) => { + tokens.push(new TextDecoder().decode(auth)); + return "write"; + }, + }); + await server.start(); + + const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` }); + await client.waitConnected(); + + const adaptor = new LoroAdaptor(); + const auth = () => new TextEncoder().encode("token-once"); + + const joinPromise1 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth }); + const joinPromise2 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth }); + + expect(joinPromise1).toBe(joinPromise2); + + const [room1, room2] = await Promise.all([joinPromise1, joinPromise2]); + expect(room1).toBe(room2); + + await waitUntil(() => tokens.length >= 1, 5000, 25); + expect(tokens).toHaveLength(1); + + await room1.destroy(); + client.destroy(); + await server.stop(); + }, 15000); + it("destroy rejects pending ping waiters", async () => { const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` }); await client.waitConnected();