diff --git a/api/comms/chat.go b/api/comms/chat.go index 67202ce7..71d96d47 100644 --- a/api/comms/chat.go +++ b/api/comms/chat.go @@ -217,6 +217,21 @@ func chatReadMessages(db dbv1.DBTX, ctx context.Context, userId int32, chatId st return err } +// chatReadAllMessages clears unread state for every chat this user belongs to. +// The (last_active_at IS NULL OR last_active_at < $1) guard mirrors the +// per-chat chat.read handler so out-of-order RPCs can't roll back a more +// recent read. The unread_count > 0 filter keeps the write set small. +func chatReadAllMessages(db dbv1.DBTX, ctx context.Context, userId int32, readTimestamp time.Time) error { + _, err := db.Exec(ctx, ` + update chat_member + set unread_count = 0, last_active_at = $1 + where user_id = $2 + and (last_active_at is null or last_active_at < $1) + and unread_count > 0`, + readTimestamp.UTC(), userId) + return err +} + var permissions = []ChatPermission{ ChatPermissionFollowees, ChatPermissionFollowers, diff --git a/api/comms/chat_test.go b/api/comms/chat_test.go index 53a141ac..44e12d3c 100644 --- a/api/comms/chat_test.go +++ b/api/comms/chat_test.go @@ -139,3 +139,68 @@ func TestChat(t *testing.T) { assertReaction(user1Id, replyMessageId, nil) } + +func TestChatReadAllMessages(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + + ctx := context.Background() + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + + user1Id := int32(1) + user2Id := int32(2) + user3Id := int32(3) + + chatA := trashid.ChatID(int(user1Id), int(user2Id)) + chatB := trashid.ChatID(int(user1Id), int(user3Id)) + SetupChatWithMembers(t, pool, ctx, chatA, user1Id, user2Id, "a1", "a2") + SetupChatWithMembers(t, pool, ctx, chatB, user1Id, user3Id, "b1", "b3") + + assertUnreadCount := func(chatId string, userId int32, expected int) { + t.Helper() + unreadCount := 0 + err := pool.QueryRow(ctx, "select unread_count from chat_member where chat_id = $1 and user_id = $2", chatId, userId).Scan(&unreadCount) + assert.NoError(t, err) + assert.Equal(t, expected, unreadCount, "unread for chat %s user %d", chatId, userId) + } + + // Send user1Id one message in each chat from the other party. + err := chatSendMessage(pool, ctx, user2Id, chatA, strconv.Itoa(seededRand.Int()), time.Now(), "hi from 2") + assert.NoError(t, err) + err = chatSendMessage(pool, ctx, user3Id, chatB, strconv.Itoa(seededRand.Int()), time.Now(), "hi from 3") + assert.NoError(t, err) + + assertUnreadCount(chatA, user1Id, 1) + assertUnreadCount(chatB, user1Id, 1) + // Senders' own unread counts stay at zero. + assertUnreadCount(chatA, user2Id, 0) + assertUnreadCount(chatB, user3Id, 0) + + // Single call clears every unread chat for user1Id without touching + // the other members' chats. + readTs := time.Now() + err = chatReadAllMessages(pool, ctx, user1Id, readTs) + assert.NoError(t, err) + + assertUnreadCount(chatA, user1Id, 0) + assertUnreadCount(chatB, user1Id, 0) + + // Re-confirm: a stale read (older timestamp) does NOT roll back. + // Add a new unread, advance via chatReadAllMessages, then try a stale read. + err = chatSendMessage(pool, ctx, user2Id, chatA, strconv.Itoa(seededRand.Int()), time.Now(), "another from 2") + assert.NoError(t, err) + assertUnreadCount(chatA, user1Id, 1) + + freshTs := time.Now() + err = chatReadAllMessages(pool, ctx, user1Id, freshTs) + assert.NoError(t, err) + assertUnreadCount(chatA, user1Id, 0) + + // Older timestamp must be a no-op for last_active_at. + err = chatReadAllMessages(pool, ctx, user1Id, freshTs.Add(-time.Hour)) + assert.NoError(t, err) + var lastActive time.Time + err = pool.QueryRow(ctx, "select last_active_at from chat_member where chat_id = $1 and user_id = $2", chatA, user1Id).Scan(&lastActive) + assert.NoError(t, err) + assert.WithinDuration(t, freshTs.UTC(), lastActive.UTC(), time.Second, "stale read should not roll back last_active_at") +} diff --git a/api/comms/rpc_processor.go b/api/comms/rpc_processor.go index c3d4c11f..c05268bb 100644 --- a/api/comms/rpc_processor.go +++ b/api/comms/rpc_processor.go @@ -270,6 +270,13 @@ select last_active_at from chat_member where chat_id = $1 and user_id = $2` return err } } + case RPCMethodChatReadAll: + // No params to unmarshal. The per-row last_active_at guard lives + // inside chatReadAllMessages so we don't have to read first. + err = chatReadAllMessages(tx, ctx, userId, messageTs) + if err != nil { + return err + } case RPCMethodChatPermit: var params ChatPermitRPCParams err = json.Unmarshal(rawRpc.Params, ¶ms) diff --git a/api/comms/schema.go b/api/comms/schema.go index 3aac7523..1d527512 100644 --- a/api/comms/schema.go +++ b/api/comms/schema.go @@ -95,6 +95,13 @@ type ChatReadRPCParams struct { ChatID string `json:"chat_id"` } +type ChatReadAllRPC struct { + Method ChatReadAllRPCMethod `json:"method"` + Params ChatReadAllRPCParams `json:"params"` +} + +type ChatReadAllRPCParams struct{} + type ChatBlockRPC struct { Method ChatBlockRPCMethod `json:"method"` Params ChatBlockRPCParams `json:"params"` @@ -367,6 +374,12 @@ const ( MethodChatRead ChatReadRPCMethod = "chat.read" ) +type ChatReadAllRPCMethod string + +const ( + MethodChatReadAll ChatReadAllRPCMethod = "chat.read_all" +) + type ChatBlockRPCMethod string const ( @@ -410,6 +423,7 @@ const ( RPCMethodChatPermit RPCMethod = "chat.permit" RPCMethodChatReact RPCMethod = "chat.react" RPCMethodChatRead RPCMethod = "chat.read" + RPCMethodChatReadAll RPCMethod = "chat.read_all" RPCMethodChatUnblock RPCMethod = "chat.unblock" RPCMethodUserValidateCanChat RPCMethod = "user.validate_can_chat" ) diff --git a/api/comms/validator.go b/api/comms/validator.go index d425d517..1d2948d5 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -67,6 +67,9 @@ func (vtor *Validator) Validate(ctx context.Context, userId int32, rawRpc RawRPC return vtor.validateChatReact(vtor.pool, ctx, userId, rawRpc) case RPCMethodChatRead: return vtor.validateChatRead(userId, rawRpc) + case RPCMethodChatReadAll: + // No params to validate; ban check above already gates this call. + return nil case RPCMethodChatPermit: return vtor.validateChatPermit(userId, rawRpc) case RPCMethodChatBlock: