diff --git a/taskiq_redis/redis_broker.py b/taskiq_redis/redis_broker.py index 9786f79..9d05828 100644 --- a/taskiq_redis/redis_broker.py +++ b/taskiq_redis/redis_broker.py @@ -161,6 +161,7 @@ def __init__( approximate: bool = True, idle_timeout: int = 600000, # 10 minutes unacknowledged_batch_size: int = 100, + unacknowledged_lock_timeout: float | None = None, xread_count: int | None = 100, additional_streams: dict[str, str | int] | None = None, **connection_kwargs: Any, @@ -188,8 +189,10 @@ def __init__( :param xread_count: number of messages to fetch from the stream at once. :param additional_streams: additional streams to read from. Each key is a stream name, value is a consumer id. - :param redeliver_timeout: time in ms to wait before redelivering a message. :param unacknowledged_batch_size: number of unacknowledged messages to fetch. + :param unacknowledged_lock_timeout: time in seconds before auto-releasing + the lock. Useful when the worker crashes or gets killed. + If not set, the lock can remain locked indefinitely. """ super().__init__( url, @@ -209,6 +212,7 @@ def __init__( self.additional_streams = additional_streams or {} self.idle_timeout = idle_timeout self.unacknowledged_batch_size = unacknowledged_batch_size + self.unacknowledged_lock_timeout = unacknowledged_lock_timeout self.count = xread_count async def _declare_consumer_group(self) -> None: @@ -290,6 +294,7 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]: for stream in [self.queue_name, *self.additional_streams.keys()]: lock = redis_conn.lock( f"autoclaim:{self.consumer_group_name}:{stream}", + timeout=self.unacknowledged_lock_timeout, ) if await lock.locked(): continue diff --git a/tests/test_broker.py b/tests/test_broker.py index e501acf..3ee4665 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -432,3 +432,38 @@ async def test_maxlen_in_sentinel_stream_broker( async with broker._acquire_master_conn() as redis_conn: assert await redis_conn.xlen(broker.queue_name) == maxlen await broker.shutdown() + + +@pytest.mark.anyio +async def test_unacknowledged_lock_timeout_in_stream_broker( + redis_url: str, + valid_broker_message: BrokerMessage, +) -> None: + unacknowledged_lock_timeout = 1 + queue_name = uuid.uuid4().hex + consumer_group_name = uuid.uuid4().hex + + broker = RedisStreamBroker( + url=redis_url, + approximate=False, + queue_name=queue_name, + consumer_group_name=consumer_group_name, + unacknowledged_lock_timeout=unacknowledged_lock_timeout, + ) + + await broker.startup() + await broker.kick(valid_broker_message) + + message = await get_message(broker) + assert isinstance(message, AckableMessage) + assert message.data == valid_broker_message.message + + async with Redis(connection_pool=broker.connection_pool) as redis: + lock_key = f"autoclaim:{consumer_group_name}:{queue_name}" + await redis.exists(lock_key) + await asyncio.sleep(unacknowledged_lock_timeout + 0.5) + + lock_exists_after_timeout = await redis.exists(lock_key) + assert lock_exists_after_timeout == 0, "Lock should be released after timeout" + + await broker.shutdown()