Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 79 additions & 27 deletions packages/loro-websocket/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import type { CrdtDocAdaptor } from "loro-adaptors";

export * from "loro-adaptors";

export type AuthProvider = () => Uint8Array | Promise<Uint8Array>;
type AuthOption = Uint8Array | AuthProvider;

interface FragmentBatch {
header: DocUpdateFragmentHeader;
fragments: Map<number, Uint8Array>;
Expand All @@ -36,7 +39,7 @@ interface PendingRoom {
reject: (error: Error) => void;
adaptor: CrdtDocAdaptor;
roomId: string;
auth?: Uint8Array;
auth?: AuthOption;
isRejoin?: boolean;
}

Expand Down Expand Up @@ -173,7 +176,7 @@ export class LoroWebsocketClient {
private roomAdaptors: Map<string, CrdtDocAdaptor> = new Map();
// Track roomId for each active id so we can rejoin on reconnect
private roomIds: Map<string, string> = new Map();
private roomAuth: Map<string, Uint8Array | undefined> = new Map();
private roomAuth: Map<string, AuthOption | undefined> = new Map();
private roomStatusListeners: Map<
string,
Set<(s: RoomJoinStatusValue) => void>
Expand Down Expand Up @@ -206,6 +209,17 @@ export class LoroWebsocketClient {
void this.connect();
}

private async resolveAuth(auth?: AuthOption): Promise<Uint8Array> {
if (typeof auth === "function") {
const value = await auth();
if (!(value instanceof Uint8Array)) {
throw new Error("Auth provider must return Uint8Array");
}
return value;
}
return auth ?? new Uint8Array();
}

get socket(): WebSocket {
return this.ws;
}
Expand Down Expand Up @@ -562,17 +576,27 @@ export class LoroWebsocketClient {
if (!roomId) continue;
const active = this.activeRooms.get(id);
if (!active) continue;
this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
void this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
}
}

private sendRejoinRequest(
private async sendRejoinRequest(
id: string,
roomId: string,
adaptor: CrdtDocAdaptor,
room: LoroWebsocketClientRoom,
auth?: Uint8Array
auth?: AuthOption
) {
let authValue: Uint8Array;
try {
authValue = await this.resolveAuth(auth);
} catch (e) {
console.error("Failed to resolve auth for rejoin:", e);
this.cleanupRoom(roomId, adaptor.crdtType);
this.emitRoomStatus(id, RoomJoinStatus.Error);
return;
}

// Prepare a lightweight pending entry so JoinError handling can retry version formats
const pending: PendingRoom = {
room: Promise.resolve(room),
Expand All @@ -589,6 +613,7 @@ export class LoroWebsocketClient {
},
reject: (error: Error) => {
console.error("Rejoin failed:", error);
this.pendingRooms.delete(id);
this.cleanupRoom(roomId, adaptor.crdtType);
this.emitRoomStatus(id, RoomJoinStatus.Error);
},
Expand All @@ -603,7 +628,7 @@ export class LoroWebsocketClient {
type: MessageType.JoinRequest,
crdt: adaptor.crdtType,
roomId,
auth: auth ?? new Uint8Array(),
auth: authValue,
version: adaptor.getVersion(),
} as JoinRequest);

Expand Down Expand Up @@ -677,7 +702,7 @@ export class LoroWebsocketClient {
// Drop any in-flight join since the server explicitly removed us
this.pendingRooms.delete(roomId);
if (shouldRejoin && active && adaptor) {
this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
void this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
} else {
// Remove local room state so client does not auto-retry unless requested
this.cleanupRoom(msg.roomId, msg.crdt);
Expand Down Expand Up @@ -815,6 +840,19 @@ export class LoroWebsocketClient {
roomId: string
) {
if (msg.code === JoinErrorCode.VersionUnknown) {
let authValue: Uint8Array;
try {
authValue = await this.resolveAuth(pending.auth);
} catch (e) {
pending.reject(e as Error);
this.pendingRooms.delete(roomId);
this.emitRoomStatus(
pending.adaptor.crdtType + pending.roomId,
RoomJoinStatus.Error
);
return;
}

// Try alternative version format
const currentVersion = pending.adaptor.getVersion();
const alternativeVersion =
Expand All @@ -826,7 +864,7 @@ export class LoroWebsocketClient {
type: MessageType.JoinRequest,
crdt: pending.adaptor.crdtType,
roomId: pending.roomId,
auth: pending.auth ?? new Uint8Array(),
auth: authValue,
version: alternativeVersion,
} as JoinRequest)
);
Expand All @@ -838,7 +876,7 @@ export class LoroWebsocketClient {
type: MessageType.JoinRequest,
crdt: pending.adaptor.crdtType,
roomId: pending.roomId,
auth: pending.auth ?? new Uint8Array(),
auth: authValue,
version: new Uint8Array(),
} as JoinRequest)
);
Expand Down Expand Up @@ -915,7 +953,11 @@ export class LoroWebsocketClient {
}

/**
* Join a room; `auth` carries application-defined join metadata forwarded to the server.
* Join a room.
* - `auth` may be a `Uint8Array` or a provider function.
* - The provider is invoked on the initial join and again on protocol-driven retries
* (e.g. `VersionUnknown`) and reconnect rejoins, so it can refresh short-lived tokens.
* If callers need a stable token, memoize in the provider.
*/
join({
roomId,
Expand All @@ -925,7 +967,7 @@ export class LoroWebsocketClient {
}: {
roomId: string;
crdtAdaptor: CrdtDocAdaptor;
auth?: Uint8Array;
auth?: AuthOption;
onStatusChange?: (s: RoomJoinStatusValue) => void;
}): Promise<LoroWebsocketClientRoom> {
const id = crdtAdaptor.crdtType + roomId;
Expand All @@ -940,8 +982,8 @@ export class LoroWebsocketClient {
return Promise.resolve(active.room);
}

let resolve: (res: JoinResponseOk) => void;
let reject: (error: Error) => void;
let resolve!: (res: JoinResponseOk) => void;
let reject!: (error: Error) => void;

const response = new Promise<JoinResponseOk>((resolve_, reject_) => {
resolve = resolve_;
Expand Down Expand Up @@ -1005,6 +1047,7 @@ export class LoroWebsocketClient {
return room;
});

// Register pending room immediately so concurrent join calls dedupe
this.pendingRooms.set(id, {
room,
resolve: resolve!,
Expand All @@ -1015,21 +1058,30 @@ export class LoroWebsocketClient {
});
this.roomAuth.set(id, auth);

const joinPayload = encode({
type: MessageType.JoinRequest,
crdt: crdtAdaptor.crdtType,
roomId,
auth: auth ?? new Uint8Array(),
version: crdtAdaptor.getVersion(),
} as JoinRequest);
void this.resolveAuth(auth)
.then(authValue => {
const joinPayload = encode({
type: MessageType.JoinRequest,
crdt: crdtAdaptor.crdtType,
roomId,
auth: authValue,
version: crdtAdaptor.getVersion(),
} as JoinRequest);

if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(joinPayload);
} else {
this.enqueueJoin(joinPayload);
// ensure a connection attempt is running
void this.connect();
}
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(joinPayload);
} else {
this.enqueueJoin(joinPayload);
// ensure a connection attempt is running
void this.connect();
}
})
.catch(err => {
const error = err instanceof Error ? err : new Error(String(err));
this.emitRoomStatus(id, RoomJoinStatus.Error);
reject(error);
this.cleanupRoom(roomId, crdtAdaptor.crdtType);
});

return room;
}
Expand Down
82 changes: 82 additions & 0 deletions packages/loro-websocket/tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,88 @@ describe("E2E: Client-Server Sync", () => {
await authServer.stop();
}, 15000);

it("fetches fresh auth on rejoin when auth provider is used", async () => {
const port = await getPort();
const tokens: string[] = [];

const server = new SimpleServer({
port,
authenticate: async (_roomId, _crdt, auth) => {
tokens.push(new TextDecoder().decode(auth));
return "write";
},
});
await server.start();

const client = new LoroWebsocketClient({
url: `ws://localhost:${port}`,
reconnect: { initialDelayMs: 20, maxDelayMs: 100, jitter: 0 },
});

let room: LoroWebsocketClientRoom | undefined;
try {
await client.waitConnected();
let call = 0;
const adaptor = new LoroAdaptor();

room = await client.join({
roomId: "auth-refresh",
crdtAdaptor: adaptor,
auth: async () => new TextEncoder().encode(`token-${++call}`),
});

await waitUntil(() => tokens.length >= 1, 5000, 25);

await server.stop();
await new Promise(resolve => setTimeout(resolve, 60));
await server.start();

await waitUntil(() => tokens.some(t => t === "token-2"), 10000, 50);

expect(tokens[0]).toBe("token-1");
expect(tokens.some(t => t === "token-2")).toBe(true);
} finally {
await room?.destroy();
client.destroy();
await server.stop();
}
}, 15000);

it("dedupes concurrent join calls even before auth resolves", async () => {
const port = await getPort();
const tokens: string[] = [];

const server = new SimpleServer({
port,
authenticate: async (_roomId, _crdt, auth) => {
tokens.push(new TextDecoder().decode(auth));
return "write";
},
});
await server.start();

const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
await client.waitConnected();

const adaptor = new LoroAdaptor();
const auth = () => new TextEncoder().encode("token-once");

const joinPromise1 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth });
const joinPromise2 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth });

expect(joinPromise1).toBe(joinPromise2);

const [room1, room2] = await Promise.all([joinPromise1, joinPromise2]);
expect(room1).toBe(room2);

await waitUntil(() => tokens.length >= 1, 5000, 25);
expect(tokens).toHaveLength(1);

await room1.destroy();
client.destroy();
await server.stop();
}, 15000);

it("destroy rejects pending ping waiters", async () => {
const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
await client.waitConnected();
Expand Down