Skip to content

Commit 550762e

Browse files
committed
Subscription updates.
1 parent ec16362 commit 550762e

File tree

4 files changed

+81
-14
lines changed

4 files changed

+81
-14
lines changed

python/natsrpy/_natsrpy_rs/__init__.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Awaitable, Callable
12
from datetime import timedelta
23
from typing import Any
34

@@ -39,7 +40,11 @@ class Nats:
3940
async def request(self, subject: str, payload: bytes) -> None: ...
4041
async def drain(self) -> None: ...
4142
async def flush(self) -> None: ...
42-
async def subscribe(self, subject: str) -> Subscription: ...
43+
async def subscribe(
44+
self,
45+
subject: str,
46+
callback: Callable[[Message], Awaitable[None]] | None = None,
47+
) -> Subscription: ...
4348
async def jetstream(self) -> JetStream: ...
4449

4550
__all__ = ["Message", "Nats", "Subscription"]

src/message.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ pub struct Message {
1717
pub length: usize,
1818
}
1919

20-
impl TryFrom<async_nats::Message> for Message {
20+
impl TryFrom<&async_nats::Message> for Message {
2121
type Error = NatsrpyError;
2222

23-
fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
23+
fn try_from(value: &async_nats::Message) -> Result<Self, Self::Error> {
2424
Python::attach(move |gil| {
25-
let headers = match value.headers {
25+
let headers = match &value.headers {
2626
Some(headermap) => headermap.to_pydict(gil)?.unbind(),
2727
None => PyDict::new(gil).unbind(),
2828
};
@@ -32,13 +32,21 @@ impl TryFrom<async_nats::Message> for Message {
3232
payload: PyBytes::new(gil, &value.payload).unbind(),
3333
headers,
3434
status: value.status.map(Into::<u16>::into),
35-
description: value.description,
35+
description: value.description.clone(),
3636
length: value.length,
3737
})
3838
})
3939
}
4040
}
4141

42+
impl TryFrom<async_nats::Message> for Message {
43+
type Error = NatsrpyError;
44+
45+
fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
46+
Self::try_from(&value)
47+
}
48+
}
49+
4250
#[pyo3::pymethods]
4351
impl Message {
4452
#[must_use]

src/nats_cls.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage};
22
use pyo3::{
3-
Bound, PyAny, Python,
3+
Bound, Py, PyAny, Python,
44
types::{PyBytes, PyBytesMethods, PyDict},
55
};
66
use std::{sync::Arc, time::Duration};
@@ -197,16 +197,21 @@ impl NatsCls {
197197
})
198198
}
199199

200+
#[pyo3(signature=(subject, callback=None))]
200201
pub fn subscribe<'py>(
201202
&self,
202203
py: Python<'py>,
203204
subject: String,
205+
callback: Option<Py<PyAny>>,
204206
) -> NatsrpyResult<Bound<'py, PyAny>> {
205207
log::debug!("Subscribing to '{subject}'");
206208
let session = self.nats_session.clone();
207209
natsrpy_future(py, async move {
208210
if let Some(session) = session.read().await.as_ref() {
209-
Ok(Subscription::new(session.subscribe(subject).await?))
211+
Ok(Subscription::new(
212+
session.subscribe(subject).await?,
213+
callback,
214+
)?)
210215
} else {
211216
Err(NatsrpyError::NotInitialized)
212217
}

src/subscription.rs

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use futures_util::StreamExt;
22
use std::sync::Arc;
33

4-
use pyo3::{Bound, PyAny, PyRef, Python};
4+
use pyo3::{Bound, Py, PyAny, PyRef, Python};
55
use tokio::sync::Mutex;
66

77
use crate::{
@@ -12,14 +12,56 @@ use crate::{
1212
#[pyo3::pyclass]
1313
pub struct Subscription {
1414
inner: Option<Arc<Mutex<async_nats::Subscriber>>>,
15+
reading_task: Option<tokio::task::AbortHandle>,
16+
}
17+
18+
async fn process_message(message: async_nats::message::Message, py_callback: Py<PyAny>) {
19+
let task = async || -> NatsrpyResult<()> {
20+
let message = crate::message::Message::try_from(&message)?;
21+
let awaitable = Python::attach(|gil| -> NatsrpyResult<_> {
22+
let res = py_callback.call1(gil, (message,))?;
23+
let rust_task = pyo3_async_runtimes::tokio::into_future(res.into_bound(gil))?;
24+
Ok(rust_task)
25+
})?;
26+
awaitable.await?;
27+
Ok(())
28+
};
29+
if let Err(err) = task().await {
30+
log::error!("Cannot process message {message:?}. Error: {err}");
31+
}
32+
}
33+
34+
async fn start_py_sub(
35+
sub: Arc<Mutex<async_nats::Subscriber>>,
36+
py_callback: Py<PyAny>,
37+
locals: pyo3_async_runtimes::TaskLocals,
38+
) {
39+
while let Some(message) = sub.lock().await.next().await {
40+
let py_cb = Python::attach(|py| py_callback.clone_ref(py));
41+
tokio::spawn(pyo3_async_runtimes::tokio::scope(
42+
locals.clone(),
43+
process_message(message, py_cb),
44+
));
45+
}
1546
}
1647

1748
impl Subscription {
18-
#[must_use]
19-
pub fn new(sub: async_nats::Subscriber) -> Self {
20-
Self {
21-
inner: Some(Arc::new(Mutex::new(sub))),
22-
}
49+
pub fn new(sub: async_nats::Subscriber, callback: Option<Py<PyAny>>) -> NatsrpyResult<Self> {
50+
let sub = Arc::new(Mutex::new(sub));
51+
let cb_sub = sub.clone();
52+
let task_locals = Python::attach(pyo3_async_runtimes::tokio::get_current_locals)?;
53+
let task_handle = callback.map(move |cb| {
54+
tokio::task::spawn(pyo3_async_runtimes::tokio::scope(
55+
task_locals.clone(),
56+
start_py_sub(cb_sub, cb, task_locals),
57+
))
58+
.abort_handle()
59+
});
60+
61+
Ok(Self {
62+
inner: Some(sub),
63+
reading_task: task_handle,
64+
})
2365
}
2466
}
2567

@@ -38,11 +80,15 @@ impl Subscription {
3880
let Some(inner) = self.inner.clone() else {
3981
unreachable!("Subscription used after del")
4082
};
83+
if self.reading_task.is_some() {
84+
log::warn!(
85+
"Callback is set. Getting messages from this subscription might produce unpredictable results."
86+
);
87+
}
4188
natsrpy_future_with_timeout(py, timeout, async move {
4289
let Some(message) = inner.lock().await.next().await else {
4390
return Err(NatsrpyError::AsyncStopIteration);
4491
};
45-
4692
crate::message::Message::try_from(message)
4793
})
4894
}
@@ -96,6 +142,9 @@ impl Drop for Subscription {
96142
fn drop(&mut self) {
97143
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
98144
self.inner = None;
145+
if let Some(reading) = self.reading_task.take() {
146+
reading.abort();
147+
}
99148
});
100149
}
101150
}

0 commit comments

Comments
 (0)