diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 42a33b0a3..6126e6676 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1577,6 +1577,7 @@ dependencies = [ "base64 0.22.1", "clap", "defguard-client-common", + "defguard-client-config-sync", "defguard-client-core", "defguard-client-posture", "defguard-client-proto", diff --git a/src-tauri/client-cli/Cargo.toml b/src-tauri/client-cli/Cargo.toml index 4c54254ce..cb40bcbac 100644 --- a/src-tauri/client-cli/Cargo.toml +++ b/src-tauri/client-cli/Cargo.toml @@ -17,6 +17,7 @@ owo-colors = { version = "4", features = ["supports-colors"] } common = { package = "defguard-client-common", path = "../common" } defguard_core = { package = "defguard-client-core", path = "../core" } +defguard_client_config_sync = { package = "defguard-client-config-sync", path = "../enterprise/config-sync" } defguard_client_posture = { package = "defguard-client-posture", path = "../enterprise/posture" } defguard_client_proto = { package = "defguard-client-proto", path = "../client-proto" } base64.workspace = true diff --git a/src-tauri/client-cli/src/brand.rs b/src-tauri/client-cli/src/brand.rs index 4cf96e55e..29a0dbfbd 100644 --- a/src-tauri/client-cli/src/brand.rs +++ b/src-tauri/client-cli/src/brand.rs @@ -62,6 +62,7 @@ pub fn print_banner() { #[cfg(windows)] pub fn print_banner() { let project = common::version_string("defguard-cli"); + println!(); println!(" {project}"); println!(" {COPYRIGHT}"); println!(); diff --git a/src-tauri/client-cli/src/config_poll.rs b/src-tauri/client-cli/src/config_poll.rs new file mode 100644 index 000000000..af790b1f8 --- /dev/null +++ b/src-tauri/client-cli/src/config_poll.rs @@ -0,0 +1,84 @@ +use std::collections::{HashMap, HashSet}; + +use defguard_client_config_sync::{poll_instances, PollInstanceResult}; +use defguard_core::{ + connection::active_state::active_state, + database::models::{location::Location, Id}, + error::Error, + ConnectionType, +}; +use tracing::debug; + +use crate::state::State; + +pub async fn poll_config(state: &State) { + let active_instance_ids = match active_instance_ids(state).await { + Ok(ids) => ids, + Err(err) => { + debug!("Skipping configuration polling, failed to detect active connections: {err}"); + return; + } + }; + + let outcomes = match poll_instances(&state.pool, &active_instance_ids).await { + Ok(outcomes) => outcomes, + Err(err) => { + debug!("Skipping configuration polling: {err}"); + return; + } + }; + + for outcome in outcomes { + match outcome.result { + Ok(PollInstanceResult::ChangedWhileActive { .. }) => { + eprintln!( + "Instance {} configuration changed, disconnect to apply changes", + outcome.instance_name + ); + } + Ok(PollInstanceResult::Updated { .. } | PollInstanceResult::Unchanged { .. }) => {} + Err(Error::CoreNotEnterprise) => { + debug!( + "Instance {} is not enterprise, skipping configuration polling", + outcome.instance_name + ); + } + Err(Error::NoToken) => { + debug!( + "Instance {} has no polling token, skipping configuration polling", + outcome.instance_name + ); + } + Err(err) => { + debug!( + "Failed to poll configuration for instance {}: {err}", + outcome.instance_name + ); + } + } + } +} + +async fn active_instance_ids(state: &State) -> Result, Error> { + let active_location_ids = active_state(&state.pool) + .await? + .into_iter() + .filter(|connection| connection.connection_type == ConnectionType::Location) + .map(|connection| connection.target_id) + .collect::>(); + + if active_location_ids.is_empty() { + return Ok(HashSet::new()); + } + + let location_instances = Location::all(&state.pool, false) + .await? + .into_iter() + .map(|location| (location.id, location.instance_id)) + .collect::>(); + + Ok(active_location_ids + .into_iter() + .filter_map(|location_id| location_instances.get(&location_id).copied()) + .collect()) +} diff --git a/src-tauri/client-cli/src/main.rs b/src-tauri/client-cli/src/main.rs index 923b2e6d8..81850fd2e 100644 --- a/src-tauri/client-cli/src/main.rs +++ b/src-tauri/client-cli/src/main.rs @@ -6,6 +6,7 @@ use common::check_version_flag; mod brand; mod cli; mod commands; +mod config_poll; mod exit; mod logging; mod mfa; @@ -51,6 +52,8 @@ async fn main() -> ExitCode { } }; + config_poll::poll_config(&state).await; + // Dispatch command. match cli.command { Commands::List => output::finish(list::handle(&state).await, cli.json), diff --git a/src-tauri/enterprise/config-sync/src/lib.rs b/src-tauri/enterprise/config-sync/src/lib.rs index 82bd18580..c1fd79cd6 100644 --- a/src-tauri/enterprise/config-sync/src/lib.rs +++ b/src-tauri/enterprise/config-sync/src/lib.rs @@ -1,12 +1,15 @@ #[macro_use] extern crate log; -use std::{cmp::Ordering, str::FromStr}; +use std::{cmp::Ordering, collections::HashSet, str::FromStr}; pub mod commands; use defguard_client_core::{ - database::models::{instance::Instance, Id}, + database::{ + models::{instance::Instance, Id}, + DbPool, + }, error::Error, proxy::post_with_headers, version::{MIN_CORE_VERSION, MIN_PROXY_VERSION}, @@ -17,7 +20,7 @@ use semver::Version; use serde::Serialize; use sqlx::{Sqlite, Transaction}; -use crate::commands::disable_enterprise_features; +use crate::commands::{disable_enterprise_features, do_update_instance}; static POLLING_ENDPOINT: &str = "/api/v1/poll"; @@ -32,6 +35,29 @@ pub struct FetchedConfig { pub version_mismatch: Option, } +/// Result of polling a single instance once. +#[derive(Debug)] +pub enum PollInstanceResult { + Unchanged { + version_mismatch: Option, + }, + Updated { + locations_changed: bool, + version_mismatch: Option, + }, + ChangedWhileActive { + version_mismatch: Option, + }, +} + +/// Outcome of polling a single instance in a batch. +#[derive(Debug)] +pub struct PollInstanceOutcome { + pub instance_id: Id, + pub instance_name: String, + pub result: Result, +} + /// Payload emitted when a version mismatch is detected. #[derive(Clone, Debug, Serialize)] pub struct VersionMismatchPayload { @@ -130,6 +156,82 @@ pub async fn fetch_instance_config( }) } +/// Polls one instance once and applies changed configuration only when safe. +/// +/// The caller owns scheduling, active-connection detection, and user-facing notifications. +pub async fn poll_instance( + transaction: &mut Transaction<'_, Sqlite>, + instance: &mut Instance, + has_active_connections: bool, +) -> Result { + let fetched = fetch_instance_config(transaction, instance).await?; + let version_mismatch = fetched.version_mismatch; + + let device_config = + fetched.response.device_config.as_ref().ok_or_else(|| { + Error::InternalError("Device config not present in response".to_string()) + })?; + + if !config_changed(transaction, instance, device_config).await? { + debug!( + "Config for instance {}({}) didn't change", + instance.name, instance.id + ); + return Ok(PollInstanceResult::Unchanged { version_mismatch }); + } + + debug!( + "Config for instance {}({}) changed", + instance.name, instance.id + ); + + if has_active_connections { + return Ok(PollInstanceResult::ChangedWhileActive { version_mismatch }); + } + + debug!( + "Updating instance {}({}) configuration: {device_config:?}", + instance.name, instance.id, + ); + let locations_changed = + do_update_instance(transaction, instance, device_config.clone()).await?; + info!( + "Updated instance {}({}) configuration based on core's response", + instance.name, instance.id + ); + + Ok(PollInstanceResult::Updated { + locations_changed, + version_mismatch, + }) +} + +/// Polls all instances that have a polling token and commits any safe configuration updates. +/// +/// The caller owns active-connection detection and all user-facing side effects. +pub async fn poll_instances( + pool: &DbPool, + active_instance_ids: &HashSet, +) -> Result, Error> { + let mut transaction = pool.begin().await?; + let mut instances = Instance::all_with_token(&mut *transaction).await?; + let mut outcomes = Vec::with_capacity(instances.len()); + + for instance in &mut instances { + let has_active_connections = active_instance_ids.contains(&instance.id); + let instance_id = instance.id; + let result = poll_instance(&mut transaction, instance, has_active_connections).await; + outcomes.push(PollInstanceOutcome { + instance_id, + instance_name: instance.name.clone(), + result, + }); + } + + transaction.commit().await?; + Ok(outcomes) +} + /// Checks if config has changed compared to what's in the database. pub async fn config_changed( transaction: &mut Transaction<'_, Sqlite>, @@ -298,10 +400,107 @@ fn check_min_version( #[cfg(test)] mod tests { - use defguard_client_core::database::models::instance::ClientTrafficPolicy; + use std::{ + collections::HashSet, + io::{ErrorKind, Read, Write}, + net::{SocketAddr, TcpListener, TcpStream}, + thread::{sleep, spawn, JoinHandle}, + time::Duration, + }; + + use defguard_client_core::database::models::{ + instance::ClientTrafficPolicy, + location::{Location, LocationMfaMode, ServiceLocationMode}, + NoId, + }; + use defguard_client_proto::defguard::client_types::{ + DeviceConfig, DeviceConfigResponse, InstanceInfo, + }; + use sqlx::SqlitePool; use super::*; + const READ_TIMEOUT: Duration = Duration::from_secs(5); + const CONNECT_TIMEOUT: Duration = Duration::from_millis(50); + const WAIT_TIMEOUT: Duration = Duration::from_millis(10); + + struct MockResponse { + status: u16, + body: String, + } + + struct MockPollServer { + addr: SocketAddr, + handle: Option>, + } + + impl MockPollServer { + fn new(responses: Vec) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + listener.set_nonblocking(true).unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = spawn(move || { + for response in responses { + let mut stream = loop { + match listener.accept() { + Ok((stream, _)) => break stream, + Err(ref err) if err.kind() == ErrorKind::WouldBlock => { + sleep(WAIT_TIMEOUT); + } + Err(_) => return, + } + }; + stream.set_nonblocking(false).ok(); + stream.set_read_timeout(Some(READ_TIMEOUT)).ok(); + let mut data = Vec::new(); + let mut buf = [0u8; 4096]; + loop { + match stream.read(&mut buf) { + Ok(0) => break, + Ok(n) => { + data.extend_from_slice(&buf[..n]); + if data.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + Err(_) => break, + } + } + + let body = format!( + "HTTP/1.1 {} OK\r\nContent-Type: application/json\r\n{}: 1.6.0\r\n{}: 1.6.0\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response.status, + CORE_VERSION_HEADER, + PROXY_VERSION_HEADER, + response.body.len(), + response.body, + ); + let _ = stream.write_all(body.as_bytes()); + } + }); + + Self { + addr, + handle: Some(handle), + } + } + + fn url(&self) -> String { + format!("http://{}/", self.addr) + } + } + + impl Drop for MockPollServer { + fn drop(&mut self) { + // Unblock accept if the test did not consume all prepared responses. + let _ = TcpStream::connect_timeout(&self.addr, CONNECT_TIMEOUT); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } + } + } + fn instance_with_token(token: Option<&str>) -> Instance { Instance { id: 1, @@ -325,6 +524,110 @@ mod tests { reqwest::Response::from(builder.body(String::new()).unwrap()) } + fn instance_info(name: &str, proxy_url: &str) -> InstanceInfo { + InstanceInfo { + id: format!("uuid-{name}"), + name: name.into(), + url: format!("https://{name}.example"), + proxy_url: proxy_url.into(), + username: "alice".into(), + enterprise_enabled: true, + ..Default::default() + } + } + + fn device_config(network_id: Id, name: &str, endpoint: &str) -> DeviceConfig { + DeviceConfig { + network_id, + network_name: name.into(), + endpoint: endpoint.into(), + assigned_ip: "10.6.0.2".into(), + pubkey: format!("pk-{network_id}"), + allowed_ips: "0.0.0.0/0".into(), + keepalive_interval: 25, + ..Default::default() + } + } + + fn device_config_response( + instance: &Instance, + config: DeviceConfig, + ) -> DeviceConfigResponse { + DeviceConfigResponse { + instance: Some(instance_info(&instance.name, &instance.proxy_url)), + configs: vec![config], + token: instance.token.clone(), + ..Default::default() + } + } + + fn poll_response(response: DeviceConfigResponse) -> MockResponse { + let body = serde_json::to_string(&InstanceInfoResponse { + device_config: Some(response), + }) + .unwrap(); + MockResponse { status: 200, body } + } + + fn error_response() -> MockResponse { + MockResponse { + status: 500, + body: "not-json".into(), + } + } + + async fn seed_instance( + pool: &SqlitePool, + name: &str, + proxy_url: &str, + token: Option<&str>, + ) -> Instance { + Instance { + id: NoId, + name: name.into(), + uuid: format!("uuid-{name}"), + url: format!("https://{name}.example"), + proxy_url: proxy_url.into(), + username: "alice".into(), + token: token.map(str::to_string), + client_traffic_policy: ClientTrafficPolicy::None, + enterprise_enabled: true, + openid_display_name: None, + } + .save(pool) + .await + .unwrap() + } + + async fn seed_location( + pool: &SqlitePool, + instance_id: Id, + network_id: Id, + name: &str, + endpoint: &str, + ) -> Location { + Location { + id: NoId, + instance_id, + network_id, + name: name.into(), + address: "10.6.0.2".into(), + pubkey: format!("pk-{network_id}"), + endpoint: endpoint.into(), + allowed_ips: "0.0.0.0/0".into(), + dns: None, + route_all_traffic: false, + keepalive_interval: 25, + location_mfa_mode: LocationMfaMode::Disabled, + service_location_mode: ServiceLocationMode::Disabled, + mfa_method: None, + posture_check_required: false, + } + .save(pool) + .await + .unwrap() + } + #[test] fn test_build_request_no_token_errors() { let instance = instance_with_token(None); @@ -383,4 +686,118 @@ mod tests { let instance = instance_with_token(Some("tok")); assert!(check_min_version(&response, &instance).is_none()); } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_config_changed_false_when_instance_and_locations_match(pool: SqlitePool) { + let instance = seed_instance(&pool, "acme", "https://proxy.example", Some("tok")).await; + seed_location(&pool, instance.id, 1, "office", "1.2.3.4:51820").await; + let response = + device_config_response(&instance, device_config(1, "office", "1.2.3.4:51820")); + + let mut transaction = pool.begin().await.unwrap(); + let changed = config_changed(&mut transaction, &instance, &response) + .await + .unwrap(); + + assert!(!changed); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_config_changed_true_when_instance_metadata_changes(pool: SqlitePool) { + let instance = seed_instance(&pool, "acme", "https://proxy.example", Some("tok")).await; + seed_location(&pool, instance.id, 1, "office", "1.2.3.4:51820").await; + let mut response = + device_config_response(&instance, device_config(1, "office", "1.2.3.4:51820")); + response.instance.as_mut().unwrap().name = "renamed".into(); + + let mut transaction = pool.begin().await.unwrap(); + let changed = config_changed(&mut transaction, &instance, &response) + .await + .unwrap(); + + assert!(changed); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_config_changed_true_when_location_changes(pool: SqlitePool) { + let instance = seed_instance(&pool, "acme", "https://proxy.example", Some("tok")).await; + seed_location(&pool, instance.id, 1, "office", "1.2.3.4:51820").await; + let response = + device_config_response(&instance, device_config(1, "office", "5.6.7.8:51820")); + + let mut transaction = pool.begin().await.unwrap(); + let changed = config_changed(&mut transaction, &instance, &response) + .await + .unwrap(); + + assert!(changed); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_poll_instance_changed_while_active_does_not_update_db(pool: SqlitePool) { + let mut instance = seed_instance(&pool, "acme", "https://proxy.example", Some("tok")).await; + seed_location(&pool, instance.id, 1, "office", "1.2.3.4:51820").await; + + let response = + device_config_response(&instance, device_config(1, "office", "5.6.7.8:51820")); + let server = MockPollServer::new(vec![poll_response(response)]); + instance.proxy_url = server.url(); + instance.save(&pool).await.unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + let result = poll_instance(&mut transaction, &mut instance, true) + .await + .unwrap(); + transaction.commit().await.unwrap(); + + assert!(matches!( + result, + PollInstanceResult::ChangedWhileActive { .. } + )); + let location = Location::find_by_instance_id(&pool, instance.id, true) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(location.endpoint, "1.2.3.4:51820"); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_poll_instances_returns_success_and_error_outcomes(pool: SqlitePool) { + let error_server = MockPollServer::new(vec![error_response()]); + + let instance_active = + seed_instance(&pool, "active", "https://proxy.example", Some("tok-1")).await; + seed_location(&pool, instance_active.id, 1, "office", "1.2.3.4:51820").await; + let response = device_config_response( + &instance_active, + device_config(1, "office", "5.6.7.8:51820"), + ); + let success_server = MockPollServer::new(vec![poll_response(response)]); + let mut instance_active = instance_active; + instance_active.proxy_url = success_server.url(); + instance_active.save(&pool).await.unwrap(); + + let instance_error = + seed_instance(&pool, "error", &error_server.url(), Some("tok-2")).await; + + let outcomes = poll_instances(&pool, &HashSet::from([instance_active.id])) + .await + .unwrap(); + + assert_eq!(outcomes.len(), 2); + let active_outcome = outcomes + .iter() + .find(|outcome| outcome.instance_id == instance_active.id) + .unwrap(); + assert!(matches!( + active_outcome.result, + Ok(PollInstanceResult::ChangedWhileActive { .. }) + )); + let error_outcome = outcomes + .iter() + .find(|outcome| outcome.instance_id == instance_error.id) + .unwrap(); + assert!(error_outcome.result.is_err()); + } } diff --git a/src-tauri/src/appstate.rs b/src-tauri/src/appstate.rs index 2b2c08ca1..ec013e1cb 100644 --- a/src-tauri/src/appstate.rs +++ b/src-tauri/src/appstate.rs @@ -10,11 +10,11 @@ use tokio_util::sync::CancellationToken; use crate::{ app_config::AppConfig, database::models::{connection::ActiveConnection, Id}, - enterprise::provisioning::ProvisioningConfig, session_state::SessionState, utils::stats_handler, ConnectionType, }; +use defguard_client_provisioning::ProvisioningConfig; pub struct AppState { pub log_watchers: Mutex>, diff --git a/src-tauri/src/bin/defguard-client.rs b/src-tauri/src/bin/defguard-client.rs index f284b794d..637f4bfa9 100644 --- a/src-tauri/src/bin/defguard-client.rs +++ b/src-tauri/src/bin/defguard-client.rs @@ -24,9 +24,9 @@ use defguard_client::{ models::{location_stats::LocationStats, tunnel::TunnelStats}, DB_POOL, }, - enterprise::provisioning::handle_client_initialization, events::handle_deep_link, periodic::run_periodic_tasks, + provisioning::handle_client_initialization, service, session_state, tray::{configure_tray_icon, setup_tray}, utils::load_log_targets, diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 097b72831..4711a9a99 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -8,6 +8,7 @@ use defguard_client_core::connection::{ active_connections::{find_connection, get_connection_id_by_type, ACTIVE_CONNECTIONS}, disconnect_interface, }; +use defguard_client_posture::authorize_posture_session; #[cfg(not(target_os = "macos"))] use defguard_client_proto::defguard::client::v1::{ DeleteServiceLocationsRequest, RemoveInterfaceRequest, SaveServiceLocationsRequest, @@ -15,6 +16,8 @@ use defguard_client_proto::defguard::client::v1::{ use defguard_client_proto::defguard::{ client_types::DeviceConfigResponse, enterprise::posture::v2::DevicePostureData, }; +use defguard_client_provisioning::ProvisioningConfig; +use defguard_client_service_locations::to_service_location; use serde::{Deserialize, Serialize}; use struct_patch::Patch; use tauri::{AppHandle, Emitter, Manager, State}; @@ -38,12 +41,6 @@ use crate::{ }, DB_POOL, }, - enterprise::{ - self, - periodic::config::{do_update_instance, poll_instance}, - posture::authorize_posture_session, - provisioning::ProvisioningConfig, - }, error::Error, events::EventKey, into_location, @@ -51,6 +48,7 @@ use crate::{ global_log_watcher::{spawn_global_log_watcher_task, stop_global_log_watcher_task}, service_log_watcher::stop_log_watcher_task, }, + periodic::config::{do_update_instance, poll_instance_with_events}, proxy::construct_platform_header, tray::{configure_tray_icon, reload_tray_menu}, utils::{ @@ -310,7 +308,7 @@ async fn maybe_update_instance_config(location_id: Id, handle: &AppHandle) -> Re ); return Err(Error::NotFound); }; - poll_instance(&mut transaction, &mut instance, handle).await?; + poll_instance_with_events(&mut transaction, &mut instance, handle).await?; transaction.commit().await?; handle .emit(EventKey::InstanceUpdate.into(), ()) @@ -441,9 +439,7 @@ async fn push_service_locations( "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", saved_location.name, saved_location.id, instance.name, instance.id, ); - service_locations.push(crate::enterprise::service_locations::to_service_location( - saved_location, - )?); + service_locations.push(to_service_location(saved_location)?); } } @@ -1291,7 +1287,7 @@ pub fn get_platform_header() -> String { #[tauri::command(async)] pub async fn get_posture_data() -> Result { debug!("Received a command to prepare posture report"); - enterprise::posture::get_posture_data().await + defguard_client_posture::get_posture_data().await } #[derive(Debug, Serialize)] diff --git a/src-tauri/src/enterprise/mod.rs b/src-tauri/src/enterprise/mod.rs deleted file mode 100644 index e5b0dc05c..000000000 --- a/src-tauri/src/enterprise/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub use defguard_client_posture::{inspector, posture}; -pub use defguard_client_provisioning::{try_get_provisioning_config, ProvisioningConfig}; -pub mod periodic; -pub mod provisioning; -pub mod service_locations; diff --git a/src-tauri/src/enterprise/periodic/config.rs b/src-tauri/src/enterprise/periodic/config.rs deleted file mode 100644 index 4a0b48aee..000000000 --- a/src-tauri/src/enterprise/periodic/config.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::{ - collections::HashSet, - sync::{LazyLock, Mutex}, - time::Duration, -}; - -pub use defguard_client_config_sync::commands::{ - disable_enterprise_features, do_update_instance, locations_changed, -}; -use defguard_client_config_sync::{config_changed, fetch_instance_config}; -use defguard_client_core::{ - connection::active_connections::active_connections, - database::{ - models::{instance::Instance, Id}, - DB_POOL, - }, - error::Error, - events::EventKey, -}; -use log::{debug, error, info}; -use sqlx::{Sqlite, Transaction}; -use tauri::{AppHandle, Emitter}; -use tokio::time::sleep; - -const INTERVAL_SECONDS: Duration = Duration::from_secs(30); - -/// Tracks instance IDs for which we already sent a version-mismatch notification, -/// to prevent duplicate notifications in the app's lifetime. -static NOTIFIED_INSTANCES: LazyLock>> = - LazyLock::new(|| Mutex::new(HashSet::new())); - -/// Periodically retrieves and updates configuration for all [`Instance`]s. -/// Updates are only performed if no connections are established to the [`Instance`], -/// otherwise event is emitted and UI message is displayed. -pub async fn poll_config(handle: AppHandle) { - debug!("Starting the configuration polling loop."); - // Polling starts sooner than app's frontend may load in dev builds, causing events (toasts) - // to be lost; you may want to wait here before starting if you want to debug it. - loop { - let Ok(mut transaction) = DB_POOL.begin().await else { - error!( - "Failed to begin database transaction for config polling, retrying in {}s", - INTERVAL_SECONDS.as_secs() - ); - sleep(INTERVAL_SECONDS).await; - continue; - }; - let Ok(mut instances) = Instance::all_with_token(&mut *transaction).await else { - error!( - "Failed to retrieve instances for config polling, retrying in {}s", - INTERVAL_SECONDS.as_secs() - ); - let _ = transaction.rollback().await; - sleep(INTERVAL_SECONDS).await; - continue; - }; - debug!( - "Found {} instances with a config polling token, proceeding with polling their \ - configuration.", - instances.len() - ); - let mut config_retrieved = 0; - for instance in &mut instances { - if instance.token.is_some() { - if let Err(err) = poll_instance(&mut transaction, instance, &handle).await { - match err { - Error::CoreNotEnterprise => { - debug!( - "Tried to contact core for instance {instance} config but it's \ - not enterprise, can't retrieve config" - ); - } - Error::NoToken => { - debug!( - "Instance {instance} has no token, can't retrieve its config from \ - the core", - ); - } - _ => { - error!( - "Failed to retrieve instance {instance} config from core: {err}" - ); - } - } - } else { - config_retrieved += 1; - debug!( - "Finished processing configuration polling request for instance {instance}" - ); - } - } - } - if let Err(err) = transaction.commit().await { - error!( - "Failed to commit config polling transaction, configuration won't be updated: \ - {err}" - ); - } - if let Err(err) = handle.emit(EventKey::InstanceUpdate.into(), ()) { - error!("Failed to emit instance update event to the frontend: {err}"); - } - if config_retrieved > 0 { - info!( - "Automatically retrieved the newest instance configuration from core for \ - {config_retrieved} instances, sleeping for {}s", - INTERVAL_SECONDS.as_secs(), - ); - debug!("Instances for which configuration was retrieved from core: {instances:?}"); - } else { - debug!( - "No configuration updates retrieved, sleeping {}s", - INTERVAL_SECONDS.as_secs(), - ); - } - sleep(INTERVAL_SECONDS).await; - } -} - -/// Retrieves configuration for a given [`Instance`]. -/// Updates the instance if there aren't any active connections, otherwise emits -/// a ConfigChanged event so the frontend can prompt the user to reconnect. -pub async fn poll_instance( - transaction: &mut Transaction<'_, Sqlite>, - instance: &mut Instance, - handle: &AppHandle, -) -> Result<(), Error> { - let fetched = fetch_instance_config(transaction, instance).await?; - - // Emit version-mismatch event if applicable and not already notified - if let Some(payload) = fetched.version_mismatch { - let mut notified_instances = NOTIFIED_INSTANCES.lock().unwrap(); - if notified_instances.insert(instance.id) { - if let Err(err) = handle.emit(EventKey::VersionMismatch.into(), payload) { - error!("Failed to emit version mismatch event to the frontend: {err}"); - // Remove so we can retry next cycle - notified_instances.remove(&instance.id); - } - } - } - - let device_config = - fetched.response.device_config.as_ref().ok_or_else(|| { - Error::InternalError("Device config not present in response".to_string()) - })?; - - // Early return if config didn't change - if !config_changed(transaction, instance, device_config).await? { - debug!( - "Config for instance {}({}) didn't change", - instance.name, instance.id - ); - return Ok(()); - } - - debug!( - "Config for instance {}({}) changed", - instance.name, instance.id - ); - - // Config changed. If there are no active connections for this instance, update the database. - // Otherwise just display a message to reconnect. - if active_connections(instance).await?.is_empty() { - debug!( - "Updating instance {}({}) configuration: {device_config:?}", - instance.name, instance.id, - ); - let locations_changed = - do_update_instance(transaction, instance, device_config.clone()).await?; - info!( - "Updated instance {}({}) configuration based on core's response", - instance.name, instance.id - ); - if locations_changed { - if let Err(err) = handle.emit(EventKey::InstanceUpdated.into(), ()) { - error!("Failed to emit instance-updated event: {err}"); - } - } - } else { - debug!( - "Emitting config-changed event for instance {}({})", - instance.name, instance.id, - ); - let _ = handle.emit(EventKey::ConfigChanged.into(), &instance.name); - info!( - "Emitted config-changed event for instance {}({})", - instance.name, instance.id, - ); - } - - Ok(()) -} diff --git a/src-tauri/src/enterprise/periodic/mod.rs b/src-tauri/src/enterprise/periodic/mod.rs deleted file mode 100644 index ef68c3694..000000000 --- a/src-tauri/src/enterprise/periodic/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod config; diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs deleted file mode 100644 index b67d780e4..000000000 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -#[cfg(windows)] -pub use defguard_client_service_locations::windows; -pub use defguard_client_service_locations::{ - to_service_location, ServiceLocationData, ServiceLocationError, ServiceLocationManager, - SingleServiceLocationData, -}; diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs deleted file mode 100644 index 2cf9b8e75..000000000 --- a/src-tauri/src/enterprise/service_locations/windows.rs +++ /dev/null @@ -1,968 +0,0 @@ -use std::{ - collections::HashMap, - ffi::OsStr, - fs::{self, create_dir_all}, - path::PathBuf, - result::Result, - str::FromStr, - sync::{Arc, RwLock}, - time::Duration, -}; - -use defguard_client_common::{dns_borrow, find_free_tcp_port, get_interface_name}; -use defguard_wireguard_rs::{ - key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration, WireguardInterfaceApi, -}; -use known_folders::get_known_folder_path; -use log::{debug, error, warn}; -use windows::{ - core::PSTR, - Win32::System::RemoteDesktop::{ - self, WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, - WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, - }, -}; -use windows_acl::acl::ACL; -use windows_sys::Win32::NetworkManagement::IpHelper::NotifyAddrChange; - -use crate::{ - enterprise::service_locations::{ - ServiceLocationData, ServiceLocationError, ServiceLocationManager, - SingleServiceLocationData, - }, - service::{ - daemon::setup_wgapi, - proto::defguard::client::v1::{ServiceLocation, ServiceLocationMode}, - }, -}; - -const LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS: u64 = 5; -// How long to wait after a network change before attempting to connect. -// Gives DHCP time to complete and DNS to become available. -const NETWORK_STABILIZATION_DELAY: Duration = Duration::from_secs(3); -// How long to wait before restarting the network change watcher on error. -const NETWORK_CHANGE_MONITOR_RESTART_DELAY: Duration = Duration::from_secs(5); -const DEFAULT_WIREGUARD_PORT: u16 = 51820; -const DEFGUARD_DIR: &str = "Defguard"; -const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; - -/// Watches for IP address changes on any network interface and attempts to connect to any -/// service locations that are not yet connected. This handles the case where the endpoint -/// hostname cannot be resolved at service startup because the network (e.g. Wi-Fi) is not -/// yet available. When the network comes up and an IP is assigned, this watcher fires and -/// retries the connection. -/// -/// Note: `NotifyAddrChange` also fires when WireGuard interfaces are created. This is -/// harmless because `connect_to_service_locations` skips already-connected locations. -/// -/// Runs on a dedicated OS thread because `NotifyAddrChange` is a blocking syscall. -pub(crate) fn watch_for_network_change( - service_location_manager: Arc>, -) { - loop { - // NotifyAddrChange blocks until any IP address is added or removed on any interface. - // Passing NULL for both handle and overlapped selects the synchronous (blocking) mode. - let result = unsafe { NotifyAddrChange(std::ptr::null_mut(), std::ptr::null()) }; - - if result != 0 { - error!("NotifyAddrChange failed with error code: {result}"); - std::thread::sleep(NETWORK_CHANGE_MONITOR_RESTART_DELAY); - continue; - } - - debug!( - "Network address change detected, waiting {NETWORK_STABILIZATION_DELAY:?}s for \ - network to stabilize before attempting service location connections..." - ); - std::thread::sleep(NETWORK_STABILIZATION_DELAY); - - debug!("Attempting to connect to service locations after network change"); - let connect_result = service_location_manager - .write() - .unwrap() - .connect_to_service_locations(); - match connect_result { - Ok(_) => { - debug!("Service location connect attempt after network change completed"); - } - Err(err) => { - warn!("Failed to connect to service locations after network change: {err}"); - } - } - } -} - -/// Watches for user logon/logoff events and connects/disconnects pre-logon service locations -/// accordingly. -/// -/// Runs on a dedicated OS thread because `WTSWaitSystemEvent` is a blocking syscall. -pub(crate) fn watch_for_login_logoff( - service_location_manager: Arc>, -) -> Result<(), ServiceLocationError> { - loop { - let mut event_flags: u32 = 0; - let success = unsafe { - WTSWaitSystemEvent( - Some(WTS_CURRENT_SERVER_HANDLE), - WTS_EVENT_LOGON | WTS_EVENT_LOGOFF, - &mut event_flags, - ) - }; - - match success { - Ok(_) => { - debug!("Waiting for system event returned with event_flags: 0x{event_flags:x}"); - } - Err(err) => { - error!("Failed waiting for login/logoff event: {err:?}"); - std::thread::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)); - continue; - } - }; - - if event_flags & WTS_EVENT_LOGON != 0 { - debug!("Detected user logon, attempting to auto-disconnect from service locations."); - service_location_manager - .write() - .unwrap() - .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; - } - if event_flags & WTS_EVENT_LOGOFF != 0 { - debug!("Detected user logoff, attempting to auto-connect to service locations."); - service_location_manager - .write() - .unwrap() - .connect_to_service_locations()?; - } - } -} - -fn get_shared_directory() -> Result { - match get_known_folder_path(known_folders::KnownFolder::ProgramData) { - Some(mut path) => { - path.push(DEFGUARD_DIR); - path.push(SERVICE_LOCATIONS_SUBDIR); - Ok(path) - } - None => Err(ServiceLocationError::LoadError( - "Could not find ProgramData known folder".to_string(), - )), - } -} - -fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { - debug!("Setting secure ACLs on: {path}"); - - const SYSTEM_SID: &str = "S-1-5-18"; // NT AUTHORITY\SYSTEM - const ADMINISTRATORS_SID: &str = "S-1-5-32-544"; // BUILTIN\Administrators - - const FILE_ALL_ACCESS: u32 = 0x001F_01FF; - - match ACL::from_file_path(path, false) { - Ok(mut acl) => { - // Remove everything else from access - debug!("Removing all existing ACL entries for {path}"); - let all_entries = acl.all().map_err(|e| { - ServiceLocationError::LoadError(format!("Failed to get ACL entries: {e}")) - })?; - - for entry in all_entries { - if let Some(sid) = entry.sid { - if let Err(e) = acl.remove(sid.as_ptr() as *mut _, None, None) { - debug!("Note: Could not remove ACL entry (might be expected): {e}"); - } - } - } - - debug!("Cleared existing ACL entries, now adding secure entries"); - - // Add SYSTEM with full control - debug!("Adding SYSTEM with full control"); - let system_sid_result = windows_acl::helper::string_to_sid(SYSTEM_SID); - match system_sid_result { - Ok(system_sid) => { - acl.allow(system_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) - .map_err(|e| { - ServiceLocationError::LoadError(format!( - "Failed to add SYSTEM ACL: {e}" - )) - })?; - } - Err(e) => { - return Err(ServiceLocationError::LoadError(format!( - "Failed to convert SYSTEM SID: {e}" - ))); - } - } - - // Add Administrators with full control - debug!("Adding Administrators with full control"); - let admin_sid_result = windows_acl::helper::string_to_sid(ADMINISTRATORS_SID); - match admin_sid_result { - Ok(admin_sid) => { - acl.allow(admin_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) - .map_err(|e| { - ServiceLocationError::LoadError(format!( - "Failed to add Administrators ACL: {e}" - )) - })?; - } - Err(e) => { - return Err(ServiceLocationError::LoadError(format!( - "Failed to convert Administrators SID: {e}" - ))); - } - } - - debug!("Successfully set secure ACLs on {path} for SYSTEM and Administrators"); - Ok(()) - } - Err(e) => { - error!("Failed to get ACL for {path}: {e}"); - Err(ServiceLocationError::LoadError(format!( - "Failed to get ACL for {path}: {e}" - ))) - } - } -} - -fn get_instance_file_path(instance_id: &str) -> Result { - let mut path = get_shared_directory()?; - path.push(format!("{instance_id}.json")); - Ok(path) -} - -pub(crate) fn is_user_logged_in() -> bool { - debug!("Starting checking if user is logged in..."); - - unsafe { - let mut pp_sessions: *mut WTS_SESSION_INFOA = std::ptr::null_mut(); - let mut count: u32 = 0; - - debug!("Calling WTSEnumerateSessionsA..."); - let ret = RemoteDesktop::WTSEnumerateSessionsA(None, 0, 1, &mut pp_sessions, &mut count); - - match ret { - Ok(_) => { - debug!("WTSEnumerateSessionsA succeeded, found {count} sessions"); - let sessions = std::slice::from_raw_parts(pp_sessions, count as usize); - - for (index, session) in sessions.iter().enumerate() { - debug!( - "Session {index}: SessionId={}, State={:?}, WinStationName={:?}", - session.SessionId, - session.State, - std::ffi::CStr::from_ptr(session.pWinStationName.0 as *const i8) - .to_string_lossy() - ); - - if session.State == windows::Win32::System::RemoteDesktop::WTSActive { - let mut buffer = PSTR::null(); - let mut bytes_returned: u32 = 0; - - let result = WTSQuerySessionInformationA( - None, - session.SessionId, - windows::Win32::System::RemoteDesktop::WTSUserName, - &mut buffer, - &mut bytes_returned, - ); - - match result { - Ok(_) => { - if !buffer.is_null() { - let username = std::ffi::CStr::from_ptr(buffer.0 as *const i8) - .to_string_lossy() - .into_owned(); - - debug!( - "Found session {} username: {username}", - session.SessionId - ); - - windows::Win32::System::RemoteDesktop::WTSFreeMemory( - buffer.0 as *mut _, - ); - - // We found an active session with a username. - // Free the session list before returning to avoid a leak. - windows::Win32::System::RemoteDesktop::WTSFreeMemory( - pp_sessions as _, - ); - return true; - } - } - Err(err) => { - debug!( - "Failed to get username for session {}: {err:?}", - session.SessionId - ); - } - } - } - } - windows::Win32::System::RemoteDesktop::WTSFreeMemory(pp_sessions as _); - debug!("No active sessions found"); - } - Err(err) => { - error!("Failed to enumerate user sessions: {err:?}"); - debug!("WTSEnumerateSessionsA failed: {err:?}"); - } - } - } - - debug!("User is not logged in."); - false -} - -impl ServiceLocationManager { - pub fn init() -> Result { - debug!("Initializing ServiceLocationApi"); - let path = get_shared_directory()?; - - debug!("Creating directory: {path:?}"); - create_dir_all(&path)?; - - if let Some(path_str) = path.to_str() { - debug!("Setting ACLs on service locations directory"); - if let Err(e) = set_protected_acls(path_str) { - warn!("Failed to set ACLs on service locations directory: {e}. Continuing anyway."); - } - } else { - warn!("Failed to convert path to string for ACL setting"); - } - - let manager = Self { - wgapis: HashMap::new(), - connected_service_locations: HashMap::new(), - }; - - debug!("ServiceLocationApi initialized successfully"); - Ok(manager) - } - - /// Check if a specific service location is already connected - fn is_service_location_connected(&self, instance_id: &str, location_pubkey: &str) -> bool { - if let Some(locations) = self.connected_service_locations.get(instance_id) { - for location in locations { - if location.pubkey == location_pubkey { - return true; - } - } - } - false - } - - /// Add a connected service location - fn add_connected_service_location( - &mut self, - instance_id: &str, - location: &ServiceLocation, - ) -> Result<(), ServiceLocationError> { - self.connected_service_locations - .entry(instance_id.to_string()) - .or_default() - .push(location.clone()); - - debug!( - "Added connected service location for instance '{instance_id}', location '{}'", - location.name - ); - Ok(()) - } - - /// Remove connected service locations by filter (write disk-first, then memory) - fn remove_connected_service_locations( - &mut self, - filter: F, - ) -> Result<(), ServiceLocationError> - where - F: Fn(&str, &ServiceLocation) -> bool, - { - // Iterate through connected_service_locations and remove matching locations - let mut instances_to_remove = Vec::new(); - - for (instance_id, locations) in self.connected_service_locations.iter_mut() { - locations.retain(|location| !filter(instance_id, location)); - - // Mark instance for removal if it has no more locations - if locations.is_empty() { - instances_to_remove.push(instance_id.clone()); - } - } - - // Remove instances with no locations - for instance_id in instances_to_remove { - self.connected_service_locations.remove(&instance_id); - } - - debug!("Removed connected service locations matching filter"); - Ok(()) - } - - // Resets the state of the service location: - // 1. If it's an always on location, disconnects and reconnects it. - // 2. Otherwise, just disconnects it if the user is not logged in. - pub(crate) fn reset_service_location_state( - &mut self, - instance_id: &str, - location_pubkey: &str, - ) -> Result<(), ServiceLocationError> { - debug!( - "Reseting the state of service location for instance_id: {instance_id}, \ - location_pubkey: {location_pubkey}" - ); - - let service_location_data = self - .load_service_location(instance_id, location_pubkey)? - .ok_or_else(|| { - ServiceLocationError::LoadError(format!( - "Service location with pubkey {} for instance {} not found", - location_pubkey, instance_id - )) - })?; - - debug!( - "Disconnecting service location for instance_id: {instance_id}, location_pubkey: \ - {location_pubkey} ({})", - service_location_data.service_location.name - ); - - self.disconnect_service_location(instance_id, location_pubkey)?; - - debug!( - "Disconnected service location for instance_id: {instance_id}, \ - location_pubkey: {location_pubkey} ({})", - service_location_data.service_location.name - ); - - debug!( - "Reconnecting service location if needed for instance_id: {instance_id}, \ - location_pubkey: {location_pubkey} ({})", - service_location_data.service_location.name - ); - - // We should reconnect only if: - // 1. It's an always on location - // 2. It's a pre-logon location and the user is not logged in - if service_location_data.service_location.mode == ServiceLocationMode::AlwaysOn as i32 - || (service_location_data.service_location.mode == ServiceLocationMode::PreLogon as i32 - && !is_user_logged_in()) - { - debug!( - "Reconnecting service location for instance_id: {instance_id}, location_pubkey: \ - {location_pubkey} ({})", - service_location_data.service_location.name - ); - self.connect_to_service_location(&service_location_data)?; - } - - debug!("Service location state reset completed."); - - Ok(()) - } - - pub(crate) fn disconnect_service_locations_by_instance( - &mut self, - instance_id: &str, - ) -> Result<(), ServiceLocationError> { - debug!("Disconnecting all service locations for instance_id: {instance_id}"); - - if let Some(locations) = self.connected_service_locations.get(instance_id) { - // Collect locations to disconnect to avoid borrowing issues - let locations_to_disconnect = locations.to_vec(); - - for location in locations_to_disconnect { - let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {ifname}"); - if let Some(mut wgapi) = self.wgapis.remove(&ifname) { - if let Err(err) = wgapi.remove_interface() { - error!("Failed to remove interface {ifname}: {err}"); - } else { - debug!("Interface {ifname} removed successfully"); - } - debug!( - "Removing connected service location for instance_id: {instance_id}, \ - location_pubkey: {}", - location.pubkey - ); - debug!( - "Disconnected service location for instance_id: {instance_id}, \ - location_pubkey: {}", - location.pubkey - ); - } else { - error!("Failed to find WireGuard API for interface {ifname}"); - } - } - - self.connected_service_locations.remove(instance_id); - } else { - debug!( - "No connected service locations found for instance_id: {instance_id}. Skipping disconnect" - ); - return Ok(()); - } - - debug!("Disconnected all service locations for instance_id: {instance_id}"); - - Ok(()) - } - - pub(crate) fn disconnect_service_location( - &mut self, - instance_id: &str, - location_pubkey: &str, - ) -> Result<(), ServiceLocationError> { - debug!( - "Disconnecting service location for instance_id: {instance_id}, location_pubkey: \ - {location_pubkey}" - ); - - if let Some(locations) = self.connected_service_locations.get_mut(instance_id) { - if let Some(pos) = locations - .iter() - .position(|loc| loc.pubkey == location_pubkey) - { - let location = locations.remove(pos); - let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {ifname}"); - if let Some(mut wgapi) = self.wgapis.remove(&ifname) { - if let Err(err) = wgapi.remove_interface() { - error!("Failed to remove interface {ifname}: {err}"); - } else { - debug!("Interface {ifname} removed successfully."); - } - } else { - error!("Failed to find WireGuard API for interface {ifname}. "); - } - } else { - debug!( - "Service location with pubkey {location_pubkey} for instance {instance_id} is \ - not connected, skipping disconnect" - ); - return Ok(()); - } - } else { - debug!( - "No connected service locations found for instance_id: {instance_id}, skipping \ - disconnect" - ); - return Ok(()); - } - - debug!( - "Disconnected service location for instance_id: {instance_id}, location_pubkey: \ - {location_pubkey}" - ); - - Ok(()) - } - - /// Helper function to setup a WireGuard interface for a service location - fn setup_service_location_interface( - &mut self, - location: &ServiceLocation, - private_key: &str, - ) -> Result<(), ServiceLocationError> { - let peer_key = Key::from_str(&location.pubkey)?; - - let mut peer = Peer::new(peer_key.clone()); - peer.set_endpoint(&location.endpoint)?; - - peer.persistent_keepalive_interval = location.keepalive_interval.try_into().ok(); - - let allowed_ips = location - .allowed_ips - .split(',') - .map(str::to_string) - .collect::>(); - - for allowed_ip in &allowed_ips { - match IpAddrMask::from_str(allowed_ip) { - Ok(addr) => { - peer.allowed_ips.push(addr); - } - Err(err) => { - error!( - "Error parsing IP address {allowed_ip} while setting up interface for \ - location {location:?}, error details: {err}" - ); - } - } - } - - let mut addresses = Vec::new(); - - for address in location.address.split(',') { - addresses.push(IpAddrMask::from_str(address.trim())?); - } - - let config = InterfaceConfiguration { - name: location.name.clone(), - prvkey: private_key.to_string(), - addresses, - port: find_free_tcp_port().unwrap_or(DEFAULT_WIREGUARD_PORT), - peers: vec![peer.clone()], - mtu: None, - fwmark: None, // TODO: add - }; - - let ifname = location.name.clone(); - let ifname = get_interface_name(&ifname); - let mut wgapi = match setup_wgapi(&ifname) { - Ok(api) => api, - Err(err) => { - let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err:?}"); - debug!("{msg}"); - return Err(ServiceLocationError::InterfaceError(msg)); - } - }; - - wgapi.create_interface()?; - - // Extract DNS configuration if available - let dns_config = Some(location.dns.clone()); - let (dns, search_domains) = dns_borrow(&dns_config); - debug!( - "Configuring interface {ifname} with DNS: {dns:?} and search domains: \ - {search_domains:?}", - ); - debug!("Interface Configuration: {config:?}"); - - wgapi.configure_interface(&config)?; - wgapi.configure_dns(&dns, &search_domains)?; - - self.wgapis.insert(ifname.clone(), wgapi); - - debug!("Interface {ifname} configured successfully."); - Ok(()) - } - - pub(crate) fn connect_to_service_location( - &mut self, - location_data: &SingleServiceLocationData, - ) -> Result<(), ServiceLocationError> { - let instance_id = &location_data.instance_id; - let location_pubkey = &location_data.service_location.pubkey; - debug!( - "Connecting to service location for instance_id: {instance_id}, location_pubkey: \ - {location_pubkey}" - ); - - // Check if already connected to this service location - if self.is_service_location_connected(instance_id, location_pubkey) { - debug!( - "Service location with pubkey {location_pubkey} for instance {instance_id} is \ - already connected, skipping" - ); - return Ok(()); - } - - let location_data = self - .load_service_location(instance_id, location_pubkey)? - .ok_or_else(|| { - ServiceLocationError::LoadError(format!( - "Service location with pubkey {location_pubkey} for instance {instance_id} not \ - found", - )) - })?; - - self.setup_service_location_interface( - &location_data.service_location, - &location_data.private_key, - )?; - self.add_connected_service_location( - &location_data.instance_id, - &location_data.service_location, - )?; - let ifname = get_interface_name(&location_data.service_location.name); - debug!("Successfully connected to service location '{ifname}'"); - - Ok(()) - } - - pub(crate) fn disconnect_service_locations( - &mut self, - mode: Option, - ) -> Result<(), ServiceLocationError> { - debug!("Disconnecting service locations with mode: {mode:?}"); - - for (instance, locations) in &self.connected_service_locations { - for location in locations { - debug!( - "Found connected service location for instance_id: {instance}, \ - location_pubkey: {}", - location.pubkey - ); - if let Some(m) = mode { - let location_mode: ServiceLocationMode = location.mode.try_into()?; - if location_mode != m { - debug!( - "Skipping interface {} due to the service location mode doesn't match the \ - requested mode (expected {m:?}, found {:?})", - location.name, location.mode - ); - continue; - } - } - - let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {ifname}"); - if let Some(mut wgapi) = self.wgapis.remove(&ifname) { - if let Err(err) = wgapi.remove_interface() { - error!("Failed to remove interface {ifname}: {err}"); - } else { - debug!("Interface {ifname} removed successfully."); - } - } else { - error!("Failed to find WireGuard API for interface {ifname}"); - } - } - } - - self.remove_connected_service_locations(|_, location| { - if let Some(m) = mode { - let location_mode: ServiceLocationMode = location - .mode - .try_into() - .unwrap_or(ServiceLocationMode::AlwaysOn); - location_mode == m - } else { - true - } - })?; - - debug!("Service locations disconnected."); - - Ok(()) - } - - /// Attempts to connect to all service locations that are not already connected. - /// - /// Returns `Ok(true)` if every location is now connected (either it was already connected or - /// it was successfully connected during this call), and `Ok(false)` if at least one location - /// failed to connect (indicating that a retry may be worthwhile). - pub(crate) fn connect_to_service_locations(&mut self) -> Result { - debug!("Attempting to auto-connect to VPN..."); - - let data = self.load_service_locations()?; - debug!("Loaded {} instance(s) from ServiceLocationApi", data.len()); - - let mut all_connected = true; - - for instance_data in data { - debug!( - "Found service locations for instance ID: {}", - instance_data.instance_id - ); - debug!( - "Instance has {} service location(s)", - instance_data.service_locations.len() - ); - for location in instance_data.service_locations { - debug!("Service Location: {location:?}"); - - if location.mode == ServiceLocationMode::PreLogon as i32 { - if is_user_logged_in() { - debug!( - "Skipping pre-logon service location '{}' because user is logged in", - location.name - ); - continue; - } - debug!( - "Proceeding to connect pre-logon service location '{}' because no user \ - is logged in", - location.name - ); - } - - if self.is_service_location_connected(&instance_data.instance_id, &location.pubkey) - { - debug!( - "Skipping service location '{}' because it's already connected", - location.name - ); - continue; - } - - if let Err(err) = - self.setup_service_location_interface(&location, &instance_data.private_key) - { - warn!( - "Failed to setup service location interface for '{}': {err:?}", - location.name - ); - all_connected = false; - continue; - } - - if let Err(err) = - self.add_connected_service_location(&instance_data.instance_id, &location) - { - debug!( - "Failed to persist connected service location after auto-connect: {err:?}" - ); - } - - debug!( - "Successfully connected to service location '{}'", - location.name - ); - } - } - - debug!("Auto-connect attempt completed"); - - Ok(all_connected) - } - - pub fn save_service_locations( - &self, - service_locations: &[ServiceLocation], - instance_id: &str, - private_key: &str, - ) -> Result<(), ServiceLocationError> { - debug!( - "Received a request to save {} service location(s) for instance {instance_id}", - service_locations.len(), - ); - - debug!("Service locations to save: {service_locations:?}"); - - create_dir_all(get_shared_directory()?)?; - - let instance_file_path = get_instance_file_path(instance_id)?; - - let service_location_data = ServiceLocationData { - service_locations: service_locations.to_vec(), - instance_id: instance_id.to_string(), - private_key: private_key.to_string(), - }; - - let json = serde_json::to_string_pretty(&service_location_data)?; - - debug!( - "Writing service location data to file: {}", - instance_file_path.display() - ); - - fs::write(&instance_file_path, &json)?; - - if let Some(file_path_str) = instance_file_path.to_str() { - debug!("Setting ACLs on service location file: {file_path_str}"); - if let Err(err) = set_protected_acls(file_path_str) { - warn!( - "Failed to set ACLs on service location file {file_path_str}: {err}. \ - File saved but may have insecure permissions." - ); - } else { - debug!("Successfully set ACLs on service location file"); - } - } else { - warn!("Failed to convert file path to string for ACL setting"); - } - - debug!( - "Service locations saved successfully for instance {instance_id} to {}", - instance_file_path.display() - ); - Ok(()) - } - - fn load_service_locations(&self) -> Result, ServiceLocationError> { - let base_dir = get_shared_directory()?; - let mut all_locations_data = Vec::new(); - - if base_dir.exists() { - for entry in fs::read_dir(base_dir)? { - let entry = entry?; - let file_path = entry.path(); - - if file_path.is_file() && file_path.extension() == Some(OsStr::new("json")) { - match fs::read_to_string(&file_path) { - Ok(data) => match serde_json::from_str::(&data) { - Ok(locations_data) => { - all_locations_data.push(locations_data); - } - Err(err) => { - error!( - "Failed to parse service locations from file {}: {err}", - file_path.display() - ); - } - }, - Err(err) => { - error!( - "Failed to read service locations file {}: {err}", - file_path.display() - ); - } - } - } - } - } - - debug!( - "Loaded service locations data for {} instances", - all_locations_data.len() - ); - Ok(all_locations_data) - } - - fn load_service_location( - &self, - instance_id: &str, - location_pubkey: &str, - ) -> Result, ServiceLocationError> { - debug!("Loading service location for instance {instance_id} and pubkey {location_pubkey}"); - - let instance_file_path = get_instance_file_path(instance_id)?; - - if instance_file_path.exists() { - let data = fs::read_to_string(&instance_file_path)?; - let service_location_data = serde_json::from_str::(&data)?; - - for location in service_location_data.service_locations { - if location.pubkey == location_pubkey { - debug!( - "Successfully loaded service location for instance {instance_id} and \ - pubkey {location_pubkey}" - ); - return Ok(Some(SingleServiceLocationData { - service_location: location, - instance_id: service_location_data.instance_id, - private_key: service_location_data.private_key, - })); - } - } - - debug!( - "No service location found for instance {instance_id} with pubkey {location_pubkey}" - ); - Ok(None) - } else { - debug!("No service location file found for instance {instance_id}"); - Ok(None) - } - } - - pub(crate) fn delete_all_service_locations_for_instance( - &self, - instance_id: &str, - ) -> Result<(), ServiceLocationError> { - debug!("Deleting all service locations for instance {instance_id}"); - - let instance_file_path = get_instance_file_path(instance_id)?; - - if instance_file_path.exists() { - fs::remove_file(&instance_file_path)?; - debug!("Successfully deleted all service locations for instance {instance_id}"); - } else { - debug!("No service location file found for instance {instance_id}"); - } - - Ok(()) - } -} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 1ed0d46ee..477c3cb68 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -5,10 +5,10 @@ pub mod apple; pub mod appstate; pub mod commands; -pub mod enterprise; pub mod events; pub mod log_watcher; pub mod periodic; +pub mod provisioning; pub mod service; pub mod session_state; pub mod tray; diff --git a/src-tauri/src/periodic/config.rs b/src-tauri/src/periodic/config.rs new file mode 100644 index 000000000..b56593248 --- /dev/null +++ b/src-tauri/src/periodic/config.rs @@ -0,0 +1,224 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{LazyLock, Mutex}, + time::Duration, +}; + +use defguard_client_config_sync::{ + poll_instance, poll_instances, PollInstanceResult, VersionMismatchPayload, +}; +use defguard_client_core::{ + connection::active_connections::{active_connections, ACTIVE_CONNECTIONS}, + database::{ + models::{instance::Instance, location::Location, Id}, + DB_POOL, + }, + error::Error, + events::EventKey, + ConnectionType, +}; +use log::{debug, error, info}; +use sqlx::{Sqlite, Transaction}; +use tauri::{AppHandle, Emitter}; +use tokio::time::sleep; + +pub use defguard_client_config_sync::commands::{ + disable_enterprise_features, do_update_instance, locations_changed, +}; + +const INTERVAL_SECONDS: Duration = Duration::from_secs(30); + +/// Tracks instance IDs for which we already sent a version-mismatch notification, +/// to prevent duplicate notifications in the app's lifetime. +static NOTIFIED_INSTANCES: LazyLock>> = + LazyLock::new(|| Mutex::new(HashSet::new())); + +/// Periodically retrieves and updates configuration for all [`Instance`]s. +/// Updates are only performed if no connections are established to the [`Instance`], +/// otherwise event is emitted and UI message is displayed. +pub async fn poll_config(handle: AppHandle) { + debug!("Starting the configuration polling loop."); + // Polling starts sooner than app's frontend may load in dev builds, causing events (toasts) + // to be lost; you may want to wait here before starting if you want to debug it. + loop { + let active_instance_ids = match active_instance_ids().await { + Ok(ids) => ids, + Err(err) => { + error!( + "Failed to detect active instances for config polling, retrying in {}s: {err}", + INTERVAL_SECONDS.as_secs() + ); + sleep(INTERVAL_SECONDS).await; + continue; + } + }; + + let outcomes = match poll_instances(&DB_POOL, &active_instance_ids).await { + Ok(outcomes) => outcomes, + Err(err) => { + error!( + "Failed to poll instance configuration, retrying in {}s: {err}", + INTERVAL_SECONDS.as_secs() + ); + sleep(INTERVAL_SECONDS).await; + continue; + } + }; + + debug!( + "Found {} instances with a config polling token, processed configuration polling.", + outcomes.len() + ); + + let mut config_retrieved = 0; + for outcome in outcomes { + let instance_name = outcome.instance_name; + let instance_id = outcome.instance_id; + match outcome.result { + Ok(result) => { + config_retrieved += 1; + emit_version_mismatch(&handle, instance_id, version_mismatch(&result)); + emit_poll_result_events(&handle, instance_id, &instance_name, result); + debug!( + "Finished processing configuration polling request for instance {}(ID: {})", + instance_name, instance_id + ); + } + Err(Error::CoreNotEnterprise) => { + debug!( + "Tried to contact core for instance {}(ID: {}) config but it's not enterprise, can't retrieve config", + instance_name, instance_id + ); + } + Err(Error::NoToken) => { + debug!( + "Instance {}(ID: {}) has no token, can't retrieve its config from the core", + instance_name, instance_id, + ); + } + Err(err) => { + error!( + "Failed to retrieve instance {}(ID: {}) config from core: {err}", + instance_name, instance_id + ); + } + } + } + + if let Err(err) = handle.emit(EventKey::InstanceUpdate.into(), ()) { + error!("Failed to emit instance update event to the frontend: {err}"); + } + if config_retrieved > 0 { + info!( + "Automatically retrieved the newest instance configuration from core for {config_retrieved} instances, sleeping for {}s", + INTERVAL_SECONDS.as_secs(), + ); + } else { + debug!( + "No configuration updates retrieved, sleeping {}s", + INTERVAL_SECONDS.as_secs(), + ); + } + sleep(INTERVAL_SECONDS).await; + } +} + +/// Retrieves configuration for a given [`Instance`]. +/// Updates the instance if there aren't any active connections, otherwise emits +/// a ConfigChanged event so the frontend can prompt the user to reconnect. +pub async fn poll_instance_with_events( + transaction: &mut Transaction<'_, Sqlite>, + instance: &mut Instance, + handle: &AppHandle, +) -> Result<(), Error> { + let has_active_connections = !active_connections(instance).await?.is_empty(); + let result = poll_instance(transaction, instance, has_active_connections).await?; + + emit_version_mismatch(handle, instance.id, version_mismatch(&result)); + emit_poll_result_events(handle, instance.id, &instance.name, result); + + Ok(()) +} + +fn emit_version_mismatch( + handle: &AppHandle, + instance_id: Id, + payload: Option<&VersionMismatchPayload>, +) { + if let Some(payload) = payload { + let mut notified_instances = NOTIFIED_INSTANCES.lock().unwrap(); + if notified_instances.insert(instance_id) { + if let Err(err) = handle.emit(EventKey::VersionMismatch.into(), payload.clone()) { + error!("Failed to emit version mismatch event to the frontend: {err}"); + // Remove so we can retry next cycle. + notified_instances.remove(&instance_id); + } + } + } +} + +fn emit_poll_result_events( + handle: &AppHandle, + instance_id: Id, + instance_name: &str, + result: PollInstanceResult, +) { + match result { + PollInstanceResult::Unchanged { .. } => {} + PollInstanceResult::Updated { + locations_changed, .. + } => { + if locations_changed { + if let Err(err) = handle.emit(EventKey::InstanceUpdated.into(), ()) { + error!("Failed to emit instance-updated event: {err}"); + } + } + } + PollInstanceResult::ChangedWhileActive { .. } => { + debug!( + "Emitting config-changed event for instance {}({})", + instance_name, instance_id, + ); + let _ = handle.emit(EventKey::ConfigChanged.into(), instance_name); + info!( + "Emitted config-changed event for instance {}({})", + instance_name, instance_id, + ); + } + } +} + +fn version_mismatch(result: &PollInstanceResult) -> Option<&VersionMismatchPayload> { + match result { + PollInstanceResult::Unchanged { version_mismatch } + | PollInstanceResult::Updated { + version_mismatch, .. + } + | PollInstanceResult::ChangedWhileActive { version_mismatch } => version_mismatch.as_ref(), + } +} + +async fn active_instance_ids() -> Result, Error> { + let active_location_ids = ACTIVE_CONNECTIONS + .lock() + .await + .iter() + .filter(|connection| connection.connection_type == ConnectionType::Location) + .map(|connection| connection.location_id) + .collect::>(); + + if active_location_ids.is_empty() { + return Ok(HashSet::new()); + } + + let location_instances = Location::all(&*DB_POOL, false) + .await? + .into_iter() + .map(|location| (location.id, location.instance_id)) + .collect::>(); + + Ok(active_location_ids + .into_iter() + .filter_map(|location_id| location_instances.get(&location_id).copied()) + .collect()) +} diff --git a/src-tauri/src/periodic/mod.rs b/src-tauri/src/periodic/mod.rs index 37daff37c..dedc7b2a6 100644 --- a/src-tauri/src/periodic/mod.rs +++ b/src-tauri/src/periodic/mod.rs @@ -2,10 +2,11 @@ use tauri::AppHandle; use tokio::select; use self::{ - connection::verify_active_connections, purge_stats::purge_stats, version::poll_version, + config::poll_config, connection::verify_active_connections, purge_stats::purge_stats, + version::poll_version, }; -use crate::enterprise::periodic::config::poll_config; +pub mod config; pub mod connection; pub mod purge_stats; pub mod version; diff --git a/src-tauri/src/enterprise/provisioning/mod.rs b/src-tauri/src/provisioning.rs similarity index 94% rename from src-tauri/src/enterprise/provisioning/mod.rs rename to src-tauri/src/provisioning.rs index 6e2f76b8e..22eb71edf 100644 --- a/src-tauri/src/enterprise/provisioning/mod.rs +++ b/src-tauri/src/provisioning.rs @@ -1,9 +1,9 @@ use defguard_client_core::database::{models::instance::Instance, DB_POOL}; -pub use defguard_client_provisioning::{try_get_provisioning_config, ProvisioningConfig}; +use defguard_client_provisioning::{try_get_provisioning_config, ProvisioningConfig}; use tauri::{AppHandle, Manager}; /// Checks if the client has already been initialized -/// and tries to load provisioning config from file if necessary +/// and tries to load provisioning config from file if necessary. pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option { match Instance::all(&*DB_POOL).await { Ok(instances) => {