diff --git a/include/libkrun.h b/include/libkrun.h index 462a7166c..7b3a9360f 100644 --- a/include/libkrun.h +++ b/include/libkrun.h @@ -327,6 +327,10 @@ int32_t krun_add_virtiofs2(uint32_t ctx_id, as required by gvproxy in vfkit mode. */ #define NET_FLAG_VFKIT 1 << 0 +/* TSI (Transparent Socket Impersonation) feature flags for vsock */ +#define KRUN_TSI_HIJACK_INET (1 << 0) +#define KRUN_TSI_HIJACK_UNIX (1 << 1) + /* Taken from uapi/linux/virtio_net.h */ #define NET_FEATURE_CSUM 1 << 0 #define NET_FEATURE_GUEST_CSUM 1 << 1 @@ -862,6 +866,27 @@ int32_t krun_add_vsock_port2(uint32_t ctx_id, uint32_t port, const char *c_filepath, bool listen); + +/** + * Add a vsock device with specified TSI features. + * + * By default, libkrun creates a vsock device implicitly with TSI hijacking + * enabled based on heuristics. Calling this function overrides the implicit + * behavior and explicitly configures the vsock device. + * + * Currently only one vsock device is supported. Calling this function + * multiple times will return an error. + * + * Arguments: + * "ctx_id" - the configuration context ID. + * "tsi_features" - bitmask of TSI features (KRUN_TSI_HIJACK_INET, KRUN_TSI_HIJACK_UNIX) + * Use 0 to add vsock without any TSI hijacking. + * + * Returns: + * Zero on success or a negative error number on failure. + */ +int32_t krun_add_vsock(uint32_t ctx_id, uint32_t tsi_features); + /** * Returns the eventfd file descriptor to signal the guest to shut down orderly. This must be * called before starting the microVM with "krun_start_event". Only available in libkrun-efi. @@ -1009,6 +1034,20 @@ int32_t krun_nitro_set_start_flags(uint32_t ctx_id, uint64_t start_flags); */ int32_t krun_disable_implicit_console(uint32_t ctx_id); +/** + * Disable the implicit vsock device. + * + * By default, libkrun creates a vsock device automatically. This function + * disables that behavior entirely - no vsock device will be created. + * + * Arguments: + * "ctx_id" - the configuration context ID. + * + * Returns: + * Zero on success or a negative error number on failure. + */ +int32_t krun_disable_implicit_vsock(uint32_t ctx_id); + /* * Specify the value of `console=` in the kernel commandline. * diff --git a/src/devices/src/virtio/vsock/device.rs b/src/devices/src/virtio/vsock/device.rs index 35b80b1ae..2708327ca 100644 --- a/src/devices/src/virtio/vsock/device.rs +++ b/src/devices/src/virtio/vsock/device.rs @@ -18,6 +18,7 @@ use super::super::{ }; use super::muxer::VsockMuxer; use super::packet::VsockPacket; +use super::TsiFlags; use super::{defs, defs::uapi}; use crate::virtio::InterruptTransport; @@ -52,8 +53,7 @@ impl Vsock { host_port_map: Option>, queues: Vec, unix_ipc_port_map: Option>, - enable_tsi: bool, - enable_tsi_unix: bool, + tsi_flags: TsiFlags, ) -> super::Result { let mut queue_events = Vec::new(); for _ in 0..queues.len() { @@ -66,13 +66,7 @@ impl Vsock { Ok(Vsock { cid, - muxer: VsockMuxer::new( - cid, - host_port_map, - unix_ipc_port_map, - enable_tsi, - enable_tsi_unix, - ), + muxer: VsockMuxer::new(cid, host_port_map, unix_ipc_port_map, tsi_flags), queue_rx, queue_tx, queues, @@ -90,21 +84,13 @@ impl Vsock { cid: u64, host_port_map: Option>, unix_ipc_port_map: Option>, - enable_tsi: bool, - enable_tsi_unix: bool, + tsi_flags: TsiFlags, ) -> super::Result { let queues: Vec = defs::QUEUE_SIZES .iter() .map(|&max_size| VirtQueue::new(max_size)) .collect(); - Self::with_queues( - cid, - host_port_map, - queues, - unix_ipc_port_map, - enable_tsi, - enable_tsi_unix, - ) + Self::with_queues(cid, host_port_map, queues, unix_ipc_port_map, tsi_flags) } pub fn id(&self) -> &str { diff --git a/src/devices/src/virtio/vsock/mod.rs b/src/devices/src/virtio/vsock/mod.rs index b2d3d2648..7288de0bd 100644 --- a/src/devices/src/virtio/vsock/mod.rs +++ b/src/devices/src/virtio/vsock/mod.rs @@ -21,11 +21,39 @@ mod tsi_stream; mod unix; pub use self::defs::uapi::VIRTIO_ID_VSOCK as TYPE_VSOCK; +pub use self::defs::TsiFlags; pub use self::device::Vsock; +use bitflags::bitflags; use vm_memory::GuestMemoryError; mod defs { + use super::bitflags; + + bitflags! { + /// TSI (Transparent Socket Impersonation) feature flags. + /// + /// These flags control which socket families are hijacked by TSI. + pub struct TsiFlags: u32 { + /// Hijack AF_INET and AF_INET6 sockets + const HIJACK_INET = 1 << 0; + /// Hijack AF_UNIX sockets + const HIJACK_UNIX = 1 << 1; + } + } + + impl TsiFlags { + /// Returns true if any TSI hijacking is enabled. + pub fn tsi_enabled(&self) -> bool { + !self.is_empty() + } + } + + impl Default for TsiFlags { + fn default() -> Self { + TsiFlags::empty() + } + } /// Device ID used in MMIO device identification. /// Because Vsock is unique per-vm, this ID can be hardcoded. pub const VSOCK_DEV_ID: &str = "vsock"; diff --git a/src/devices/src/virtio/vsock/muxer.rs b/src/devices/src/virtio/vsock/muxer.rs index 0f91882dc..d3287586a 100644 --- a/src/devices/src/virtio/vsock/muxer.rs +++ b/src/devices/src/virtio/vsock/muxer.rs @@ -16,6 +16,7 @@ use super::timesync::TimesyncThread; use super::tsi_dgram::TsiDgramProxy; use super::tsi_stream::TsiStreamProxy; use super::unix::UnixProxy; +use super::TsiFlags; use super::VsockError; use crossbeam_channel::{unbounded, Sender}; use utils::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; @@ -106,8 +107,7 @@ pub struct VsockMuxer { proxy_map: ProxyMap, reaper_sender: Option>, unix_ipc_port_map: Option>, - enable_tsi: bool, - enable_tsi_unix: bool, + tsi_flags: TsiFlags, } impl VsockMuxer { @@ -115,8 +115,7 @@ impl VsockMuxer { cid: u64, host_port_map: Option>, unix_ipc_port_map: Option>, - enable_tsi: bool, - enable_tsi_unix: bool, + tsi_flags: TsiFlags, ) -> Self { VsockMuxer { cid, @@ -129,8 +128,7 @@ impl VsockMuxer { proxy_map: Arc::new(RwLock::new(HashMap::new())), reaper_sender: None, unix_ipc_port_map, - enable_tsi, - enable_tsi_unix, + tsi_flags, } } @@ -285,8 +283,16 @@ impl VsockMuxer { defs::SOCK_STREAM => { debug!("proxy create stream"); let id = ((req.peer_port as u64) << 32) | (defs::TSI_PROXY_PORT as u64); - if req.family as i32 == libc::AF_UNIX && !self.enable_tsi_unix { - warn!("rejecting tcp unix proxy because tsi_unix is disabled"); + if req.family as i32 == libc::AF_UNIX + && !self.tsi_flags.contains(TsiFlags::HIJACK_UNIX) + { + warn!("rejecting stream unix proxy because HIJACK_UNIX is disabled"); + return; + } + if (req.family as i32 == libc::AF_INET || req.family as i32 == libc::AF_INET6) + && !self.tsi_flags.contains(TsiFlags::HIJACK_INET) + { + warn!("rejecting stream inet proxy because HIJACK_INET is disabled"); return; } match TsiStreamProxy::new( @@ -312,8 +318,16 @@ impl VsockMuxer { defs::SOCK_DGRAM => { debug!("proxy create dgram"); let id = ((req.peer_port as u64) << 32) | (defs::TSI_PROXY_PORT as u64); - if req.family as i32 == libc::AF_UNIX && !self.enable_tsi_unix { - warn!("rejecting udp unix proxy because tsi_unix is disabled"); + if req.family as i32 == libc::AF_UNIX + && !self.tsi_flags.contains(TsiFlags::HIJACK_UNIX) + { + warn!("rejecting dgram unix proxy because HIJACK_UNIX is disabled"); + return; + } + if (req.family as i32 == libc::AF_INET || req.family as i32 == libc::AF_INET6) + && !self.tsi_flags.contains(TsiFlags::HIJACK_INET) + { + warn!("rejecting dgram inet proxy because HIJACK_INET is disabled"); return; } match TsiDgramProxy::new( @@ -498,14 +512,18 @@ impl VsockMuxer { } match pkt.dst_port() { - defs::TSI_PROXY_CREATE if self.enable_tsi => self.process_proxy_create(pkt), - defs::TSI_CONNECT if self.enable_tsi => self.process_connect(pkt), - defs::TSI_GETNAME if self.enable_tsi => self.process_getname(pkt), - defs::TSI_SENDTO_ADDR if self.enable_tsi => self.process_sendto_addr(pkt), - defs::TSI_SENDTO_DATA if self.enable_tsi => self.process_sendto_data(pkt), - defs::TSI_LISTEN if self.enable_tsi => self.process_listen_request(pkt), - defs::TSI_ACCEPT if self.enable_tsi => self.process_accept_request(pkt), - defs::TSI_PROXY_RELEASE if self.enable_tsi => self.process_proxy_release(pkt), + defs::TSI_PROXY_CREATE if self.tsi_flags.tsi_enabled() => { + self.process_proxy_create(pkt) + } + defs::TSI_CONNECT if self.tsi_flags.tsi_enabled() => self.process_connect(pkt), + defs::TSI_GETNAME if self.tsi_flags.tsi_enabled() => self.process_getname(pkt), + defs::TSI_SENDTO_ADDR if self.tsi_flags.tsi_enabled() => self.process_sendto_addr(pkt), + defs::TSI_SENDTO_DATA if self.tsi_flags.tsi_enabled() => self.process_sendto_data(pkt), + defs::TSI_LISTEN if self.tsi_flags.tsi_enabled() => self.process_listen_request(pkt), + defs::TSI_ACCEPT if self.tsi_flags.tsi_enabled() => self.process_accept_request(pkt), + defs::TSI_PROXY_RELEASE if self.tsi_flags.tsi_enabled() => { + self.process_proxy_release(pkt) + } _ => { if pkt.op() == uapi::VSOCK_OP_RW { self.process_dgram_rw(pkt); diff --git a/src/libkrun/src/lib.rs b/src/libkrun/src/lib.rs index a85f1b0ca..e042a23e0 100644 --- a/src/libkrun/src/lib.rs +++ b/src/libkrun/src/lib.rs @@ -38,8 +38,8 @@ use std::sync::LazyLock; use std::sync::Mutex; use utils::eventfd::EventFd; use vmm::resources::{ - DefaultVirtioConsoleConfig, PortConfig, SerialConsoleConfig, VirtioConsoleConfigMode, - VmResources, + DefaultVirtioConsoleConfig, PortConfig, SerialConsoleConfig, TsiFlags, VirtioConsoleConfigMode, + VmResources, VsockConfig, }; #[cfg(feature = "blk")] use vmm::vmm_config::block::{BlockDeviceConfig, BlockRootConfig}; @@ -144,6 +144,7 @@ struct ContextConfig { legacy_mac: Option<[u8; 6]>, net_index: u8, tsi_port_map: Option>, + vsock_config: VsockConfig, #[cfg(feature = "blk")] block_cfgs: Vec, #[cfg(feature = "blk")] @@ -1181,6 +1182,9 @@ pub unsafe extern "C" fn krun_set_port_map(ctx_id: u32, c_port_map: *const *cons match CTX_MAP.lock().unwrap().entry(ctx_id) { Entry::Occupied(mut ctx_cfg) => { let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config == VsockConfig::Disabled { + return -libc::ENODEV; + } if cfg.set_port_map(port_map).is_err() { return -libc::EINVAL; } @@ -1405,6 +1409,9 @@ pub unsafe extern "C" fn krun_add_vsock_port2( match CTX_MAP.lock().unwrap().entry(ctx_id) { Entry::Occupied(mut ctx_cfg) => { let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config == VsockConfig::Disabled { + return -libc::ENODEV; + } cfg.add_vsock_port(port, filepath, listen); } Entry::Vacant(_) => return -libc::ENOENT, @@ -2285,6 +2292,40 @@ pub extern "C" fn krun_disable_implicit_console(ctx_id: u32) -> i32 { KRUN_SUCCESS } +#[no_mangle] +pub extern "C" fn krun_disable_implicit_vsock(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vsock_config = VsockConfig::Disabled; + } + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS +} + +#[no_mangle] +pub extern "C" fn krun_add_vsock(ctx_id: u32, tsi_features: u32) -> i32 { + let tsi_flags = match TsiFlags::from_bits(tsi_features) { + Some(flags) => flags, + None => return -libc::EINVAL, + }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config != VsockConfig::Disabled { + return -libc::EEXIST; + } + cfg.vsock_config = VsockConfig::Explicit { tsi_flags }; + } + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS +} + #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn krun_add_virtio_console_default( @@ -2554,46 +2595,45 @@ pub extern "C" fn krun_start_enter(ctx_id: u32) -> i32 { } } - #[allow(unused_assignments)] - let mut vsock_set = false; - let mut vsock_config = VsockDeviceConfig { - vsock_id: "vsock0".to_string(), - guest_cid: 3, - host_port_map: None, - unix_ipc_port_map: None, - enable_tsi: false, - enable_tsi_unix: false, - }; + match &ctx_cfg.vsock_config { + VsockConfig::Disabled => (), + VsockConfig::Explicit { tsi_flags } => { + let vsock_device_config = VsockDeviceConfig { + vsock_id: "vsock0".to_string(), + guest_cid: 3, + host_port_map: ctx_cfg.tsi_port_map, + unix_ipc_port_map: ctx_cfg.unix_ipc_port_map.clone(), + tsi_flags: *tsi_flags, + }; + ctx_cfg.vmr.set_vsock_device(vsock_device_config).unwrap(); + } + VsockConfig::Implicit => { + // Implicit vsock configuration - use heuristics + // Check if TSI should be enabled based on network configuration + #[cfg(feature = "net")] + let enable_tsi = ctx_cfg.vmr.net.list.is_empty() && ctx_cfg.legacy_net_cfg.is_none(); + #[cfg(not(feature = "net"))] + let enable_tsi = true; - #[cfg(feature = "net")] - if ctx_cfg.vmr.net.list.is_empty() && ctx_cfg.legacy_net_cfg.is_none() { - vsock_config.host_port_map = ctx_cfg.tsi_port_map; - vsock_config.enable_tsi = true; - vsock_set = true; - } - #[cfg(not(feature = "net"))] - { - vsock_config.host_port_map = ctx_cfg.tsi_port_map; - vsock_config.enable_tsi = true; - vsock_set = true; - } + let has_ipc_map = ctx_cfg.unix_ipc_port_map.is_some(); - if let Some(ref map) = ctx_cfg.unix_ipc_port_map { - vsock_config.unix_ipc_port_map = Some(map.clone()); - vsock_set = true; - } + if enable_tsi || has_ipc_map { + let (tsi_flags, host_port_map) = if enable_tsi { + (TsiFlags::HIJACK_INET, ctx_cfg.tsi_port_map) + } else { + (TsiFlags::empty(), None) + }; - if vsock_set { - if vsock_config.enable_tsi { - // We only support using TSI for AF_UNIX in a containerized context, - // so only enable it when we have a single virtio-fs device pointing - // to root. - #[cfg(not(feature = "tee"))] - if ctx_cfg.vmr.fs.len() == 1 && ctx_cfg.vmr.fs[0].shared_dir == "/" { - vsock_config.enable_tsi_unix = true; + let vsock_device_config = VsockDeviceConfig { + vsock_id: "vsock0".to_string(), + guest_cid: 3, + host_port_map, + unix_ipc_port_map: ctx_cfg.unix_ipc_port_map.clone(), + tsi_flags, + }; + ctx_cfg.vmr.set_vsock_device(vsock_device_config).unwrap(); } } - ctx_cfg.vmr.set_vsock_device(vsock_config).unwrap(); } if let Some(virgl_flags) = ctx_cfg.gpu_virgl_flags { diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index 71bc2bfb1..aaade57d5 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -24,7 +24,7 @@ use super::{Error, Vmm}; use crate::device_manager::legacy::PortIODeviceManager; use crate::device_manager::mmio::MMIODeviceManager; use crate::resources::{ - DefaultVirtioConsoleConfig, PortConfig, VirtioConsoleConfigMode, VmResources, + DefaultVirtioConsoleConfig, PortConfig, TsiFlags, VirtioConsoleConfigMode, VmResources, }; use crate::vmm_config::external_kernel::{ExternalKernel, KernelFormat}; #[cfg(feature = "net")] @@ -1045,16 +1045,18 @@ pub fn build_microvm( )?; #[cfg(feature = "blk")] attach_block_devices(&mut vmm, &vm_resources.block, intc.clone())?; + if let Some(vsock) = vm_resources.vsock.get() { attach_unixsock_vsock_device(&mut vmm, vsock, event_manager, intc.clone())?; - #[cfg(not(feature = "net"))] - vmm.kernel_cmdline.insert_str("tsi_hijack")?; - #[cfg(feature = "net")] - if vm_resources.net.list.is_empty() { - // Only enable TSI if we don't have any network devices. + let tsi_flags = vm_resources.vsock.tsi_flags(); + if tsi_flags.contains(TsiFlags::HIJACK_INET) { vmm.kernel_cmdline.insert_str("tsi_hijack")?; } + if tsi_flags.contains(TsiFlags::HIJACK_UNIX) { + vmm.kernel_cmdline.insert_str("tsi_hijack_unix")?; + } } + #[cfg(feature = "net")] attach_net_devices(&mut vmm, &vm_resources.net, intc.clone())?; #[cfg(feature = "snd")] diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 800146e8f..d8d0fff24 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -37,6 +37,9 @@ use krun_display::DisplayBackend; type Result = std::result::Result<(), E>; +// Re-export TsiFlags from devices crate +pub use devices::virtio::TsiFlags; + /// Errors encountered when configuring microVM resources. #[derive(Debug)] pub enum Error { @@ -109,6 +112,18 @@ pub enum PortConfig { }, } +/// Configuration for the vsock device +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub enum VsockConfig { + /// Default behavior - vsock created implicitly with heuristics-based TSI + #[default] + Implicit, + /// Explicit configuration with specified TSI features + Explicit { tsi_flags: TsiFlags }, + /// Vsock device disabled + Disabled, +} + /// A data structure that encapsulates the device configurations /// held in the Vmm. #[derive(Default)] diff --git a/src/vmm/src/vmm_config/vsock.rs b/src/vmm/src/vmm_config/vsock.rs index 732505fb9..1e3d6300d 100644 --- a/src/vmm/src/vmm_config/vsock.rs +++ b/src/vmm/src/vmm_config/vsock.rs @@ -6,7 +6,7 @@ use std::fmt; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -use devices::virtio::{Vsock, VsockError}; +use devices::virtio::{TsiFlags, Vsock, VsockError}; type MutexVsock = Arc>; @@ -40,10 +40,8 @@ pub struct VsockDeviceConfig { pub host_port_map: Option>, /// An optional map of guest port to host UNIX domain sockets for IPC. pub unix_ipc_port_map: Option>, - /// Whether to enable TSI - pub enable_tsi: bool, - /// Whether to enable TSI for AF_UNIX - pub enable_tsi_unix: bool, + /// TSI feature flags + pub tsi_flags: TsiFlags, } struct VsockWrapper { @@ -54,17 +52,22 @@ struct VsockWrapper { #[derive(Default)] pub struct VsockBuilder { inner: Option, + tsi_flags: TsiFlags, } impl VsockBuilder { /// Creates an empty Vsock. pub fn new() -> Self { - Self { inner: None } + Self { + inner: None, + tsi_flags: TsiFlags::empty(), + } } /// Inserts a Vsock in the store. /// If an entry already exists, it will overwrite it. pub fn insert(&mut self, cfg: VsockDeviceConfig) -> Result<()> { + self.tsi_flags = cfg.tsi_flags; self.inner = Some(VsockWrapper { vsock: Arc::new(Mutex::new(Self::create_vsock(cfg)?)), }); @@ -76,14 +79,17 @@ impl VsockBuilder { self.inner.as_ref().map(|pair| &pair.vsock) } + pub fn tsi_flags(&self) -> TsiFlags { + self.tsi_flags + } + /// Creates a Vsock device from a VsockDeviceConfig. pub fn create_vsock(cfg: VsockDeviceConfig) -> Result { Vsock::new( u64::from(cfg.guest_cid), cfg.host_port_map, cfg.unix_ipc_port_map, - cfg.enable_tsi, - cfg.enable_tsi_unix, + cfg.tsi_flags, ) .map_err(VsockConfigError::CreateVsockDevice) } @@ -121,8 +127,7 @@ pub(crate) mod tests { guest_cid: 3, host_port_map: None, unix_ipc_port_map: None, - enable_tsi: false, - enable_tsi_unix: false, + tsi_flags: TsiFlags::empty(), } }