Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 114 additions & 2 deletions rust/tests/e2e/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, Permissi
use github_copilot_sdk::tool::ToolHandler;
use github_copilot_sdk::{
Error, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation,
ToolResult,
ToolResult, ToolSet,
};
use serde_json::json;
use tokio::sync::mpsc;
use tokio::sync::{Mutex, mpsc};

use super::support::{assistant_message_content, recv_with_timeout, with_e2e_context};

Expand Down Expand Up @@ -73,6 +73,55 @@ async fn invokes_custom_tool() {
.await;
}

#[tokio::test]
async fn low_level_tool_definition() {
with_e2e_context("tools", "low_level_tool_definition", |ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let __perm = Arc::new(ApproveAllHandler);
let current_phase = Arc::new(Mutex::new(String::new()));
let tools = vec![
set_current_phase_tool(current_phase.clone()),
search_items_tool(),
];
let available_tools = ToolSet::new()
.add_custom("*")
.expect("add custom wildcard")
.add_builtin("web_fetch")
.expect("add web_fetch")
.into_vec();
let session = client
.create_session(
SessionConfig::default()
.with_github_token(super::support::DEFAULT_TEST_TOKEN)
.with_permission_handler(__perm)
.with_tools(tools)
.with_available_tools(available_tools),
)
.await
.expect("create session");

let answer = session
.send_and_wait(
"First, set the current phase to 'analyzing'. Then search for items with keyword 'copilot'. Report the phase and search results.",
)
.await
.expect("send")
.expect("assistant message");
let content = assistant_message_content(&answer);
assert!(!content.is_empty());
assert!(content.to_lowercase().contains("analyzing"));
assert!(content.contains("item_alpha") || content.contains("item_beta"));
assert_eq!(current_phase.lock().await.clone(), "analyzing");

session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
})
.await;
}

#[tokio::test]
async fn handles_tool_calling_errors() {
with_e2e_context("tools", "handles_tool_calling_errors", |ctx| {
Expand Down Expand Up @@ -502,6 +551,69 @@ impl ToolHandler for ErrorTool {

struct CustomGrepTool;

struct SetCurrentPhaseTool {
current_phase: Arc<Mutex<String>>,
}

fn set_current_phase_tool(current_phase: Arc<Mutex<String>>) -> Tool {
Tool::new("set_current_phase")
.with_description("Sets the current phase of the agent")
.with_parameters(json!({
"type": "object",
"properties": {
"phase": {
"type": "string",
"description": "Current phase",
"pattern": "^(searching|analyzing|done)$"
}
},
"required": ["phase"]
}))
.with_handler(Arc::new(SetCurrentPhaseTool { current_phase }))
}

#[async_trait::async_trait]
impl ToolHandler for SetCurrentPhaseTool {
async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
let phase = invocation
.arguments
.get("phase")
.and_then(serde_json::Value::as_str)
.unwrap_or_default()
.to_string();
*self.current_phase.lock().await = phase.clone();
Ok(ToolResult::Text(format!("Phase set to {phase}")))
}
}

struct SearchItemsTool;

fn search_items_tool() -> Tool {
Tool::new("search_items")
.with_description("Search for items by keyword")
.with_parameters(json!({
"type": "object",
"properties": {
"keyword": { "type": "string" }
},
"required": ["keyword"]
}))
.with_handler(Arc::new(SearchItemsTool))
}

#[async_trait::async_trait]
impl ToolHandler for SearchItemsTool {
async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
let keyword = invocation
.arguments
.get("keyword")
.and_then(serde_json::Value::as_str)
.unwrap_or_default();
assert_eq!(keyword, "copilot");
Ok(ToolResult::Text("Found: item_alpha, item_beta".to_string()))
}
}

fn custom_grep_tool() -> Tool {
Tool::new("grep")
.with_description("A custom grep implementation that overrides the built-in")
Expand Down
Loading