diff --git a/Cargo.lock b/Cargo.lock index 433c12fd5..62fd50a36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1577,6 +1577,7 @@ dependencies = [ "time", "tokio 1.52.3", "tokio-rustls", + "tokio-tungstenite 0.26.2", "tracing", "url", "uuid", diff --git a/devolutions-agent/Cargo.toml b/devolutions-agent/Cargo.toml index 4a09d0aff..661d3c26e 100644 --- a/devolutions-agent/Cargo.toml +++ b/devolutions-agent/Cargo.toml @@ -41,6 +41,7 @@ sha2 = "0.10" serde_json = "1" serde = { version = "1", features = ["derive"] } tap = "1.0" +tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"] } tracing = "0.1" url = { version = "2.5", features = ["serde"] } @@ -67,6 +68,7 @@ features = [ "parking_lot", "fs", "process", + "sync", ] [target.'cfg(windows)'.dependencies] diff --git a/devolutions-agent/src/config.rs b/devolutions-agent/src/config.rs index c773d7e93..a9560c48c 100644 --- a/devolutions-agent/src/config.rs +++ b/devolutions-agent/src/config.rs @@ -21,6 +21,7 @@ pub struct Conf { pub pedm: dto::PedmConf, pub session: dto::SessionConf, pub tunnel: TunnelConf, + pub psu_event_hub: dto::PsuEventHubConf, pub proxy: dto::ProxyConf, pub debug: dto::DebugConf, } @@ -122,6 +123,7 @@ impl Conf { remote_desktop, pedm: conf_file.pedm.clone().unwrap_or_default(), session: conf_file.session.clone().unwrap_or_default(), + psu_event_hub: conf_file.psu_event_hub.clone().unwrap_or_default(), tunnel: conf_file .tunnel .clone() @@ -268,7 +270,7 @@ fn load_conf_file(conf_path: &Utf8Path) -> anyhow::Result> pub fn load_conf_file_or_generate_new() -> anyhow::Result { let conf_file_path = get_conf_file_path(); - let conf_file = match load_conf_file(&conf_file_path).context("failed to load configuration")? { + let mut conf_file = match load_conf_file(&conf_file_path).context("failed to load configuration")? { Some(conf_file) => conf_file, None => { let defaults = dto::ConfFile::generate_new(); @@ -278,9 +280,243 @@ pub fn load_conf_file_or_generate_new() -> anyhow::Result { } }; + merge_psu_event_hub_compat_config(&mut conf_file) + .context("failed to load PowerShell Universal agent configuration")?; + Ok(conf_file) } +fn merge_psu_event_hub_compat_config(conf_file: &mut dto::ConfFile) -> anyhow::Result<()> { + let Some(compat_conf) = load_psu_event_hub_compat_config()? else { + return Ok(()); + }; + + match &mut conf_file.psu_event_hub { + None => conf_file.psu_event_hub = Some(compat_conf), + Some(current) if current.enabled && current.connections.is_empty() => { + current.connections = compat_conf.connections; + } + Some(_) => {} + } + + Ok(()) +} + +fn load_psu_event_hub_compat_config() -> anyhow::Result> { + let mut connections = Vec::new(); + + for path in psu_event_hub_compat_config_paths() { + let Some(file) = load_psu_event_hub_compat_file(&path)? else { + continue; + }; + + if !file.connections.is_empty() { + connections = file.connections; + } + } + + apply_psu_event_hub_env_overrides(&mut connections)?; + + if connections.is_empty() { + return Ok(None); + } + + Ok(Some(dto::PsuEventHubConf { + enabled: true, + connections, + power_shell: dto::PsuPowerShellConf::default(), + })) +} + +fn load_psu_event_hub_compat_file(path: &Utf8Path) -> anyhow::Result> { + match File::open(path) { + Ok(file) => BufReader::new(file) + .pipe(serde_json::from_reader) + .map(Some) + .with_context(|| format!("invalid PowerShell Universal agent config file at {path}")), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(error) => Err(anyhow::anyhow!(error).context(format!( + "couldn't open PowerShell Universal agent config file at {path}" + ))), + } +} + +fn psu_event_hub_compat_config_paths() -> Vec { + let mut paths = Vec::new(); + + if let Some(program_data) = env_path("ProgramData") { + paths.push(program_data.join("PowerShellUniversal").join("eventHubClient.json")); + paths.push(program_data.join("PowerShellUniversal").join("agent.json")); + } + + if let Some(app_data) = env_path("APPDATA") { + paths.push(app_data.join("PowerShellUniversal").join("agent.json")); + } + + paths +} + +fn env_path(name: &str) -> Option { + std::env::var_os(name).and_then(|path| Utf8PathBuf::from_path_buf(path.into()).ok()) +} + +#[derive(Default)] +struct PsuEventHubConnectionPatch { + hub: Option, + url: Option, + app_token: Option>, + use_default_credentials: Option, + script_path: Option>, + description: Option>, +} + +impl PsuEventHubConnectionPatch { + fn apply_to(&self, connection: &mut dto::PsuEventHubConnectionConf) -> anyhow::Result<()> { + if let Some(hub) = &self.hub { + connection.hub = hub.clone(); + } + if let Some(url) = &self.url { + connection.url = + Url::parse(url).with_context(|| format!("invalid PSU Event Hub URL from environment: {url}"))?; + } + if let Some(app_token) = &self.app_token { + connection.app_token = app_token.clone(); + } + if let Some(use_default_credentials) = self.use_default_credentials { + connection.use_default_credentials = use_default_credentials; + } + if let Some(script_path) = &self.script_path { + connection.script_path = script_path.clone(); + } + if let Some(description) = &self.description { + connection.description = description.clone(); + } + + Ok(()) + } + + fn try_build(&self) -> anyhow::Result> { + let (Some(hub), Some(url)) = (&self.hub, &self.url) else { + return Ok(None); + }; + + Ok(Some(dto::PsuEventHubConnectionConf { + hub: hub.clone(), + url: Url::parse(url).with_context(|| format!("invalid PSU Event Hub URL from environment: {url}"))?, + app_token: self.app_token.clone().flatten(), + use_default_credentials: self.use_default_credentials.unwrap_or(false), + script_path: self.script_path.clone().flatten(), + description: self.description.clone().flatten(), + })) + } + + fn is_empty(&self) -> bool { + self.hub.is_none() + && self.url.is_none() + && self.app_token.is_none() + && self.use_default_credentials.is_none() + && self.script_path.is_none() + && self.description.is_none() + } +} + +fn apply_psu_event_hub_env_overrides(connections: &mut Vec) -> anyhow::Result<()> { + let mut patches = std::collections::BTreeMap::::new(); + + for (name, value) in std::env::vars() { + let Some(key) = name.strip_prefix("PSU_") else { + continue; + }; + + let key = key.replace("__", ":"); + if let Some((index, field)) = parse_psu_connection_env_key(&key)? { + apply_psu_connection_patch_field(patches.entry(index).or_default(), field, value)?; + } else if let Some(field) = psu_connection_field_name(&key) { + apply_psu_connection_patch_field(patches.entry(0).or_default(), field, value)?; + } + } + + for (index, patch) in patches { + if patch.is_empty() { + continue; + } + + if let Some(connection) = connections.get_mut(index) { + patch.apply_to(connection)?; + } else if let Some(connection) = patch.try_build()? { + connections.push(connection); + } + } + + Ok(()) +} + +fn parse_psu_connection_env_key(key: &str) -> anyhow::Result> { + let parts = key.split(':').collect::>(); + if parts.len() != 3 || !parts[0].eq_ignore_ascii_case("Connections") { + return Ok(None); + } + + let index = parts[1] + .parse::() + .with_context(|| format!("invalid PSU connection environment index: {}", parts[1]))?; + let Some(field) = psu_connection_field_name(parts[2]) else { + return Ok(None); + }; + + Ok(Some((index, field))) +} + +fn psu_connection_field_name(key: &str) -> Option<&'static str> { + if key.eq_ignore_ascii_case("Hub") { + Some("Hub") + } else if key.eq_ignore_ascii_case("Url") { + Some("Url") + } else if key.eq_ignore_ascii_case("AppToken") { + Some("AppToken") + } else if key.eq_ignore_ascii_case("UseDefaultCredentials") { + Some("UseDefaultCredentials") + } else if key.eq_ignore_ascii_case("ScriptPath") { + Some("ScriptPath") + } else if key.eq_ignore_ascii_case("Description") { + Some("Description") + } else { + None + } +} + +fn apply_psu_connection_patch_field( + patch: &mut PsuEventHubConnectionPatch, + field: &str, + value: String, +) -> anyhow::Result<()> { + match field { + "Hub" => patch.hub = Some(value), + "Url" => patch.url = Some(value), + "AppToken" => patch.app_token = Some(non_empty_string(value)), + "UseDefaultCredentials" => patch.use_default_credentials = Some(parse_psu_bool(&value)?), + "ScriptPath" => patch.script_path = Some(non_empty_string(value).map(Utf8PathBuf::from)), + "Description" => patch.description = Some(non_empty_string(value)), + _ => unreachable!("unsupported PSU Event Hub connection field"), + } + + Ok(()) +} + +fn non_empty_string(value: String) -> Option { + if value.is_empty() { None } else { Some(value) } +} + +fn parse_psu_bool(value: &str) -> anyhow::Result { + if value.eq_ignore_ascii_case("true") || value == "1" || value.eq_ignore_ascii_case("yes") { + Ok(true) + } else if value.eq_ignore_ascii_case("false") || value == "0" || value.eq_ignore_ascii_case("no") { + Ok(false) + } else { + bail!("invalid PSU boolean environment value: {value}"); + } +} + pub mod dto { use devolutions_agent_shared::UpdateProductKey; @@ -481,6 +717,98 @@ pub mod dto { pub server_spki_sha256: Option, } + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct PsuEventHubConf { + /// Enable PowerShell Universal Event Hub compatibility. + pub enabled: bool, + + /// Event Hub connections to maintain. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub connections: Vec, + + /// PowerShell worker process configuration. + #[serde(default, skip_serializing_if = "PsuPowerShellConf::is_default")] + pub power_shell: PsuPowerShellConf, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct PsuEventHubCompatFile { + #[serde(default)] + pub connections: Vec, + } + + #[allow(clippy::derivable_impls)] // Just to be explicit about default disabled behavior. + impl Default for PsuEventHubConf { + fn default() -> Self { + Self { + enabled: false, + connections: Vec::new(), + power_shell: PsuPowerShellConf::default(), + } + } + } + + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct PsuEventHubConnectionConf { + pub hub: String, + pub url: Url, + #[serde(skip_serializing_if = "Option::is_none")] + pub app_token: Option, + #[serde(default)] + pub use_default_credentials: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub script_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + } + + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct PsuPowerShellConf { + #[serde(skip_serializing_if = "Option::is_none")] + pub executable_path: Option, + #[serde(default)] + pub use_windows_power_shell: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub version_selector: Option, + #[serde(default = "default_worker_pool_size")] + pub worker_pool_size: usize, + #[serde(default = "default_max_worker_pool_size")] + pub max_worker_pool_size: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub virtual_environment: Option, + } + + impl Default for PsuPowerShellConf { + fn default() -> Self { + Self { + executable_path: None, + use_windows_power_shell: false, + version_selector: None, + worker_pool_size: default_worker_pool_size(), + max_worker_pool_size: default_max_worker_pool_size(), + virtual_environment: None, + } + } + } + + impl PsuPowerShellConf { + pub fn is_default(&self) -> bool { + Self::default().eq(self) + } + } + + fn default_worker_pool_size() -> usize { + 1 + } + + fn default_max_worker_pool_size() -> usize { + 25 + } + fn default_true() -> bool { true } @@ -538,6 +866,10 @@ pub mod dto { #[serde(skip_serializing_if = "Option::is_none")] pub tunnel: Option, + /// PowerShell Universal Event Hub compatibility. + #[serde(skip_serializing_if = "Option::is_none")] + pub psu_event_hub: Option, + /// HTTP/SOCKS proxy configuration for outbound requests #[serde(skip_serializing_if = "Option::is_none")] pub proxy: Option, @@ -568,6 +900,7 @@ pub mod dto { debug: None, session: Some(SessionConf { enabled: false }), tunnel: None, + psu_event_hub: None, rest: serde_json::Map::new(), } } @@ -748,3 +1081,261 @@ pub fn handle_cli(command: &str) -> Result<(), anyhow::Error> { Ok(()) } + +#[cfg(test)] +mod tests { + use std::ffi::OsString; + + use parking_lot::{Mutex, MutexGuard}; + + use super::*; + + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + struct EnvGuard { + _guard: MutexGuard<'static, ()>, + saved: Vec<(OsString, Option)>, + } + + impl EnvGuard { + fn new(vars: &[(&str, &str)]) -> Self { + let guard = ENV_LOCK.lock(); + let mut saved = std::env::vars_os() + .filter(|(name, _)| { + let name = name.to_string_lossy(); + name == "ProgramData" || name == "APPDATA" || name.starts_with("PSU_") + }) + .map(|(name, value)| (name, Some(value))) + .collect::>(); + + for (name, _) in &saved { + // SAFETY: These tests hold ENV_LOCK while mutating process environment. + unsafe { + std::env::remove_var(name); + } + } + + for (name, value) in vars { + let name = OsString::from(name); + if !saved.iter().any(|(saved_name, _)| saved_name == &name) { + saved.push((name.clone(), None)); + } + // SAFETY: These tests hold ENV_LOCK while mutating process environment. + unsafe { + std::env::set_var(name, value); + } + } + + Self { _guard: guard, saved } + } + } + + impl Drop for EnvGuard { + fn drop(&mut self) { + for (name, value) in &self.saved { + match value { + Some(value) => { + // SAFETY: These tests hold ENV_LOCK while mutating process environment. + unsafe { + std::env::set_var(name, value); + } + } + None => { + // SAFETY: These tests hold ENV_LOCK while mutating process environment. + unsafe { + std::env::remove_var(name); + } + } + } + } + } + } + + #[test] + fn psu_event_hub_config_is_disabled_by_default() { + let conf = Conf::from_conf_file(&dto::ConfFile::generate_new()).expect("load generated config"); + assert!(!conf.psu_event_hub.enabled); + assert!(conf.psu_event_hub.connections.is_empty()); + } + + #[test] + fn psu_event_hub_config_deserializes() { + let conf_file: dto::ConfFile = serde_json::from_value(serde_json::json!({ + "PsuEventHub": { + "Enabled": true, + "Connections": [ + { + "Hub": "Hub", + "Url": "http://localhost:5000", + "AppToken": "token", + "UseDefaultCredentials": false, + "ScriptPath": "event.ps1", + "Description": "test agent" + } + ], + "PowerShell": { + "VersionSelector": "7.4", + "WorkerPoolSize": 1, + "MaxWorkerPoolSize": 25 + } + } + })) + .expect("deserialize config"); + + let conf = Conf::from_conf_file(&conf_file).expect("load config"); + assert!(conf.psu_event_hub.enabled); + assert_eq!(conf.psu_event_hub.connections[0].hub, "Hub"); + assert_eq!(conf.psu_event_hub.power_shell.version_selector.as_deref(), Some("7.4")); + } + + #[test] + fn psu_event_hub_imports_compat_config_when_missing() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let program_data = Utf8PathBuf::from_path_buf(temp_dir.path().to_owned()).expect("temp path is UTF-8"); + let psu_dir = program_data.join("PowerShellUniversal"); + std::fs::create_dir_all(&psu_dir).expect("create PSU dir"); + std::fs::write( + psu_dir.join("eventHubClient.json"), + r#"{"Connections":[{"Hub":"Compat","Url":"http://localhost:5000"}]}"#, + ) + .expect("write compat config"); + + let _env = EnvGuard::new(&[ + ("ProgramData", program_data.as_str()), + ("APPDATA", program_data.as_str()), + ]); + let mut conf_file = dto::ConfFile::generate_new(); + + merge_psu_event_hub_compat_config(&mut conf_file).expect("merge compat config"); + + let psu_event_hub = conf_file.psu_event_hub.expect("compat config"); + assert!(psu_event_hub.enabled); + assert_eq!(psu_event_hub.connections[0].hub, "Compat"); + } + + #[test] + fn psu_event_hub_imports_compat_connections_when_enabled_empty() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let program_data = Utf8PathBuf::from_path_buf(temp_dir.path().to_owned()).expect("temp path is UTF-8"); + let psu_dir = program_data.join("PowerShellUniversal"); + std::fs::create_dir_all(&psu_dir).expect("create PSU dir"); + std::fs::write( + psu_dir.join("eventHubClient.json"), + r#"{"Connections":[{"Hub":"Compat","Url":"http://localhost:5000"}]}"#, + ) + .expect("write compat config"); + + let _env = EnvGuard::new(&[ + ("ProgramData", program_data.as_str()), + ("APPDATA", program_data.as_str()), + ]); + let mut conf_file = dto::ConfFile::generate_new(); + conf_file.psu_event_hub = Some(dto::PsuEventHubConf { + enabled: true, + connections: Vec::new(), + power_shell: dto::PsuPowerShellConf::default(), + }); + + merge_psu_event_hub_compat_config(&mut conf_file).expect("merge compat config"); + + let psu_event_hub = conf_file.psu_event_hub.expect("compat config"); + assert!(psu_event_hub.enabled); + assert_eq!(psu_event_hub.connections[0].hub, "Compat"); + } + + #[test] + fn psu_event_hub_explicit_connections_win_over_compat_config() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let program_data = Utf8PathBuf::from_path_buf(temp_dir.path().to_owned()).expect("temp path is UTF-8"); + let psu_dir = program_data.join("PowerShellUniversal"); + std::fs::create_dir_all(&psu_dir).expect("create PSU dir"); + std::fs::write( + psu_dir.join("eventHubClient.json"), + r#"{"Connections":[{"Hub":"Compat","Url":"http://localhost:5000"}]}"#, + ) + .expect("write compat config"); + + let _env = EnvGuard::new(&[ + ("ProgramData", program_data.as_str()), + ("APPDATA", program_data.as_str()), + ]); + let mut conf_file: dto::ConfFile = serde_json::from_value(serde_json::json!({ + "PsuEventHub": { + "Enabled": true, + "Connections": [{"Hub":"Explicit","Url":"http://localhost:5001"}] + } + })) + .expect("deserialize config"); + + merge_psu_event_hub_compat_config(&mut conf_file).expect("merge compat config"); + + let psu_event_hub = conf_file.psu_event_hub.expect("compat config"); + assert_eq!(psu_event_hub.connections[0].hub, "Explicit"); + } + + #[test] + fn psu_event_hub_explicit_disabled_config_stays_disabled() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let program_data = Utf8PathBuf::from_path_buf(temp_dir.path().to_owned()).expect("temp path is UTF-8"); + let psu_dir = program_data.join("PowerShellUniversal"); + std::fs::create_dir_all(&psu_dir).expect("create PSU dir"); + std::fs::write( + psu_dir.join("eventHubClient.json"), + r#"{"Connections":[{"Hub":"Compat","Url":"http://localhost:5000"}]}"#, + ) + .expect("write compat config"); + + let _env = EnvGuard::new(&[ + ("ProgramData", program_data.as_str()), + ("APPDATA", program_data.as_str()), + ]); + let mut conf_file = dto::ConfFile::generate_new(); + conf_file.psu_event_hub = Some(dto::PsuEventHubConf::default()); + + merge_psu_event_hub_compat_config(&mut conf_file).expect("merge compat config"); + + let psu_event_hub = conf_file.psu_event_hub.expect("compat config"); + assert!(!psu_event_hub.enabled); + assert!(psu_event_hub.connections.is_empty()); + } + + #[test] + fn psu_event_hub_reads_scalar_env_connection() { + let _env = EnvGuard::new(&[ + ("PSU_Hub", "EnvHub"), + ("PSU_Url", "http://localhost:5000"), + ("PSU_AppToken", "token"), + ("PSU_UseDefaultCredentials", "true"), + ("PSU_ScriptPath", "event.ps1"), + ("PSU_Description", "env agent"), + ]); + + let compat = load_psu_event_hub_compat_config() + .expect("load compat config") + .expect("env compat config"); + + assert!(compat.enabled); + assert_eq!(compat.connections[0].hub, "EnvHub"); + assert_eq!(compat.connections[0].app_token.as_deref(), Some("token")); + assert!(compat.connections[0].use_default_credentials); + assert_eq!( + compat.connections[0].script_path.as_deref(), + Some(Utf8Path::new("event.ps1")) + ); + assert_eq!(compat.connections[0].description.as_deref(), Some("env agent")); + } + + #[test] + fn psu_event_hub_reads_indexed_env_connection() { + let _env = EnvGuard::new(&[ + ("PSU_Connections__0__Hub", "IndexedHub"), + ("PSU_Connections__0__Url", "http://localhost:5000"), + ]); + + let compat = load_psu_event_hub_compat_config() + .expect("load compat config") + .expect("env compat config"); + + assert_eq!(compat.connections[0].hub, "IndexedHub"); + } +} diff --git a/devolutions-agent/src/lib.rs b/devolutions-agent/src/lib.rs index 970637c5d..894ed3a89 100644 --- a/devolutions-agent/src/lib.rs +++ b/devolutions-agent/src/lib.rs @@ -9,6 +9,7 @@ pub mod config; pub mod domain_detect; pub mod enrollment; pub mod log; +pub mod psu_event_hub; pub mod remote_desktop; pub mod tunnel; mod tunnel_helpers; diff --git a/devolutions-agent/src/psu_event_hub/executor.rs b/devolutions-agent/src/psu_event_hub/executor.rs new file mode 100644 index 000000000..8769a42d1 --- /dev/null +++ b/devolutions-agent/src/psu_event_hub/executor.rs @@ -0,0 +1,146 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use camino::Utf8PathBuf; +use serde_json::Value; +use uuid::Uuid; + +use crate::config::dto::{PsuEventHubConnectionConf, PsuPowerShellConf}; +use crate::psu_event_hub::models::WebsocketEventResponse; +use crate::psu_event_hub::powershell_worker::PowerShellWorker; +use crate::psu_event_hub::result_store::ResultStore; + +#[derive(Debug, Clone)] +pub(super) struct EventHubExecutor { + hub: String, + script_path: Option, + worker: Arc, + result_store: ResultStore, +} + +impl EventHubExecutor { + pub(super) fn new(connection: &PsuEventHubConnectionConf, power_shell: PsuPowerShellConf) -> Self { + Self { + hub: connection.hub.clone(), + script_path: connection.script_path.as_ref().map(normalize_script_path), + worker: Arc::new(PowerShellWorker::new(power_shell)), + result_store: ResultStore::default(), + } + } + + pub(super) fn handle_invocation(&self, target: &str, arguments: &[Value]) -> anyhow::Result> { + if target == "GetResult" { + let execution_id = required_string_argument(arguments, 0, "event id")?; + let result = self.result_store.take(execution_id); + return serde_json::to_value(result) + .map(Some) + .context("failed to serialize PSU GetResult response"); + } + + if target == self.hub { + let data = required_string_argument(arguments, 0, "event data")?.to_owned(); + let execution_id = self.execute_script(data, true); + return Ok(Some(Value::String(execution_id))); + } + + if target == format!("{}Void", self.hub) { + let data = required_string_argument(arguments, 0, "event data")?.to_owned(); + self.execute_script(data, false); + return Ok(None); + } + + if target == format!("{}Module", self.hub) { + let command = required_string_argument(arguments, 0, "command")?.to_owned(); + let data = required_string_argument(arguments, 1, "event data")?.to_owned(); + let execution_id = self.execute_command(command, data, true); + return Ok(Some(Value::String(execution_id))); + } + + if target == format!("{}ModuleVoid", self.hub) { + let command = required_string_argument(arguments, 0, "command")?.to_owned(); + let data = required_string_argument(arguments, 1, "event data")?.to_owned(); + self.execute_command(command, data, false); + return Ok(None); + } + + warn!(%target, hub = %self.hub, "Received unknown PSU Event Hub invocation"); + Ok(None) + } + + fn execute_command(&self, command: String, data: String, return_result: bool) -> String { + let execution_id = Uuid::new_v4().to_string(); + let worker = Arc::clone(&self.worker); + let result_store = self.result_store.clone(); + let stored_execution_id = execution_id.clone(); + + tokio::spawn(async move { + match worker.execute_command(command, data, return_result).await { + Ok(response) if return_result => result_store.insert(stored_execution_id, response), + Ok(_) => {} + Err(error) if return_result => { + result_store.insert( + stored_execution_id, + WebsocketEventResponse::terminating_error(error.to_string()), + ); + } + Err(error) => warn!(error = format!("{error:#}"), "PSU command execution failed"), + } + }); + + execution_id + } + + fn execute_script(&self, data: String, return_result: bool) -> String { + let execution_id = Uuid::new_v4().to_string(); + let Some(script_path) = self.script_path.clone() else { + if return_result { + self.result_store.insert( + execution_id.clone(), + WebsocketEventResponse::terminating_error("No script block found."), + ); + } + return execution_id; + }; + + let worker = Arc::clone(&self.worker); + let result_store = self.result_store.clone(); + let stored_execution_id = execution_id.clone(); + + tokio::spawn(async move { + match worker.execute_script(script_path, data, return_result).await { + Ok(response) if return_result => result_store.insert(stored_execution_id, response), + Ok(_) => {} + Err(error) if return_result => { + result_store.insert( + stored_execution_id, + WebsocketEventResponse::terminating_error(error.to_string()), + ); + } + Err(error) => warn!(error = format!("{error:#}"), "PSU script execution failed"), + } + }); + + execution_id + } +} + +fn required_string_argument<'a>(arguments: &'a [Value], index: usize, name: &str) -> anyhow::Result<&'a str> { + arguments + .get(index) + .and_then(Value::as_str) + .with_context(|| format!("missing or invalid PSU invocation argument: {name}")) +} + +fn normalize_script_path(path: &Utf8PathBuf) -> Utf8PathBuf { + if path.is_absolute() { + return path.clone(); + } + + if let Some(program_data) = + std::env::var_os("ProgramData").and_then(|path| Utf8PathBuf::from_path_buf(path.into()).ok()) + { + return program_data.join("PowerShellUniversal").join(path); + } + + path.clone() +} diff --git a/devolutions-agent/src/psu_event_hub/mod.rs b/devolutions-agent/src/psu_event_hub/mod.rs new file mode 100644 index 000000000..845525ee1 --- /dev/null +++ b/devolutions-agent/src/psu_event_hub/mod.rs @@ -0,0 +1,86 @@ +mod executor; +mod models; +mod powershell_worker; +mod result_store; +mod signalr; + +use async_trait::async_trait; +use devolutions_gateway_task::{ShutdownSignal, Task}; +use tokio::task::JoinSet; + +use crate::config::ConfHandle; +use crate::psu_event_hub::executor::EventHubExecutor; +use crate::psu_event_hub::powershell_worker::PowerShellWorker; + +pub struct PsuEventHubTask { + conf_handle: ConfHandle, +} + +impl PsuEventHubTask { + pub fn new(conf_handle: ConfHandle) -> Self { + Self { conf_handle } + } +} + +#[async_trait] +impl Task for PsuEventHubTask { + type Output = anyhow::Result<()>; + + const NAME: &'static str = "psu event hub"; + + async fn run(self, shutdown_signal: ShutdownSignal) -> anyhow::Result<()> { + let conf = self.conf_handle.get_conf(); + let psu_conf = conf.psu_event_hub.clone(); + + if psu_conf.connections.is_empty() { + warn!("PSU Event Hub feature is enabled, but no connections are configured"); + return Ok(()); + } + + info!( + connection_count = psu_conf.connections.len(), + "Starting PSU Event Hub compatibility feature" + ); + + let mut join_set = JoinSet::new(); + + let secret_resolver = PowerShellWorker::new(psu_conf.power_shell.clone()); + + for mut connection in psu_conf.connections { + if connection.hub.trim().is_empty() { + warn!(url = %connection.url, "Skipping PSU Event Hub connection without a hub name"); + continue; + } + + if let Some(app_token) = connection.app_token.as_deref() { + match secret_resolver.resolve_app_token(app_token).await { + Ok(resolved) => connection.app_token = Some(resolved), + Err(error) => { + error!( + hub = %connection.hub, + error = format!("{error:#}"), + "Skipping PSU Event Hub connection because AppToken secret resolution failed" + ); + continue; + } + } + } + + let executor = EventHubExecutor::new(&connection, psu_conf.power_shell.clone()); + let connection_shutdown_signal = shutdown_signal.clone(); + + join_set + .spawn(async move { signalr::run_connection(connection, executor, connection_shutdown_signal).await }); + } + + while let Some(result) = join_set.join_next().await { + match result { + Ok(Ok(())) => trace!("PSU Event Hub connection task terminated gracefully"), + Ok(Err(error)) => error!(error = format!("{error:#}"), "PSU Event Hub connection task failed"), + Err(error) => error!(%error, "PSU Event Hub connection task panicked"), + } + } + + Ok(()) + } +} diff --git a/devolutions-agent/src/psu_event_hub/models.rs b/devolutions-agent/src/psu_event_hub/models.rs new file mode 100644 index 000000000..dddd6106a --- /dev/null +++ b/devolutions-agent/src/psu_event_hub/models.rs @@ -0,0 +1,160 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(super) struct WebsocketEventResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(default)] + pub job_outputs: Vec, + #[serde(default)] + pub complete: bool, + #[serde(default)] + pub timeout: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub terminating_error: Option, +} + +impl WebsocketEventResponse { + pub(super) fn pending() -> Self { + Self { + data: None, + job_outputs: Vec::new(), + complete: false, + timeout: false, + terminating_error: None, + } + } + + pub(super) fn terminating_error(message: impl Into) -> Self { + Self { + data: None, + job_outputs: Vec::new(), + complete: true, + timeout: false, + terminating_error: Some(message.into()), + } + } +} + +impl Default for WebsocketEventResponse { + fn default() -> Self { + Self::pending() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(super) struct JobOutput { + #[serde(default)] + pub id: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + #[serde(rename = "type")] + pub output_type: JobOutputType, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(default)] + pub timestamp: String, + #[serde(default)] + pub job_id: i64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(super) enum JobOutputType { + Information = 0, + Verbose = 1, + Debug = 2, + Warning = 3, + Error = 4, + Progress = 5, +} + +impl JobOutputType { + fn as_u8(self) -> u8 { + self as u8 + } + + fn from_u8(value: u8) -> Option { + match value { + 0 => Some(Self::Information), + 1 => Some(Self::Verbose), + 2 => Some(Self::Debug), + 3 => Some(Self::Warning), + 4 => Some(Self::Error), + 5 => Some(Self::Progress), + _ => None, + } + } +} + +impl Serialize for JobOutputType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_u8(self.as_u8()) + } +} + +impl<'de> Deserialize<'de> for JobOutputType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = JobOutputType; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a PSU JobOutputType numeric value or name") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + let value = u8::try_from(value).map_err(|_| E::custom("JobOutputType value is out of range"))?; + JobOutputType::from_u8(value).ok_or_else(|| E::custom("unknown JobOutputType value")) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "Information" => Ok(JobOutputType::Information), + "Verbose" => Ok(JobOutputType::Verbose), + "Debug" => Ok(JobOutputType::Debug), + "Warning" => Ok(JobOutputType::Warning), + "Error" => Ok(JobOutputType::Error), + "Progress" => Ok(JobOutputType::Progress), + _ => Err(E::custom("unknown JobOutputType name")), + } + } + } + + deserializer.deserialize_any(Visitor) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn job_output_type_serializes_as_psu_numeric_value() { + let json = serde_json::to_string(&JobOutputType::Error).expect("serialize output type"); + assert_eq!(json, "4"); + } + + #[test] + fn job_output_type_accepts_worker_names() { + let output_type: JobOutputType = serde_json::from_str("\"Warning\"").expect("deserialize output type"); + assert_eq!(output_type, JobOutputType::Warning); + } +} diff --git a/devolutions-agent/src/psu_event_hub/powershell_worker.rs b/devolutions-agent/src/psu_event_hub/powershell_worker.rs new file mode 100644 index 000000000..41ae0928a --- /dev/null +++ b/devolutions-agent/src/psu_event_hub/powershell_worker.rs @@ -0,0 +1,485 @@ +use std::ffi::OsString; +use std::process::Stdio; +use std::sync::Arc; + +use anyhow::{Context as _, bail}; +use camino::{Utf8Path, Utf8PathBuf}; +use serde::Serialize; +use tokio::process::Command; +use tokio::sync::Semaphore; +use uuid::Uuid; + +use crate::config::dto::PsuPowerShellConf; +use crate::psu_event_hub::models::WebsocketEventResponse; + +const WORKER_SCRIPT: &str = r#" +param([string] $RequestPath) + +function New-PsuResponse { + [ordered]@{ + data = $null + jobOutputs = @() + complete = $true + timeout = $false + terminatingError = $null + } +} + +function Add-PsuJobOutput { + param( + [System.Collections.IDictionary] $Response, + [int] $Type, + [object] $Record + ) + + $data = if ($null -eq $Record) { + $null + } else { + ($Record | Out-String).TrimEnd() + } + + $Response.jobOutputs += ,([ordered]@{ + id = 0 + message = $null + type = $Type + data = $data + timestamp = [DateTime]::UtcNow.ToString('O') + jobId = 0 + }) +} + +function Split-PsuPipelineOutput { + param( + [System.Collections.IDictionary] $Response, + [object[]] $Items + ) + + $pipeline = New-Object System.Collections.ArrayList + + foreach ($item in $Items) { + if ($item -is [System.Management.Automation.ErrorRecord]) { + Add-PsuJobOutput -Response $Response -Type 4 -Record $item + } elseif ($item -is [System.Management.Automation.WarningRecord]) { + Add-PsuJobOutput -Response $Response -Type 3 -Record $item + } elseif ($item -is [System.Management.Automation.InformationRecord]) { + Add-PsuJobOutput -Response $Response -Type 0 -Record $item + } elseif ($item -is [System.Management.Automation.VerboseRecord]) { + Add-PsuJobOutput -Response $Response -Type 1 -Record $item + } elseif ($item -is [System.Management.Automation.DebugRecord]) { + Add-PsuJobOutput -Response $Response -Type 2 -Record $item + } elseif ($item -is [System.Management.Automation.ProgressRecord]) { + Add-PsuJobOutput -Response $Response -Type 5 -Record $item + } else { + [void] $pipeline.Add($item) + } + } + + $pipeline.ToArray() +} + +$response = New-PsuResponse + +try { + $request = Get-Content -Raw -LiteralPath $RequestPath | ConvertFrom-Json + + $VerbosePreference = 'Continue' + $DebugPreference = 'Continue' + $InformationPreference = 'Continue' + $WarningPreference = 'Continue' + + if ($request.kind -eq 'command') { + $item = [System.Management.Automation.PSSerializer]::Deserialize([string] $request.data) + + if ($item -is [System.Management.Automation.PSObject] -and $item.GetType().FullName -eq 'System.Management.Automation.PSObject') { + $item = $item.BaseObject + } + + if ($item -isnot [hashtable]) { + $response.terminatingError = 'Data was not a hashtable' + } else { + $powerShell = [System.Management.Automation.PowerShell]::Create() + try { + [void] $powerShell.AddCommand([string] $request.command) + + foreach ($key in $item.Keys) { + [void] $powerShell.AddParameter([string] $key, $item[$key]) + } + + $pipeline = $powerShell.Invoke() + + foreach ($record in $powerShell.Streams.Error) { + Add-PsuJobOutput -Response $response -Type 4 -Record $record + } + foreach ($record in $powerShell.Streams.Warning) { + Add-PsuJobOutput -Response $response -Type 3 -Record $record + } + foreach ($record in $powerShell.Streams.Information) { + Add-PsuJobOutput -Response $response -Type 0 -Record $record + } + foreach ($record in $powerShell.Streams.Verbose) { + Add-PsuJobOutput -Response $response -Type 1 -Record $record + } + foreach ($record in $powerShell.Streams.Debug) { + Add-PsuJobOutput -Response $response -Type 2 -Record $record + } + foreach ($record in $powerShell.Streams.Progress) { + Add-PsuJobOutput -Response $response -Type 5 -Record $record + } + + if ($request.returnResult) { + $response.data = [System.Management.Automation.PSSerializer]::Serialize($pipeline) + } + } finally { + $powerShell.Dispose() + } + } + } elseif ($request.kind -eq 'script') { + if ([string]::IsNullOrWhiteSpace([string] $request.scriptPath) -or -not (Test-Path -LiteralPath ([string] $request.scriptPath))) { + $response.terminatingError = 'No script block found.' + } else { + $eventData = [System.Management.Automation.PSSerializer]::Deserialize([string] $request.data) + Set-Variable -Name EventData -Value $eventData -Scope Local -Force + Set-Variable -Name _ -Value $eventData -Scope Local -Force + + $items = . ([string] $request.scriptPath) *>&1 + $pipeline = Split-PsuPipelineOutput -Response $response -Items @($items) + + if ($request.returnResult) { + $response.data = [System.Management.Automation.PSSerializer]::Serialize($pipeline) + } + } + } elseif ($request.kind -eq 'secret') { + $secretName = [string] $request.data + $secret = Get-Secret -Name $secretName -AsPlainText -ErrorAction Stop + if ($null -eq $secret) { + $response.terminatingError = "Secret not found: $secretName" + } else { + $response.data = [string] $secret + } + } else { + $response.terminatingError = "Unknown PSU worker request kind: $($request.kind)" + } +} catch { + $response.terminatingError = $_.Exception.Message +} + +$response | ConvertTo-Json -Compress -Depth 16 +"#; + +#[derive(Debug, Clone)] +pub(super) struct PowerShellWorker { + conf: PsuPowerShellConf, + permits: Arc, +} + +impl PowerShellWorker { + pub(super) fn new(conf: PsuPowerShellConf) -> Self { + let worker_limit = effective_worker_limit(&conf); + Self { + conf, + permits: Arc::new(Semaphore::new(worker_limit)), + } + } + + pub(super) async fn resolve_app_token(&self, app_token: &str) -> anyhow::Result { + let Some(secret_name) = secret_reference_name(app_token) else { + return Ok(app_token.to_owned()); + }; + + let response = self.run_request(WorkerRequest::secret(secret_name.to_owned())).await?; + if let Some(error) = response.terminating_error { + bail!("failed to resolve PSU AppToken secret {secret_name}: {error}"); + } + + response + .data + .filter(|secret| !secret.is_empty()) + .with_context(|| format!("PSU AppToken secret {secret_name} resolved to an empty value")) + } + + pub(super) async fn execute_command( + &self, + command: String, + data: String, + return_result: bool, + ) -> anyhow::Result { + self.run_request(WorkerRequest::command(command, data, return_result)) + .await + } + + pub(super) async fn execute_script( + &self, + script_path: Utf8PathBuf, + data: String, + return_result: bool, + ) -> anyhow::Result { + self.run_request(WorkerRequest::script(script_path, data, return_result)) + .await + } + + async fn run_request(&self, request: WorkerRequest) -> anyhow::Result { + let _permit = self + .permits + .acquire() + .await + .context("PSU PowerShell worker pool is closed")?; + let temp_dir = Utf8PathBuf::from_path_buf(std::env::temp_dir()) + .map_err(|path| anyhow::anyhow!("non-UTF-8 temp path: {path:?}"))?; + let request_path = temp_dir.join(format!("devolutions-agent-psu-{}.json", Uuid::new_v4())); + let script_path = temp_dir.join(format!("devolutions-agent-psu-{}.ps1", Uuid::new_v4())); + + let request_json = serde_json::to_vec(&request).context("failed to serialize PSU worker request")?; + tokio::fs::write(&request_path, request_json) + .await + .with_context(|| format!("failed to write PSU worker request at {request_path}"))?; + tokio::fs::write(&script_path, WORKER_SCRIPT) + .await + .with_context(|| format!("failed to write PSU worker script at {script_path}"))?; + + let output = self.invoke_worker(&script_path, &request_path).await; + + remove_temp_file(&request_path).await; + remove_temp_file(&script_path).await; + + output + } + + async fn invoke_worker( + &self, + script_path: &Utf8Path, + request_path: &Utf8Path, + ) -> anyhow::Result { + let executable = resolve_powershell_executable(&self.conf); + let mut command = Command::new(&executable); + command + .arg("-NoLogo") + .arg("-NoProfile") + .arg("-NonInteractive") + .arg("-ExecutionPolicy") + .arg("Bypass") + .arg("-File") + .arg(script_path.as_std_path()) + .arg(request_path.as_std_path()) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if let Some(virtual_environment) = &self.conf.virtual_environment { + command.env("PSMODULE_VENV_PATH", virtual_environment); + } + + let output = command.output().await.with_context(|| { + format!( + "failed to start PowerShell worker using {}", + executable.to_string_lossy() + ) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + bail!( + "PowerShell worker exited with status {}: {}", + output.status, + stderr.trim() + ); + } + + serde_json::from_slice(&output.stdout).context("failed to parse PowerShell worker response") + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct WorkerRequest { + kind: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + command: Option, + #[serde(skip_serializing_if = "Option::is_none")] + script_path: Option, + data: String, + return_result: bool, +} + +impl WorkerRequest { + fn command(command: String, data: String, return_result: bool) -> Self { + Self { + kind: "command", + command: Some(command), + script_path: None, + data, + return_result, + } + } + + fn script(script_path: Utf8PathBuf, data: String, return_result: bool) -> Self { + Self { + kind: "script", + command: None, + script_path: Some(script_path), + data, + return_result, + } + } + + fn secret(secret_name: String) -> Self { + Self { + kind: "secret", + command: None, + script_path: None, + data: secret_name, + return_result: true, + } + } +} + +fn secret_reference_name(app_token: &str) -> Option<&str> { + let prefix = "$secret:"; + app_token + .get(..prefix.len()) + .filter(|candidate| candidate.eq_ignore_ascii_case(prefix)) + .and_then(|_| app_token.get(prefix.len()..)) + .filter(|name| !name.is_empty()) +} + +fn effective_worker_limit(conf: &PsuPowerShellConf) -> usize { + let max_worker_pool_size = conf.max_worker_pool_size.max(1); + if conf.worker_pool_size > max_worker_pool_size { + warn!( + worker_pool_size = conf.worker_pool_size, + max_worker_pool_size, + "PSU worker pool size exceeds maximum, limiting concurrent workers to MaxWorkerPoolSize" + ); + } + max_worker_pool_size +} + +fn resolve_powershell_executable(conf: &PsuPowerShellConf) -> OsString { + if let Some(path) = &conf.executable_path { + return path.as_str().into(); + } + + if let Some(selector) = &conf.version_selector { + if selector.eq_ignore_ascii_case("pwsh") + || selector.eq_ignore_ascii_case("pwsh-preview") + || selector.eq_ignore_ascii_case("pwsh-lts") + || selector.starts_with("pwsh-") + { + return selector.into(); + } + + return format!("pwsh-{selector}").into(); + } + + if conf.use_windows_power_shell { + "powershell.exe".into() + } else { + "pwsh".into() + } +} + +async fn remove_temp_file(path: &Utf8Path) { + if let Err(error) = tokio::fs::remove_file(path).await { + debug!(%path, %error, "Failed to remove temporary PSU worker file"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::psu_event_hub::models::JobOutputType; + + const HASHTABLE_PS_VERSION_TABLE: &str = r#" + + + System.Collections.Hashtable + System.Object + + + + ValueOnly + true + + + Name + PSVersionTable + + + +"#; + + const HASHTABLE_MESSAGE: &str = r#" + + + System.Collections.Hashtable + System.Object + + + + Message + Hello World + + + +"#; + + #[tokio::test] + async fn command_execution_returns_clixml_result() { + let worker = PowerShellWorker::new(PsuPowerShellConf::default()); + let response = worker + .execute_command("Get-Variable".to_owned(), HASHTABLE_PS_VERSION_TABLE.to_owned(), true) + .await + .expect("execute command"); + + assert!(response.complete); + assert!(response.terminating_error.is_none()); + assert!(response.data.expect("serialized response").contains(">>, +} + +impl ResultStore { + pub(super) fn insert(&self, execution_id: String, response: WebsocketEventResponse) { + self.inner.lock().insert(execution_id, response); + } + + pub(super) fn take(&self, execution_id: &str) -> WebsocketEventResponse { + self.inner + .lock() + .remove(execution_id) + .unwrap_or_else(WebsocketEventResponse::pending) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn take_removes_result_after_first_read() { + let store = ResultStore::default(); + store.insert( + "execution-id".to_owned(), + WebsocketEventResponse { + complete: true, + ..WebsocketEventResponse::default() + }, + ); + + assert!(store.take("execution-id").complete); + assert!(!store.take("execution-id").complete); + } +} diff --git a/devolutions-agent/src/psu_event_hub/signalr.rs b/devolutions-agent/src/psu_event_hub/signalr.rs new file mode 100644 index 000000000..9481104b6 --- /dev/null +++ b/devolutions-agent/src/psu_event_hub/signalr.rs @@ -0,0 +1,280 @@ +use anyhow::{Context as _, bail}; +use futures::{SinkExt as _, StreamExt as _}; +use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; +use serde::Deserialize; +use serde_json::{Value, json}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest as _; +use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION as WS_AUTHORIZATION; +use tokio_tungstenite::tungstenite::http::{ + HeaderMap as WsHeaderMap, HeaderName as WsHeaderName, HeaderValue as WsHeaderValue, +}; +use url::Url; + +use crate::config::dto::PsuEventHubConnectionConf; +use crate::psu_event_hub::executor::EventHubExecutor; + +const RECORD_SEPARATOR: char = '\u{1e}'; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct NegotiateResponse { + connection_id: Option, + connection_token: Option, +} + +pub(super) async fn run_connection( + connection: PsuEventHubConnectionConf, + executor: EventHubExecutor, + mut shutdown_signal: devolutions_gateway_task::ShutdownSignal, +) -> anyhow::Result<()> { + loop { + match run_single_connection(&connection, &executor, &mut shutdown_signal).await { + Ok(()) => { + info!(hub = %connection.hub, "Stopping PSU Event Hub connection"); + return Ok(()); + } + Err(error) => { + warn!( + hub = %connection.hub, + url = %connection.url, + error = format!("{error:#}"), + "PSU Event Hub connection failed" + ); + } + } + + tokio::select! { + _ = shutdown_signal.wait() => return Ok(()), + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => {} + } + } +} + +async fn run_single_connection( + connection: &PsuEventHubConnectionConf, + executor: &EventHubExecutor, + shutdown_signal: &mut devolutions_gateway_task::ShutdownSignal, +) -> anyhow::Result<()> { + if connection.use_default_credentials && connection.app_token.is_none() { + warn!( + hub = %connection.hub, + "PSU Event Hub UseDefaultCredentials is configured, but Windows default credentials are not implemented yet" + ); + } + + let endpoint = endpoint_url(connection)?; + let negotiate = negotiate_url(&endpoint)?; + let headers = psu_headers(connection)?; + let client = reqwest::Client::new(); + + let mut request = client.post(negotiate.clone()).headers(headers.clone()); + if let Some(token) = &connection.app_token { + request = request.bearer_auth(token); + } + + let negotiate_response: NegotiateResponse = request + .send() + .await + .with_context(|| format!("failed to negotiate SignalR connection at {negotiate}"))? + .error_for_status() + .with_context(|| format!("SignalR negotiate failed at {negotiate}"))? + .json() + .await + .context("failed to parse SignalR negotiate response")?; + + let connection_token = negotiate_response + .connection_token + .or(negotiate_response.connection_id) + .context("SignalR negotiate response did not include a connection token")?; + + let ws_url = websocket_url(&endpoint, &connection_token, connection.app_token.as_deref())?; + let mut ws_request = ws_url.as_str().into_client_request()?; + apply_ws_headers(ws_request.headers_mut(), &headers)?; + if let Some(token) = &connection.app_token { + let value = format!("Bearer {token}"); + ws_request + .headers_mut() + .insert(WS_AUTHORIZATION, WsHeaderValue::from_str(&value)?); + } + + info!(hub = %connection.hub, url = %connection.url, "Connecting to PSU Event Hub"); + let (mut socket, _) = connect_async(ws_request) + .await + .with_context(|| format!("failed to connect PSU Event Hub WebSocket at {ws_url}"))?; + + socket + .send(Message::Text( + format!(r#"{{"protocol":"json","version":1}}{RECORD_SEPARATOR}"#).into(), + )) + .await + .context("failed to send SignalR handshake")?; + + info!(hub = %connection.hub, "Connected to PSU Event Hub"); + + loop { + tokio::select! { + _ = shutdown_signal.wait() => { + let _ = socket.close(None).await; + return Ok(()); + } + message = socket.next() => { + let Some(message) = message else { + bail!("SignalR WebSocket closed"); + }; + + match message.context("failed to read SignalR WebSocket message")? { + Message::Text(text) => handle_text_message(&mut socket, executor, &text).await?, + Message::Binary(bytes) => { + let text = String::from_utf8(bytes.to_vec()).context("SignalR binary message was not UTF-8")?; + handle_text_message(&mut socket, executor, &text).await?; + } + Message::Close(frame) => bail!("SignalR WebSocket closed: {frame:?}"), + Message::Ping(payload) => socket.send(Message::Pong(payload)).await?, + Message::Pong(_) => {} + Message::Frame(_) => {} + } + } + } + } +} + +async fn handle_text_message(socket: &mut S, executor: &EventHubExecutor, text: &str) -> anyhow::Result<()> +where + S: futures::Sink + Unpin, +{ + for frame in text.split(RECORD_SEPARATOR).filter(|frame| !frame.is_empty()) { + let value: Value = + serde_json::from_str(frame).with_context(|| format!("invalid SignalR JSON frame: {frame}"))?; + let message_type = value.get("type").and_then(Value::as_u64); + + match message_type { + None => {} + Some(1) => handle_invocation(socket, executor, value).await?, + Some(6) => {} + Some(7) => bail!("SignalR server sent close message"), + Some(message_type) => trace!(message_type, "Ignoring unsupported SignalR message"), + } + } + + Ok(()) +} + +async fn handle_invocation(socket: &mut S, executor: &EventHubExecutor, value: Value) -> anyhow::Result<()> +where + S: futures::Sink + Unpin, +{ + let target = value + .get("target") + .and_then(Value::as_str) + .context("SignalR invocation missing target")?; + let arguments = value + .get("arguments") + .and_then(Value::as_array) + .map(Vec::as_slice) + .unwrap_or(&[]); + let invocation_id = value.get("invocationId").and_then(Value::as_str); + + let result = executor.handle_invocation(target, arguments)?; + if let Some(invocation_id) = invocation_id { + let completion = if let Some(result) = result { + json!({ + "type": 3, + "invocationId": invocation_id, + "result": result, + }) + } else { + json!({ + "type": 3, + "invocationId": invocation_id, + }) + }; + + socket + .send(Message::Text(format!("{completion}{RECORD_SEPARATOR}").into())) + .await + .context("failed to send SignalR completion")?; + } + + Ok(()) +} + +fn endpoint_url(connection: &PsuEventHubConnectionConf) -> anyhow::Result { + let endpoint = if connection.app_token.is_some() || connection.use_default_credentials { + "autheventhub" + } else { + "eventhub" + }; + + let mut url = Url::parse(&format!("{}/{endpoint}", connection.url.as_str().trim_end_matches('/'))) + .context("failed to build PSU Event Hub URL")?; + url.query_pairs_mut().append_pair("group", &connection.hub); + Ok(url) +} + +fn negotiate_url(endpoint: &Url) -> anyhow::Result { + let mut url = endpoint.clone(); + let path = format!("{}/negotiate", endpoint.path().trim_end_matches('/')); + url.set_path(&path); + url.query_pairs_mut().append_pair("negotiateVersion", "1"); + Ok(url) +} + +fn websocket_url(endpoint: &Url, connection_token: &str, access_token: Option<&str>) -> anyhow::Result { + let mut url = endpoint.clone(); + let scheme = match endpoint.scheme() { + "http" => "ws", + "https" => "wss", + scheme => bail!("unsupported SignalR endpoint scheme: {scheme}"), + }; + url.set_scheme(scheme) + .map_err(|_| anyhow::anyhow!("failed to set SignalR WebSocket URL scheme"))?; + url.query_pairs_mut().append_pair("id", connection_token); + if let Some(access_token) = access_token { + url.query_pairs_mut().append_pair("access_token", access_token); + } + Ok(url) +} + +fn psu_headers(connection: &PsuEventHubConnectionConf) -> anyhow::Result { + let mut headers = HeaderMap::new(); + headers.insert("PSUComputerName", HeaderValue::from_str(&computer_name())?); + headers.insert("PSUUserName", HeaderValue::from_str(&user_name())?); + headers.insert("PSUDomainName", HeaderValue::from_str(&domain_name())?); + headers.insert("PSUVersion", HeaderValue::from_static(env!("CARGO_PKG_VERSION"))); + headers.insert( + "PSUDescription", + HeaderValue::from_str(connection.description.as_deref().unwrap_or_default())?, + ); + if let Some(token) = &connection.app_token { + headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}"))?); + } + Ok(headers) +} + +fn apply_ws_headers(target: &mut WsHeaderMap, source: &HeaderMap) -> anyhow::Result<()> { + for (name, value) in source { + let name = WsHeaderName::from_bytes(name.as_str().as_bytes())?; + let value = WsHeaderValue::from_bytes(value.as_bytes())?; + target.insert(name, value); + } + Ok(()) +} + +fn computer_name() -> String { + std::env::var("COMPUTERNAME") + .ok() + .or_else(|| hostname::get().ok().and_then(|name| name.into_string().ok())) + .unwrap_or_else(|| "localhost".to_owned()) +} + +fn user_name() -> String { + std::env::var("USERNAME") + .or_else(|_| std::env::var("USER")) + .unwrap_or_default() +} + +fn domain_name() -> String { + std::env::var("USERDOMAIN").unwrap_or_default() +} diff --git a/devolutions-agent/src/service.rs b/devolutions-agent/src/service.rs index 276a2e4f6..d1ab5e6cc 100644 --- a/devolutions-agent/src/service.rs +++ b/devolutions-agent/src/service.rs @@ -4,6 +4,7 @@ use anyhow::Context; use devolutions_agent::AgentServiceEvent; use devolutions_agent::config::ConfHandle; use devolutions_agent::log::AgentLog; +use devolutions_agent::psu_event_hub::PsuEventHubTask; use devolutions_agent::remote_desktop::RemoteDesktopTask; #[cfg(windows)] use devolutions_agent::session_manager::SessionManager; @@ -232,7 +233,11 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { } if conf.tunnel.enabled { - tasks.register(TunnelTask::new(conf_handle)); + tasks.register(TunnelTask::new(conf_handle.clone())); + } + + if conf.psu_event_hub.enabled { + tasks.register(PsuEventHubTask::new(conf_handle)); } Ok(TasksCtx { diff --git a/package/AgentWindowsManaged/Actions/CustomActions.cs b/package/AgentWindowsManaged/Actions/CustomActions.cs index ebd15bdd0..009528b1f 100644 --- a/package/AgentWindowsManaged/Actions/CustomActions.cs +++ b/package/AgentWindowsManaged/Actions/CustomActions.cs @@ -266,22 +266,28 @@ static ActionResult ToggleAgentFeature(Session session, string feature, bool ena try { - Dictionary config = []; + JObject config = new JObject(); try { using StreamReader reader = new StreamReader(path); - config = JsonConvert.DeserializeObject>(reader.ReadToEnd()); + config = JObject.Parse(reader.ReadToEnd()); } catch (Exception) { // ignored. Previous config is either invalid or non-existent. } - config[feature] = new Dictionary {{"Enabled", enable}}; + if (config[feature] is not JObject featureConfig) + { + featureConfig = new JObject(); + config[feature] = featureConfig; + } + + featureConfig["Enabled"] = enable; using StreamWriter writer = new StreamWriter(path); - writer.Write(JsonConvert.SerializeObject(config)); + writer.Write(config.ToString(Formatting.None)); return ActionResult.Success; } @@ -302,6 +308,7 @@ public static ActionResult SetFeaturesToConfigure(Session session) [ (Features.SESSION_FEATURE.Id, Features.SESSION_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length)), (Features.AGENT_UPDATER_FEATURE.Id, Features.AGENT_UPDATER_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length)), + (Features.PSU_EVENT_HUB_FEATURE.Id, Features.PSU_EVENT_HUB_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length)), (Features.PEDM_FEATURE.Id, Features.PEDM_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length)), ]; @@ -575,6 +582,7 @@ public static ActionResult ConfigureFeatures(Session session) [ Features.SESSION_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length), Features.AGENT_UPDATER_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length), + Features.PSU_EVENT_HUB_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length), Features.PEDM_FEATURE.Id.Substring(Features.FEATURE_ID_PREFIX.Length), ]; diff --git a/package/AgentWindowsManaged/Program.cs b/package/AgentWindowsManaged/Program.cs index 2d9ead71c..46054fe84 100644 --- a/package/AgentWindowsManaged/Program.cs +++ b/package/AgentWindowsManaged/Program.cs @@ -335,6 +335,13 @@ static void Main() Win64 = project.Platform == Platform.x64, RegistryKeyAction = RegistryKeyAction.create, Feature = Features.AGENT_TUNNEL_FEATURE, + }, + new (RegistryHive.LocalMachine, $"SOFTWARE\\{Includes.VENDOR_NAME}\\{Includes.SHORT_NAME}", "PsuEventHubEnabled", "1") + { + AttributesDefinition = "Type=string", + Win64 = project.Platform == Platform.x64, + RegistryKeyAction = RegistryKeyAction.create, + Feature = Features.PSU_EVENT_HUB_FEATURE, } }; diff --git a/package/AgentWindowsManaged/Resources/DevolutionsAgent_en-us.wxl b/package/AgentWindowsManaged/Resources/DevolutionsAgent_en-us.wxl index d92de55f8..b5e4d4dbf 100644 --- a/package/AgentWindowsManaged/Resources/DevolutionsAgent_en-us.wxl +++ b/package/AgentWindowsManaged/Resources/DevolutionsAgent_en-us.wxl @@ -5,6 +5,8 @@ Devolutions Agent Connects the agent to a Devolutions Gateway. Requires an enrollment string from your gateway operator. Agent Tunnel + Enables PowerShell Universal Event Hub remote agent compatibility. + PowerShell Universal Event Hub Enables the Devolutions Gateway updater Devolutions Gateway Updater Enables PEDM features and installs the shell extension diff --git a/package/AgentWindowsManaged/Resources/DevolutionsAgent_fr-fr.wxl b/package/AgentWindowsManaged/Resources/DevolutionsAgent_fr-fr.wxl index d1603d39a..3adee9a55 100644 --- a/package/AgentWindowsManaged/Resources/DevolutionsAgent_fr-fr.wxl +++ b/package/AgentWindowsManaged/Resources/DevolutionsAgent_fr-fr.wxl @@ -3,6 +3,8 @@ Connecte l'agent à une passerelle Devolutions. Nécessite une chaîne d'enrôlement fournie par l'opérateur de votre passerelle. Tunnel d'agent + Active la compatibilité d'agent distant pour PowerShell Universal Event Hub. + PowerShell Universal Event Hub Installe l'extension RDP Extension RDP 1036 diff --git a/package/AgentWindowsManaged/Resources/Features.cs b/package/AgentWindowsManaged/Resources/Features.cs index 4408d44b0..01bfd037d 100644 --- a/package/AgentWindowsManaged/Resources/Features.cs +++ b/package/AgentWindowsManaged/Resources/Features.cs @@ -22,11 +22,16 @@ internal static class Features Id = $"{FEATURE_ID_PREFIX}Tunnel" }; + internal static Feature PSU_EVENT_HUB_FEATURE = new("!(loc.FeaturePsuEventHubName)", "!(loc.FeaturePsuEventHubDescription)", isEnabled: false, allowChange: true) + { + Id = $"{FEATURE_ID_PREFIX}PsuEventHub" + }; + internal static Feature AGENT_FEATURE = new("!(loc.FeatureAgentName)", isEnabled: true, allowChange: false) { Id = $"{FEATURE_ID_PREFIX}Agent", Description = "!(loc.FeatureAgentDescription)", - Children = [ AGENT_UPDATER_FEATURE, AGENT_TUNNEL_FEATURE ] + Children = [ AGENT_UPDATER_FEATURE, AGENT_TUNNEL_FEATURE, PSU_EVENT_HUB_FEATURE ] }; internal static Feature PEDM_FEATURE = new("!(loc.FeaturePedmName)", "!(loc.FeaturePedmDescription)", isEnabled: false)