Skip to content

Commit 81f8588

Browse files
authored
feat: provide blanket implementations for ClientHandler and ServerHandler traits (#609)
* feat!: implement ServerHandler for Box<H> and Arc<H> where H is a ServerHandler * feat!: implement ClientHandler for Box<H> and Arc<H> where H is a ClientHandler * test: test Box and Arc have blanket implementations for handler traits * refactor: deduplicate blanket implementations with macros
1 parent acb06ea commit 81f8588

4 files changed

Lines changed: 348 additions & 1 deletion

File tree

crates/rmcp/src/handler/client.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pub mod progress;
2+
use std::sync::Arc;
3+
24
use crate::{
35
error::ErrorData as McpError,
46
model::*,
@@ -210,3 +212,115 @@ impl ClientHandler for ClientInfo {
210212
self.clone()
211213
}
212214
}
215+
216+
macro_rules! impl_client_handler_for_wrapper {
217+
($wrapper:ident) => {
218+
impl<T: ClientHandler> ClientHandler for $wrapper<T> {
219+
fn ping(
220+
&self,
221+
context: RequestContext<RoleClient>,
222+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
223+
(**self).ping(context)
224+
}
225+
226+
fn create_message(
227+
&self,
228+
params: CreateMessageRequestParam,
229+
context: RequestContext<RoleClient>,
230+
) -> impl Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ {
231+
(**self).create_message(params, context)
232+
}
233+
234+
fn list_roots(
235+
&self,
236+
context: RequestContext<RoleClient>,
237+
) -> impl Future<Output = Result<ListRootsResult, McpError>> + Send + '_ {
238+
(**self).list_roots(context)
239+
}
240+
241+
fn create_elicitation(
242+
&self,
243+
request: CreateElicitationRequestParam,
244+
context: RequestContext<RoleClient>,
245+
) -> impl Future<Output = Result<CreateElicitationResult, McpError>> + Send + '_ {
246+
(**self).create_elicitation(request, context)
247+
}
248+
249+
fn on_custom_request(
250+
&self,
251+
request: CustomRequest,
252+
context: RequestContext<RoleClient>,
253+
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
254+
(**self).on_custom_request(request, context)
255+
}
256+
257+
fn on_cancelled(
258+
&self,
259+
params: CancelledNotificationParam,
260+
context: NotificationContext<RoleClient>,
261+
) -> impl Future<Output = ()> + Send + '_ {
262+
(**self).on_cancelled(params, context)
263+
}
264+
265+
fn on_progress(
266+
&self,
267+
params: ProgressNotificationParam,
268+
context: NotificationContext<RoleClient>,
269+
) -> impl Future<Output = ()> + Send + '_ {
270+
(**self).on_progress(params, context)
271+
}
272+
273+
fn on_logging_message(
274+
&self,
275+
params: LoggingMessageNotificationParam,
276+
context: NotificationContext<RoleClient>,
277+
) -> impl Future<Output = ()> + Send + '_ {
278+
(**self).on_logging_message(params, context)
279+
}
280+
281+
fn on_resource_updated(
282+
&self,
283+
params: ResourceUpdatedNotificationParam,
284+
context: NotificationContext<RoleClient>,
285+
) -> impl Future<Output = ()> + Send + '_ {
286+
(**self).on_resource_updated(params, context)
287+
}
288+
289+
fn on_resource_list_changed(
290+
&self,
291+
context: NotificationContext<RoleClient>,
292+
) -> impl Future<Output = ()> + Send + '_ {
293+
(**self).on_resource_list_changed(context)
294+
}
295+
296+
fn on_tool_list_changed(
297+
&self,
298+
context: NotificationContext<RoleClient>,
299+
) -> impl Future<Output = ()> + Send + '_ {
300+
(**self).on_tool_list_changed(context)
301+
}
302+
303+
fn on_prompt_list_changed(
304+
&self,
305+
context: NotificationContext<RoleClient>,
306+
) -> impl Future<Output = ()> + Send + '_ {
307+
(**self).on_prompt_list_changed(context)
308+
}
309+
310+
fn on_custom_notification(
311+
&self,
312+
notification: CustomNotification,
313+
context: NotificationContext<RoleClient>,
314+
) -> impl Future<Output = ()> + Send + '_ {
315+
(**self).on_custom_notification(notification, context)
316+
}
317+
318+
fn get_info(&self) -> ClientInfo {
319+
(**self).get_info()
320+
}
321+
}
322+
};
323+
}
324+
325+
impl_client_handler_for_wrapper!(Box);
326+
impl_client_handler_for_wrapper!(Arc);

crates/rmcp/src/handler/server.rs

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::{
24
error::ErrorData as McpError,
35
model::*,
@@ -327,3 +329,206 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
327329
std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
328330
}
329331
}
332+
333+
macro_rules! impl_server_handler_for_wrapper {
334+
($wrapper:ident) => {
335+
impl<T: ServerHandler> ServerHandler for $wrapper<T> {
336+
fn enqueue_task(
337+
&self,
338+
request: CallToolRequestParam,
339+
context: RequestContext<RoleServer>,
340+
) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
341+
(**self).enqueue_task(request, context)
342+
}
343+
344+
fn ping(
345+
&self,
346+
context: RequestContext<RoleServer>,
347+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
348+
(**self).ping(context)
349+
}
350+
351+
fn initialize(
352+
&self,
353+
request: InitializeRequestParam,
354+
context: RequestContext<RoleServer>,
355+
) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
356+
(**self).initialize(request, context)
357+
}
358+
359+
fn complete(
360+
&self,
361+
request: CompleteRequestParam,
362+
context: RequestContext<RoleServer>,
363+
) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
364+
(**self).complete(request, context)
365+
}
366+
367+
fn set_level(
368+
&self,
369+
request: SetLevelRequestParam,
370+
context: RequestContext<RoleServer>,
371+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
372+
(**self).set_level(request, context)
373+
}
374+
375+
fn get_prompt(
376+
&self,
377+
request: GetPromptRequestParam,
378+
context: RequestContext<RoleServer>,
379+
) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
380+
(**self).get_prompt(request, context)
381+
}
382+
383+
fn list_prompts(
384+
&self,
385+
request: Option<PaginatedRequestParam>,
386+
context: RequestContext<RoleServer>,
387+
) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
388+
(**self).list_prompts(request, context)
389+
}
390+
391+
fn list_resources(
392+
&self,
393+
request: Option<PaginatedRequestParam>,
394+
context: RequestContext<RoleServer>,
395+
) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
396+
(**self).list_resources(request, context)
397+
}
398+
399+
fn list_resource_templates(
400+
&self,
401+
request: Option<PaginatedRequestParam>,
402+
context: RequestContext<RoleServer>,
403+
) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_
404+
{
405+
(**self).list_resource_templates(request, context)
406+
}
407+
408+
fn read_resource(
409+
&self,
410+
request: ReadResourceRequestParam,
411+
context: RequestContext<RoleServer>,
412+
) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
413+
(**self).read_resource(request, context)
414+
}
415+
416+
fn subscribe(
417+
&self,
418+
request: SubscribeRequestParam,
419+
context: RequestContext<RoleServer>,
420+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
421+
(**self).subscribe(request, context)
422+
}
423+
424+
fn unsubscribe(
425+
&self,
426+
request: UnsubscribeRequestParam,
427+
context: RequestContext<RoleServer>,
428+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
429+
(**self).unsubscribe(request, context)
430+
}
431+
432+
fn call_tool(
433+
&self,
434+
request: CallToolRequestParam,
435+
context: RequestContext<RoleServer>,
436+
) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
437+
(**self).call_tool(request, context)
438+
}
439+
440+
fn list_tools(
441+
&self,
442+
request: Option<PaginatedRequestParam>,
443+
context: RequestContext<RoleServer>,
444+
) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
445+
(**self).list_tools(request, context)
446+
}
447+
448+
fn on_custom_request(
449+
&self,
450+
request: CustomRequest,
451+
context: RequestContext<RoleServer>,
452+
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
453+
(**self).on_custom_request(request, context)
454+
}
455+
456+
fn on_cancelled(
457+
&self,
458+
notification: CancelledNotificationParam,
459+
context: NotificationContext<RoleServer>,
460+
) -> impl Future<Output = ()> + Send + '_ {
461+
(**self).on_cancelled(notification, context)
462+
}
463+
464+
fn on_progress(
465+
&self,
466+
notification: ProgressNotificationParam,
467+
context: NotificationContext<RoleServer>,
468+
) -> impl Future<Output = ()> + Send + '_ {
469+
(**self).on_progress(notification, context)
470+
}
471+
472+
fn on_initialized(
473+
&self,
474+
context: NotificationContext<RoleServer>,
475+
) -> impl Future<Output = ()> + Send + '_ {
476+
(**self).on_initialized(context)
477+
}
478+
479+
fn on_roots_list_changed(
480+
&self,
481+
context: NotificationContext<RoleServer>,
482+
) -> impl Future<Output = ()> + Send + '_ {
483+
(**self).on_roots_list_changed(context)
484+
}
485+
486+
fn on_custom_notification(
487+
&self,
488+
notification: CustomNotification,
489+
context: NotificationContext<RoleServer>,
490+
) -> impl Future<Output = ()> + Send + '_ {
491+
(**self).on_custom_notification(notification, context)
492+
}
493+
494+
fn get_info(&self) -> ServerInfo {
495+
(**self).get_info()
496+
}
497+
498+
fn list_tasks(
499+
&self,
500+
request: Option<PaginatedRequestParam>,
501+
context: RequestContext<RoleServer>,
502+
) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
503+
(**self).list_tasks(request, context)
504+
}
505+
506+
fn get_task_info(
507+
&self,
508+
request: GetTaskInfoParam,
509+
context: RequestContext<RoleServer>,
510+
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
511+
(**self).get_task_info(request, context)
512+
}
513+
514+
fn get_task_result(
515+
&self,
516+
request: GetTaskResultParam,
517+
context: RequestContext<RoleServer>,
518+
) -> impl Future<Output = Result<TaskResult, McpError>> + Send + '_ {
519+
(**self).get_task_result(request, context)
520+
}
521+
522+
fn cancel_task(
523+
&self,
524+
request: CancelTaskParam,
525+
context: RequestContext<RoleServer>,
526+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
527+
(**self).cancel_task(request, context)
528+
}
529+
}
530+
};
531+
}
532+
533+
impl_server_handler_for_wrapper!(Box);
534+
impl_server_handler_for_wrapper!(Arc);

crates/rmcp/src/handler/server/router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,6 @@ where
133133
}
134134

135135
fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
136-
self.service.get_info()
136+
ServerHandler::get_info(&self.service)
137137
}
138138
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// cargo test --test test_handler_wrappers --features "client server"
2+
3+
mod common;
4+
5+
use std::sync::Arc;
6+
7+
use common::handlers::{TestClientHandler, TestServer};
8+
use rmcp::{ClientHandler, ServerHandler};
9+
10+
#[test]
11+
fn test_wrapped_server_handlers() {
12+
// This test asserts that, when T: ServerHandler, both Box<T> and Arc<T> also implement ServerHandler.
13+
fn accepts_server_handler<H: ServerHandler>(_handler: H) {}
14+
15+
accepts_server_handler(Box::new(TestServer::new()));
16+
accepts_server_handler(Arc::new(TestServer::new()));
17+
}
18+
19+
#[test]
20+
fn test_wrapped_client_handlers() {
21+
// This test asserts that, when T: ClientHandler, both Box<T> and Arc<T> also implement ClientHandler.
22+
fn accepts_client_handler<H: ClientHandler>(_handler: H) {}
23+
24+
let client = TestClientHandler::new(false, false);
25+
26+
accepts_client_handler(Box::new(client.clone()));
27+
accepts_client_handler(Arc::new(client));
28+
}

0 commit comments

Comments
 (0)