Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 70 additions & 32 deletions packages/loro-websocket/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import type { CrdtDocAdaptor } from "loro-adaptors";

export * from "loro-adaptors";

export type AuthProvider = () => Uint8Array | Promise<Uint8Array>;
type AuthOption = Uint8Array | AuthProvider;

interface FragmentBatch {
header: DocUpdateFragmentHeader;
fragments: Map<number, Uint8Array>;
Expand Down Expand Up @@ -173,7 +176,7 @@ export class LoroWebsocketClient {
private roomAdaptors: Map<string, CrdtDocAdaptor> = new Map();
// Track roomId for each active id so we can rejoin on reconnect
private roomIds: Map<string, string> = new Map();
private roomAuth: Map<string, Uint8Array | undefined> = new Map();
private roomAuth: Map<string, AuthOption | undefined> = new Map();
private roomStatusListeners: Map<
string,
Set<(s: RoomJoinStatusValue) => void>
Expand Down Expand Up @@ -206,6 +209,17 @@ export class LoroWebsocketClient {
void this.connect();
}

private async resolveAuth(auth?: AuthOption): Promise<Uint8Array> {
if (typeof auth === "function") {
const value = await auth();
if (!(value instanceof Uint8Array)) {
throw new Error("Auth provider must return Uint8Array");
}
return value;
}
return auth ?? new Uint8Array();
}

get socket(): WebSocket {
return this.ws;
}
Expand Down Expand Up @@ -562,17 +576,27 @@ export class LoroWebsocketClient {
if (!roomId) continue;
const active = this.activeRooms.get(id);
if (!active) continue;
this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
void this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
}
}

private sendRejoinRequest(
private async sendRejoinRequest(
id: string,
roomId: string,
adaptor: CrdtDocAdaptor,
room: LoroWebsocketClientRoom,
auth?: Uint8Array
auth?: AuthOption
) {
let authValue: Uint8Array;
try {
authValue = await this.resolveAuth(auth);
} catch (e) {
console.error("Failed to resolve auth for rejoin:", e);
this.cleanupRoom(roomId, adaptor.crdtType);
this.emitRoomStatus(id, RoomJoinStatus.Error);
return;
}

// Prepare a lightweight pending entry so JoinError handling can retry version formats
const pending: PendingRoom = {
room: Promise.resolve(room),
Expand All @@ -594,7 +618,7 @@ export class LoroWebsocketClient {
},
adaptor,
roomId,
auth,
auth: authValue,
isRejoin: true,
};
this.pendingRooms.set(id, pending);
Expand All @@ -603,7 +627,7 @@ export class LoroWebsocketClient {
type: MessageType.JoinRequest,
crdt: adaptor.crdtType,
roomId,
auth: auth ?? new Uint8Array(),
auth: authValue,
version: adaptor.getVersion(),
} as JoinRequest);

Expand Down Expand Up @@ -677,7 +701,7 @@ export class LoroWebsocketClient {
// Drop any in-flight join since the server explicitly removed us
this.pendingRooms.delete(roomId);
if (shouldRejoin && active && adaptor) {
this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
void this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
} else {
// Remove local room state so client does not auto-retry unless requested
this.cleanupRoom(msg.roomId, msg.crdt);
Expand Down Expand Up @@ -925,7 +949,7 @@ export class LoroWebsocketClient {
}: {
roomId: string;
crdtAdaptor: CrdtDocAdaptor;
auth?: Uint8Array;
auth?: AuthOption;
onStatusChange?: (s: RoomJoinStatusValue) => void;
}): Promise<LoroWebsocketClientRoom> {
const id = crdtAdaptor.crdtType + roomId;
Expand All @@ -940,8 +964,8 @@ export class LoroWebsocketClient {
return Promise.resolve(active.room);
}

let resolve: (res: JoinResponseOk) => void;
let reject: (error: Error) => void;
let resolve!: (res: JoinResponseOk) => void;
let reject!: (error: Error) => void;

const response = new Promise<JoinResponseOk>((resolve_, reject_) => {
resolve = resolve_;
Expand Down Expand Up @@ -1005,31 +1029,45 @@ export class LoroWebsocketClient {
return room;
});

this.pendingRooms.set(id, {
room,
resolve: resolve!,
reject: reject!,
adaptor: crdtAdaptor,
roomId,
auth,
});
this.roomAuth.set(id, auth);

const joinPayload = encode({
type: MessageType.JoinRequest,
crdt: crdtAdaptor.crdtType,
roomId,
auth: auth ?? new Uint8Array(),
version: crdtAdaptor.getVersion(),
} as JoinRequest);
// Resolve auth before registering pending room to avoid race condition
// where JoinError retry might use undefined auth
void this.resolveAuth(auth)
.then(authValue => {
// Register pending room only after auth is resolved
this.pendingRooms.set(id, {
room,
resolve: resolve!,
reject: reject!,
adaptor: crdtAdaptor,
roomId,
auth: authValue,
});

if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(joinPayload);
} else {
this.enqueueJoin(joinPayload);
// ensure a connection attempt is running
void this.connect();
}
const joinPayload = encode({
type: MessageType.JoinRequest,
crdt: crdtAdaptor.crdtType,
roomId,
auth: authValue,
version: crdtAdaptor.getVersion(),
} as JoinRequest);

if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(joinPayload);
} else {
this.enqueueJoin(joinPayload);
// ensure a connection attempt is running
void this.connect();
}
})
.catch(err => {
const error = err instanceof Error ? err : new Error(String(err));
this.emitRoomStatus(id, RoomJoinStatus.Error);
reject(error);
this.cleanupRoom(roomId, crdtAdaptor.crdtType);
this.pendingRooms.delete(id);
});

return room;
}
Expand Down
47 changes: 47 additions & 0 deletions packages/loro-websocket/tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,53 @@ describe("E2E: Client-Server Sync", () => {
await authServer.stop();
}, 15000);

it("fetches fresh auth on rejoin when auth provider is used", async () => {
const port = await getPort();
const tokens: string[] = [];

const server = new SimpleServer({
port,
authenticate: async (_roomId, _crdt, auth) => {
tokens.push(new TextDecoder().decode(auth));
return "write";
},
});
await server.start();

const client = new LoroWebsocketClient({
url: `ws://localhost:${port}`,
reconnect: { initialDelayMs: 20, maxDelayMs: 100, jitter: 0 },
});

let room: LoroWebsocketClientRoom | undefined;
try {
await client.waitConnected();
let call = 0;
const adaptor = new LoroAdaptor();

room = await client.join({
roomId: "auth-refresh",
crdtAdaptor: adaptor,
auth: async () => new TextEncoder().encode(`token-${++call}`),
});

await waitUntil(() => tokens.length >= 1, 5000, 25);

await server.stop();
await new Promise(resolve => setTimeout(resolve, 60));
await server.start();

await waitUntil(() => tokens.some(t => t === "token-2"), 10000, 50);

expect(tokens[0]).toBe("token-1");
expect(tokens.some(t => t === "token-2")).toBe(true);
} finally {
await room?.destroy();
client.destroy();
await server.stop();
}
}, 15000);

it("destroy rejects pending ping waiters", async () => {
const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
await client.waitConnected();
Expand Down