diff --git a/src/bins/src/shared/util_common.rs b/src/bins/src/shared/util_common.rs index 1cf65b97..6cd5a48c 100644 --- a/src/bins/src/shared/util_common.rs +++ b/src/bins/src/shared/util_common.rs @@ -2,6 +2,7 @@ use anyhow::{anyhow, Result}; use rand::distr::{Alphanumeric, SampleString}; use regex::Regex; use std::{path::Path, thread, time::Duration}; +use velopack::process; #[derive(Debug, Clone, Copy, PartialEq)] pub enum OperationWait { @@ -12,11 +13,11 @@ pub enum OperationWait { pub fn operation_wait(wait: OperationWait) { if let OperationWait::WaitPid(pid) = wait { - if let Err(e) = super::wait_for_pid_to_exit(pid, 60_000) { + if let Err(e) = process::wait_for_pid_to_exit(pid, Duration::from_secs(60)) { warn!("Failed to wait for process ({}) to exit ({}). Continuing...", pid, e); } } else if let OperationWait::WaitParent = wait { - if let Err(e) = super::wait_for_parent_to_exit(60_000) { + if let Err(e) = process::wait_for_parent_to_exit(Duration::from_secs(60)) { warn!("Failed to wait for parent process to exit ({}). Continuing...", e); } } else { diff --git a/src/bins/src/shared/util_windows.rs b/src/bins/src/shared/util_windows.rs index 9f34b635..d5cd27e8 100644 --- a/src/bins/src/shared/util_windows.rs +++ b/src/bins/src/shared/util_windows.rs @@ -1,89 +1,21 @@ -use ::windows::Win32::System::ProcessStatus::EnumProcesses; -use ::windows::Win32::UI::WindowsAndMessaging::AllowSetForegroundWindow; -use anyhow::{anyhow, bail, Result}; +use crate::windows::strings; +use ::windows::{ + core::PWSTR, + Win32::{ + Foundation::CloseHandle, + System::ProcessStatus::EnumProcesses, + System::Threading::{OpenProcess, QueryFullProcessImageNameW, PROCESS_NAME_WIN32, PROCESS_QUERY_LIMITED_INFORMATION}, + }, +}; +use anyhow::{bail, Result}; use regex::Regex; use semver::Version; use std::{ collections::HashMap, fs, path::{Path, PathBuf}, - process::Command as Process, }; -use windows::Wdk::System::Threading::{NtQueryInformationProcess, ProcessBasicInformation}; -use windows::Win32::System::Threading::{GetCurrentProcess, PROCESS_BASIC_INFORMATION}; -use winsafe::{self as w, co, prelude::*}; - -use velopack::locator::VelopackLocator; - -pub fn wait_for_pid_to_exit(pid: u32, ms_to_wait: u32) -> Result<()> { - info!("Waiting {}ms for process ({}) to exit.", ms_to_wait, pid); - let handle = w::HPROCESS::OpenProcess(co::PROCESS::SYNCHRONIZE, false, pid)?; - match handle.WaitForSingleObject(Some(ms_to_wait)) { - Ok(co::WAIT::OBJECT_0) => Ok(()), - // Ok(co::WAIT::TIMEOUT) => Ok(()), - _ => Err(anyhow!("WaitForSingleObject Failed.")), - } -} - -pub fn wait_for_parent_to_exit(ms_to_wait: u32) -> Result<()> { - info!("Reading parent process information."); - let basic_info = ProcessBasicInformation; - let handle = unsafe { GetCurrentProcess() }; - let mut return_length: u32 = 0; - let return_length_ptr: *mut u32 = &mut return_length as *mut u32; - - let mut info = PROCESS_BASIC_INFORMATION { - AffinityMask: 0, - BasePriority: 0, - ExitStatus: Default::default(), - InheritedFromUniqueProcessId: 0, - PebBaseAddress: std::ptr::null_mut(), - UniqueProcessId: 0, - }; - - let info_ptr: *mut ::core::ffi::c_void = &mut info as *mut _ as *mut ::core::ffi::c_void; - let info_size = std::mem::size_of::() as u32; - let hres = unsafe { NtQueryInformationProcess(handle, basic_info, info_ptr, info_size, return_length_ptr) }; - if hres.is_err() { - return Err(anyhow!("Failed to query process information: {:?}", hres)); - } - - if info.InheritedFromUniqueProcessId <= 1 { - // the parent process has exited - info!("The parent process ({}) has already exited", info.InheritedFromUniqueProcessId); - return Ok(()); - } - - fn get_pid_start_time(process: w::HPROCESS) -> Result { - let mut creation = w::FILETIME::default(); - let mut exit = w::FILETIME::default(); - let mut kernel = w::FILETIME::default(); - let mut user = w::FILETIME::default(); - process.GetProcessTimes(&mut creation, &mut exit, &mut kernel, &mut user)?; - Ok(((creation.dwHighDateTime as u64) << 32) | creation.dwLowDateTime as u64) - } - - let permissions = co::PROCESS::QUERY_LIMITED_INFORMATION | co::PROCESS::SYNCHRONIZE; - let parent_handle = w::HPROCESS::OpenProcess(permissions, false, info.InheritedFromUniqueProcessId as u32)?; - let parent_start_time = get_pid_start_time(unsafe { parent_handle.raw_copy() })?; - let myself_start_time = get_pid_start_time(w::HPROCESS::GetCurrentProcess())?; - - if parent_start_time > myself_start_time { - // the parent process has exited and the id has been re-used - info!( - "The parent process ({}) has already exited. parent_start={}, my_start={}", - info.InheritedFromUniqueProcessId, parent_start_time, myself_start_time - ); - return Ok(()); - } - - info!("Waiting {}ms for parent process ({}) to exit.", ms_to_wait, info.InheritedFromUniqueProcessId); - match parent_handle.WaitForSingleObject(Some(ms_to_wait)) { - Ok(co::WAIT::OBJECT_0) => Ok(()), - // Ok(co::WAIT::TIMEOUT) => Ok(()), - _ => Err(anyhow!("WaitForSingleObject Failed.")), - } -} +use velopack::{locator::VelopackLocator, process}; // https://github.com/nushell/nushell/blob/4458aae3d41517d74ce1507ad3e8cd94021feb16/crates/nu-system/src/windows.rs#L593 fn get_pids() -> Result> { @@ -101,52 +33,46 @@ fn get_pids() -> Result> { Ok(pids.iter().map(|x| *x as u32).collect()) } -fn get_processes_running_in_directory>(dir: P) -> Result> { +unsafe fn get_processes_running_in_directory>(dir: P) -> Result> { let dir = dir.as_ref(); let mut oup = HashMap::new(); + let mut full_path_vec = vec![0; i16::MAX as usize]; + let full_path_ptr = PWSTR(full_path_vec.as_mut_ptr()); + for pid in get_pids()? { - // I don't like using catch_unwind, but QueryFullProcessImageName seems to panic - // when it reaches a mingw64 process. This is a workaround. - let process_path = std::panic::catch_unwind(|| { - let process = w::HPROCESS::OpenProcess(co::PROCESS::QUERY_LIMITED_INFORMATION, false, pid); - if let Err(_) = process { - // trace!("Failed to open process: {} ({})", pid, e); - return None; - } + let process = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, pid); + if process.is_err() { + continue; + } - let process = process.unwrap(); - let full_path = process.QueryFullProcessImageName(co::PROCESS_NAME::WIN32); - if let Err(_) = full_path { - // trace!("Failed to query process path: {} ({})", pid, e); - return None; - } - return Some(full_path.unwrap()); - }); + let process = process.unwrap(); + if process.is_invalid() { + continue; + } - match process_path { - Ok(Some(full_path)) => { - let full_path = Path::new(&full_path); - if let Ok(is_subpath) = crate::windows::is_sub_path(full_path, dir) { - if is_subpath { - oup.insert(pid, full_path.to_path_buf()); - } - } + let mut full_path_len = full_path_vec.len() as u32; + if QueryFullProcessImageNameW(process, PROCESS_NAME_WIN32, full_path_ptr, &mut full_path_len).is_err() { + let _ = CloseHandle(process); + continue; + } + + let full_path = strings::u16_to_string(&full_path_vec); + if let Err(_) = full_path { + continue; + } + + let full_path = PathBuf::from(full_path.unwrap()); + if let Ok(is_subpath) = crate::windows::is_sub_path(&full_path, dir) { + if is_subpath { + oup.insert(pid, full_path); } - Ok(None) => {} - Err(e) => error!("Fatal panic checking process: {} ({:?})", pid, e), } } Ok(oup) } -fn kill_pid(pid: u32) -> Result<()> { - let process = w::HPROCESS::OpenProcess(co::PROCESS::TERMINATE, false, pid)?; - process.TerminateProcess(1)?; - Ok(()) -} - pub fn force_stop_package>(root_dir: P) -> Result<()> { let root_dir = root_dir.as_ref(); super::retry_io(|| _force_stop_package(root_dir))?; @@ -156,7 +82,7 @@ pub fn force_stop_package>(root_dir: P) -> Result<()> { fn _force_stop_package>(root_dir: P) -> Result<()> { let dir = root_dir.as_ref(); info!("Checking for running processes in: {}", dir.display()); - let processes = get_processes_running_in_directory(dir)?; + let processes = unsafe { get_processes_running_in_directory(dir)? }; let my_pid = std::process::id(); for (pid, exe) in processes.iter() { if *pid == my_pid { @@ -164,7 +90,7 @@ fn _force_stop_package>(root_dir: P) -> Result<()> { continue; } warn!("Killing process: {} ({})", exe.display(), pid); - kill_pid(*pid)?; + process::kill_pid(*pid)?; } Ok(()) } @@ -177,21 +103,13 @@ pub fn start_package(locator: &VelopackLocator, exe_args: Option>, set bail!("Unable to find executable to start: '{}'", exe_to_execute.to_string_lossy()); } - let mut psi = Process::new(&exe_to_execute); - psi.current_dir(¤t); - if let Some(args) = exe_args { - psi.args(args); + let args: Vec = exe_args.unwrap_or_default().iter().map(|s| s.to_string()).collect(); + let mut environment = HashMap::new(); + if let Some(env_var) = set_env { + debug!("Setting environment variable: {}={}", env_var, "true"); + environment.insert(env_var.to_string(), "true".to_string()); } - if let Some(env) = set_env { - debug!("Setting environment variable: {}={}", env, "true"); - psi.env(env, "true"); - } - - info!("About to launch: '{:?}' in dir '{:?}'", exe_to_execute, current); - info!("Args: {:?}", psi.get_args()); - let child = psi.spawn().map_err(|z| anyhow!("Failed to start application ({}).", z))?; - let _ = unsafe { AllowSetForegroundWindow(child.id()) }; - + process::run_process(exe_to_execute, args, Some(current), true, Some(environment))?; Ok(()) } @@ -273,7 +191,7 @@ fn test_get_running_processes_finds_cargo() { let path = Path::new(&profile); let rustup = path.join(".rustup"); - let processes = get_processes_running_in_directory(&rustup).unwrap(); + let processes = unsafe { get_processes_running_in_directory(&rustup).unwrap() }; assert!(processes.len() > 0); let mut found = false; diff --git a/src/bins/src/windows/known_path.rs b/src/bins/src/windows/known_path.rs index db3b266b..ff0ca725 100644 --- a/src/bins/src/windows/known_path.rs +++ b/src/bins/src/windows/known_path.rs @@ -12,7 +12,7 @@ fn get_known_folder(rfid: *const GUID) -> Result { unsafe { let flag = windows::Win32::UI::Shell::KNOWN_FOLDER_FLAG(0); let result = SHGetKnownFolderPath(rfid, flag, None)?; - super::strings::pwstr_to_string(result) + super::strings::u16_to_string(result) } } diff --git a/src/bins/src/windows/strings.rs b/src/bins/src/windows/strings.rs index 83f94c94..8d85c003 100644 --- a/src/bins/src/windows/strings.rs +++ b/src/bins/src/windows/strings.rs @@ -1,30 +1,94 @@ use anyhow::Result; -use windows::core::{HSTRING, PCWSTR, PWSTR}; +use windows::core::{PCWSTR, PWSTR}; pub fn string_to_u16>(input: P) -> Vec { let input = input.as_ref(); input.encode_utf16().chain(Some(0)).collect::>() } -pub fn pwstr_to_string(input: PWSTR) -> Result { - unsafe { - let hstring = input.to_hstring(); - let string = hstring.to_string_lossy(); - Ok(string.trim_end_matches('\0').to_string()) +pub trait WideString { + fn to_wide_slice(&self) -> &[u16]; +} + +impl WideString for PWSTR { + fn to_wide_slice(&self) -> &[u16] { + unsafe { self.as_wide() } } } -pub fn pcwstr_to_string(input: PCWSTR) -> Result { - unsafe { - let hstring = input.to_hstring(); - let string = hstring.to_string_lossy(); - Ok(string.trim_end_matches('\0').to_string()) +impl WideString for PCWSTR { + fn to_wide_slice(&self) -> &[u16] { + unsafe { self.as_wide() } } } -pub fn u16_to_string>(input: T) -> Result { - let input = input.as_ref(); - let hstring = HSTRING::from_wide(input); - let string = hstring.to_string_lossy(); - Ok(string.trim_end_matches('\0').to_string()) +impl WideString for Vec { + fn to_wide_slice(&self) -> &[u16] { + self.as_ref() + } } + +impl WideString for &Vec { + fn to_wide_slice(&self) -> &[u16] { + self.as_ref() + } +} + +// impl WideString for [u16] { +// fn to_wide_slice(&self) -> &[u16] { +// self.as_ref() +// } +// } + +impl WideString for [u16; N] { + fn to_wide_slice(&self) -> &[u16] { + self.as_ref() + } +} + +pub fn u16_to_string_lossy(input: T) -> String { + let slice = input.to_wide_slice(); + let null_pos = slice.iter().position(|&x| x == 0).unwrap_or(slice.len()); + let trimmed_slice = &slice[..null_pos]; + String::from_utf16_lossy(trimmed_slice) +} + +pub fn u16_to_string(input: T) -> Result { + let slice = input.to_wide_slice(); + let null_pos = slice.iter().position(|&x| x == 0).unwrap_or(slice.len()); + let trimmed_slice = &slice[..null_pos]; + Ok(String::from_utf16(trimmed_slice)?) +} + +// pub fn pwstr_to_string(input: PWSTR) -> Result { +// unsafe { +// let hstring = input.to_hstring(); +// let string = hstring.to_string_lossy(); +// Ok(string.trim_end_matches('\0').to_string()) +// } +// } + +// pub fn pcwstr_to_string(input: PCWSTR) -> Result { +// unsafe { +// let hstring = input.to_hstring(); +// let string = hstring.to_string_lossy(); +// Ok(string.trim_end_matches('\0').to_string()) +// } +// } + +// pub fn u16_to_string>(input: T) -> Result { +// let input = input.as_ref(); +// let hstring = HSTRING::from_wide(input); +// let string = hstring.to_string_lossy(); +// Ok(string.trim_end_matches('\0').to_string()) +// } + +// pub fn u16_to_string>(input: T) -> Result { +// let input = input.as_ref(); +// // Find position of first null byte (0) +// let null_pos = input.iter().position(|&x| x == 0).unwrap_or(input.len()); +// // Take only up to the first null byte +// let trimmed_input = &input[..null_pos]; +// let hstring = HSTRING::from_wide(trimmed_input); +// Ok(hstring.to_string_lossy()) +// } diff --git a/src/bins/src/windows/util.rs b/src/bins/src/windows/util.rs index ac966a6b..e1fb24e8 100644 --- a/src/bins/src/windows/util.rs +++ b/src/bins/src/windows/util.rs @@ -1,57 +1,60 @@ -use std::{ - os::windows::process::CommandExt, - path::{Path, PathBuf}, - process::Command as Process, - time::Duration, +use crate::{ + shared::{self, runtime_arch::RuntimeArch}, + windows::strings::{string_to_u16, u16_to_string}, }; - -use velopack::locator::VelopackLocator; - use anyhow::{anyhow, Result}; use normpath::PathExt; -use wait_timeout::ChildExt; -use windows::core::PCWSTR; -use windows::Win32::Storage::FileSystem::GetLongPathNameW; -use windows::Win32::System::SystemInformation::{VerSetConditionMask, VerifyVersionInfoW, OSVERSIONINFOEXW, VER_FLAGS}; -use windows::Win32::UI::WindowsAndMessaging::AllowSetForegroundWindow; -use windows::Win32::Foundation; - -use crate::shared::{self, runtime_arch::RuntimeArch}; -use crate::windows::strings::{string_to_u16, u16_to_string}; +use std::{ + path::{Path, PathBuf}, + time::Duration, +}; +use velopack::{locator::VelopackLocator, process, process::WaitResult}; +use windows::{ + core::PCWSTR, + Win32::Storage::FileSystem::GetLongPathNameW, + Win32::System::SystemInformation::{VerSetConditionMask, VerifyVersionInfoW, OSVERSIONINFOEXW, VER_FLAGS}, +}; pub fn run_hook(locator: &VelopackLocator, hook_name: &str, timeout_secs: u64) -> bool { let sw = simple_stopwatch::Stopwatch::start_new(); let root_dir = locator.get_root_dir(); let current_path = locator.get_current_bin_dir(); - let main_exe_path = locator.get_main_exe_path(); + let main_exe_path = locator.get_main_exe_path_as_string(); let ver_string = locator.get_manifest_version_full_string(); let args = vec![hook_name, &ver_string]; let mut success = false; info!("Running {} hook...", hook_name); - const CREATE_NO_WINDOW: u32 = 0x08000000; - let cmd = Process::new(&main_exe_path).args(args).current_dir(¤t_path).creation_flags(CREATE_NO_WINDOW).spawn(); + let cmd = process::run_process( + main_exe_path, + args.iter().map(|f| f.to_string()).collect(), + Some(current_path.to_string_lossy().to_string()), + false, + None, + ); if let Err(e) = cmd { warn!("Failed to start hook {}: {}", hook_name, e); return false; } - let mut cmd = cmd.unwrap(); - let _ = unsafe { AllowSetForegroundWindow(cmd.id()) }; + let cmd = cmd.unwrap(); - match cmd.wait_timeout(Duration::from_secs(timeout_secs)) { - Ok(Some(status)) => { - if status.success() { + match process::wait_for_process_to_exit_with_timeout(cmd.handle(), Duration::from_secs(timeout_secs)) { + Ok(WaitResult::NoWaitRequired) => { + warn!("Was unable to wait for hook (it may have exited too quickly)."); + } + Ok(WaitResult::ExitCode(code)) => { + if code == 0 { info!("Hook executed successfully (took {}ms)", sw.ms()); success = true; } else { - warn!("Hook exited with non-zero exit code: {}", status.code().unwrap_or(0)); + warn!("Hook exited with non-zero exit code: {}", code); } } - Ok(None) => { - let _ = cmd.kill(); - error!("Process timed out after {}s", timeout_secs); + Ok(WaitResult::WaitTimeout) => { + let _ = process::kill_process(cmd.handle()); + error!("Process timed out after {}s and was killed.", timeout_secs); } Err(e) => { error!("Error waiting for process to finish: {}", e); diff --git a/src/lib-rust/Cargo.toml b/src/lib-rust/Cargo.toml index c08fbfb4..27359d42 100644 --- a/src/lib-rust/Cargo.toml +++ b/src/lib-rust/Cargo.toml @@ -19,6 +19,7 @@ delta = ["zstd"] async = ["async-std"] typescript = ["ts-rs"] file-logging = ["log-panics", "simplelog", "file-rotate", "time"] +public-utils = [] [package.metadata.docs.rs] features = ["async", "delta"] @@ -67,7 +68,22 @@ file-rotate = { workspace = true, optional = true } time = { workspace = true, optional = true } [target.'cfg(windows)'.dependencies] -windows = { workspace = true, features = ["Win32_Foundation", "Win32_Storage", "Win32_Storage_FileSystem", "Win32_System_IO", "Win32_UI_Shell"] } +windows = { workspace = true, features = [ + "Win32_Foundation", + "Win32_Storage", + "Win32_Storage_FileSystem", + "Win32_Security", + "Win32_System_IO", + "Win32_System_Threading", + "Win32_UI_WindowsAndMessaging", + "Win32_UI_Shell", + "Win32_System_Kernel", + "Win32_System_Registry", + "Wdk", + "Wdk_System", + "Wdk_System_Threading", +] } [target.'cfg(unix)'.dependencies] libc.workspace = true +waitpid-any.workspace = true diff --git a/src/lib-rust/src/bundle.rs b/src/lib-rust/src/bundle.rs index 39701294..f9f357e1 100644 --- a/src/lib-rust/src/bundle.rs +++ b/src/lib-rust/src/bundle.rs @@ -15,7 +15,7 @@ use xml::EventReader; use xml::reader::XmlEvent; use zip::ZipArchive; -use crate::{Error, util}; +use crate::{Error, misc}; #[cfg(target_os = "macos")] use std::os::unix::fs::PermissionsExt; @@ -39,7 +39,7 @@ pub struct BundleZip<'a> { pub fn load_bundle_from_file<'a, P: AsRef>(file_name: P) -> Result, Error> { let file_name = file_name.as_ref(); debug!("Loading bundle from file '{}'...", file_name.to_string_lossy()); - let file = util::retry_io(|| File::open(&file_name))?; + let file = misc::retry_io(|| File::open(&file_name))?; let cursor: Box = Box::new(file); let zip = ZipArchive::new(cursor)?; Ok(BundleZip { @@ -69,9 +69,9 @@ impl BundleZip<'_> { pub fn copy_bundle_to_file>(&self, output_file_path: T) -> Result<(), Error> { let nupkg_path = output_file_path.as_ref(); if self.zip_from_file { - util::retry_io(|| fs::copy(self.file_path.clone().unwrap(), nupkg_path))?; + misc::retry_io(|| fs::copy(self.file_path.clone().unwrap(), nupkg_path))?; } else { - util::retry_io(|| fs::write(nupkg_path, self.zip_range.unwrap()))?; + misc::retry_io(|| fs::write(nupkg_path, self.zip_range.unwrap()))?; } Ok(()) } @@ -146,12 +146,12 @@ impl BundleZip<'_> { if !parent.exists() { debug!("Creating parent directory: {:?}", parent); - util::retry_io(|| fs::create_dir_all(parent))?; + misc::retry_io(|| fs::create_dir_all(parent))?; } let mut archive = self.zip.borrow_mut(); let mut file = archive.by_index(index)?; - let mut outfile = util::retry_io(|| File::create(path))?; + let mut outfile = misc::retry_io(|| File::create(path))?; let mut buffer = [0; 64000]; // Use a 64KB buffer; good balance for large/small files. debug!("Writing file to disk with 64k buffer: {:?}", path); @@ -326,9 +326,9 @@ impl BundleZip<'_> { let parent = link_path.parent().unwrap(); if !parent.exists() { debug!("Creating parent directory: {:?}", parent); - util::retry_io(|| fs::create_dir_all(parent))?; + misc::retry_io(|| fs::create_dir_all(parent))?; } - util::retry_io(|| Self::create_symlink(&link_path, &contents))?; + misc::retry_io(|| Self::create_symlink(&link_path, &contents))?; } Ok(()) diff --git a/src/lib-rust/src/constants.rs b/src/lib-rust/src/constants.rs index 5d39569c..663e568a 100644 --- a/src/lib-rust/src/constants.rs +++ b/src/lib-rust/src/constants.rs @@ -6,4 +6,13 @@ pub const HOOK_ENV_RESTART: &str = "VELOPACK_RESTART"; pub const HOOK_CLI_INSTALL: &str = "--veloapp-install"; pub const HOOK_CLI_UPDATED: &str = "--veloapp-updated"; pub const HOOK_CLI_OBSOLETE: &str = "--veloapp-obsolete"; -pub const HOOK_CLI_UNINSTALL: &str = "--veloapp-uninstall"; \ No newline at end of file +pub const HOOK_CLI_UNINSTALL: &str = "--veloapp-uninstall"; + +#[cfg(target_os = "windows")] +pub const DEFAULT_CHANNEL_NAME: &str = "win"; + +#[cfg(target_os = "linux")] +pub const DEFAULT_CHANNEL_NAME: &str = "linux"; + +#[cfg(target_os = "macos")] +pub const DEFAULT_CHANNEL_NAME: &str = "osx"; \ No newline at end of file diff --git a/src/lib-rust/src/download.rs b/src/lib-rust/src/download.rs index d27094b8..c678c176 100644 --- a/src/lib-rust/src/download.rs +++ b/src/lib-rust/src/download.rs @@ -1,7 +1,7 @@ use std::fs::File; use std::io::{Read, Write}; -use crate::{util, Error}; +use crate::{misc, Error}; /// Downloads a file from a URL and writes it to a file while reporting progress from 0-100. pub fn download_url_to_file(url: &str, file_path: &str, mut progress: A) -> Result<(), Error> @@ -12,7 +12,7 @@ where let (head, body) = agent.get(url).call()?.into_parts(); let total_size = head.headers.get("Content-Length").and_then(|s| s.to_str().ok()).and_then(|s| s.parse::().ok()); - let mut file = util::retry_io(|| File::create(file_path))?; + let mut file = misc::retry_io(|| File::create(file_path))?; const CHUNK_SIZE: usize = 2 * 1024 * 1024; // 2MB let mut downloaded: u64 = 0; diff --git a/src/lib-rust/src/lib.rs b/src/lib-rust/src/lib.rs index ad7e8916..b33faea1 100644 --- a/src/lib-rust/src/lib.rs +++ b/src/lib-rust/src/lib.rs @@ -78,44 +78,68 @@ #![warn(missing_docs)] + + +macro_rules! maybe_pub { + ($($mod:ident),*) => { + $( + #[cfg(feature = "public-utils")] + #[allow(missing_docs)] + pub mod $mod; + + #[cfg(not(feature = "public-utils"))] + #[allow(unused)] + mod $mod; + )* + }; +} + +macro_rules! maybe_pub_os { + ($mod:ident, $win_path:expr, $unix_path:expr) => { + #[cfg(all(windows, feature = "public-utils"))] + #[path = $win_path] + #[allow(missing_docs)] + pub mod $mod; + + #[cfg(all(windows, not(feature = "public-utils")))] + #[path = $win_path] + #[allow(unused)] + mod $mod; + + #[cfg(all(not(windows), feature = "public-utils"))] + #[path = $unix_path] + #[allow(missing_docs)] + pub mod $mod; + + #[cfg(all(not(windows), not(feature = "public-utils")))] + #[path = $unix_path] + #[allow(unused)] + mod $mod; + }; +} + + mod app; +pub use app::*; + mod manager; -mod util; +pub use manager::*; -/// Utility functions for loading and working with Velopack bundles and manifests. -pub mod bundle; - -/// Utility function for downloading files with progress reporting. -pub mod download; - -/// Constant strings used internally by Velopack. -pub mod constants; - -/// Locator provides some utility functions for locating the current app important paths (eg. path to packages, update binary, and so forth). +/// Locator provides support for locating the current app important paths (eg. path to packages, update binary, and so forth). pub mod locator; -/// Sources contains abstractions for custom update sources (eg. url, local file, github releases, etc). +/// Sources are abstractions for custom update sources (eg. url, local file, github releases, etc). pub mod sources; -/// Functions to patch files and reconstruct Velopack delta packages. -pub mod delta; - -/// Acquire and manage file-system based lock files. -pub mod lockfile; - -/// Logging utilities and setup. -pub mod logging; - -pub use app::*; -pub use manager::*; +maybe_pub!(download, bundle, delta, constants, lockfile, logging, misc); +maybe_pub_os!(process, "process_win.rs", "process_unix.rs"); #[macro_use] extern crate log; #[derive(thiserror::Error, Debug)] #[allow(missing_docs, clippy::large_enum_variant)] -pub enum NetworkError -{ +pub enum NetworkError { #[error("Http error: {0}")] Http(#[from] ureq::Error), #[error("Url error: {0}")] @@ -124,8 +148,7 @@ pub enum NetworkError #[derive(thiserror::Error, Debug)] #[allow(missing_docs)] -pub enum Error -{ +pub enum Error { #[error("File does not exist: {0}")] FileNotFound(String), #[error("IO error: {0}")] @@ -148,6 +171,7 @@ pub enum Error NotInstalled(String), #[error("Generic error: {0}")] Generic(String), + #[cfg(target_os = "windows")] #[error("Win32 error: {0}")] Win32(#[from] windows::core::Error), } @@ -162,4 +186,4 @@ impl From for Error { fn from(err: ureq::Error) -> Self { Error::Network(Box::new(NetworkError::Http(err))) } -} \ No newline at end of file +} diff --git a/src/lib-rust/src/lockfile.rs b/src/lib-rust/src/lockfile.rs index 641299a7..6587d21c 100644 --- a/src/lib-rust/src/lockfile.rs +++ b/src/lib-rust/src/lockfile.rs @@ -16,7 +16,7 @@ impl LockFile { pub fn try_acquire_lock>(path: P) -> Result { let path: PathBuf = path.into(); - crate::util::retry_io(|| { + crate::misc::retry_io(|| { #[cfg(windows)] { let file = Self::windows_exclusive_lock(&path)?; diff --git a/src/lib-rust/src/manager.rs b/src/lib-rust/src/manager.rs index 30b007ed..8622c338 100644 --- a/src/lib-rust/src/manager.rs +++ b/src/lib-rust/src/manager.rs @@ -1,10 +1,6 @@ +use std::process::exit; #[cfg(target_os = "windows")] -use std::os::windows::process::CommandExt; -use std::{ - fs, - process::{exit, Command as Process}, - sync::mpsc::Sender, -}; +use std::{fs, sync::mpsc::Sender}; use semver::Version; use serde::{Deserialize, Serialize}; @@ -15,9 +11,11 @@ use async_std::channel::Sender as AsyncSender; use async_std::task::JoinHandle; use crate::{ + constants, locator::{self, LocationContext, VelopackLocator, VelopackLocatorConfig}, + misc, process, sources::UpdateSource, - util, Error, + Error, }; /// Configure how the update process should wait before applying updates. @@ -185,8 +183,8 @@ impl UpdateManager { let app_channel = self.locator.get_manifest_channel(); let mut channel = options_channel.unwrap_or(&app_channel).to_string(); if channel.is_empty() { - warn!("Channel is empty, picking default."); - channel = locator::default_channel_name(); + warn!("Channel is empty, using default."); + channel = constants::DEFAULT_CHANNEL_NAME.to_owned(); } info!("Chosen channel for updates: {:?} (explicit={:?}, memorized={:?})", channel, options_channel, app_channel); channel @@ -220,13 +218,14 @@ impl UpdateManager { let packages_dir = self.locator.get_packages_dir(); if let Some((path, manifest)) = locator::find_latest_full_package(&packages_dir) { if manifest.version > self.locator.get_manifest_version() { + let (sha1, sha256) = misc::calculate_sha1_sha256(&path).unwrap_or_default(); return Some(VelopackAsset { PackageId: manifest.id, Version: manifest.version.to_string(), Type: "Full".to_string(), FileName: path.file_name().unwrap().to_string_lossy().to_string(), - SHA1: util::calculate_file_sha1(&path).unwrap_or_default(), - SHA256: util::calculate_file_sha256(&path).unwrap_or_default(), + SHA1: sha1, + SHA256: sha256, Size: path.metadata().map(|m| m.len()).unwrap_or(0), NotesMarkdown: manifest.release_notes, NotesHtml: manifest.release_notes_html, @@ -462,7 +461,7 @@ impl UpdateManager { } /// This will launch the Velopack updater and optionally wait for a program to exit gracefully. - /// This method is unsafe because it does not necessarily wait for any / the correct process to exit + /// This method is unsafe because it does not necessarily wait for any / the correct process to exit /// before applying updates. The `wait_exit_then_apply_updates` method is recommended for most use cases. pub fn unsafe_apply_updates( &self, @@ -520,21 +519,8 @@ impl UpdateManager { } } - let mut p = Process::new(&self.locator.get_update_path()); - p.args(&args); - - if let Some(update_exe_parent) = self.locator.get_update_path().parent() { - p.current_dir(update_exe_parent); - } - - #[cfg(target_os = "windows")] - { - const CREATE_NO_WINDOW: u32 = 0x08000000; - p.creation_flags(CREATE_NO_WINDOW); - } - - info!("About to run Update.exe: {} {:?}", self.locator.get_update_path_as_string(), args); - p.spawn()?; + let update_path = self.locator.get_update_path(); + process::run_process(&update_path, args, update_path.parent(), false, None)?; Ok(()) } } diff --git a/src/lib-rust/src/util.rs b/src/lib-rust/src/misc.rs similarity index 70% rename from src/lib-rust/src/util.rs rename to src/lib-rust/src/misc.rs index b54ac460..68383bfc 100644 --- a/src/lib-rust/src/util.rs +++ b/src/lib-rust/src/misc.rs @@ -2,6 +2,7 @@ use crate::Error; use rand::distr::{Alphanumeric, SampleString}; use sha2::Digest; use std::fs::File; +use std::io::{BufReader, Read}; use std::path::Path; use std::thread; use std::time::Duration; @@ -42,20 +43,28 @@ pub fn random_string(len: usize) -> String { Alphanumeric.sample_string(&mut rand::rng(), len) } -pub fn calculate_file_sha256>(file: P) -> Result { - let mut file = File::open(file)?; - let mut sha256 = sha2::Sha256::new(); - std::io::copy(&mut file, &mut sha256)?; - let hash = sha256.finalize(); - Ok(format!("{:x}", hash)) -} +pub fn calculate_sha1_sha256>(file: P) -> Result<(String, String), Error> { + let file = File::open(file)?; + let mut reader = BufReader::new(file); -pub fn calculate_file_sha1>(file: P) -> Result { - let mut file = File::open(file)?; - let mut sha1o = sha1::Sha1::new(); - std::io::copy(&mut file, &mut sha1o)?; - let hash = sha1o.finalize(); - Ok(format!("{:x}", hash)) + let mut sha256 = sha2::Sha256::new(); + let mut sha1 = sha1::Sha1::new(); + + let mut buffer = [0u8; 1024 * 1024]; // 1MB buffer + loop { + let bytes_read = reader.read(&mut buffer)?; + if bytes_read == 0 { + break; + } + + sha256.update(&buffer[..bytes_read]); + sha1.update(&buffer[..bytes_read]); + } + + let sha256_hash = format!("{:x}", sha256.finalize()); + let sha1_hash = format!("{:x}", sha1.finalize()); + + Ok((sha1_hash, sha256_hash)) } pub fn is_directory_writable>(path: P1) -> bool { diff --git a/src/lib-rust/src/process_unix.rs b/src/lib-rust/src/process_unix.rs new file mode 100644 index 00000000..cdad652e --- /dev/null +++ b/src/lib-rust/src/process_unix.rs @@ -0,0 +1,87 @@ +use std::{ + collections::HashMap, + ffi::{OsStr, OsString}, + io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult}, + os::{raw::c_void, windows::ffi::OsStrExt}, + process::{Child, Command}, + time::Duration, +}; + +use crate::process; + +pub fn is_current_process_elevated() -> bool { + false +} + +fn string_to_u16>(input: P) -> Vec { + let input = input.as_ref(); + input.encode_utf16().chain(Some(0)).collect::>() +} + +pub fn run_process_as_admin( + exe_path: String, + args: Vec, + work_dir: Option, + show_window: bool, +) -> IoResult { + + // let mut cmd = Command::new(exe_path).args(args); + + // if let Some(dir) = work_dir { + // cmd.current_dir(dir); + // } +} + +pub fn run_process( + exe_path: String, + args: Vec, + work_dir: Option, + _show_window: bool, + set_env: Option>, +) -> IoResult { + let mut cmd = Command::new(exe_path); + cmd.args(args); + if let Some(dir) = work_dir { + cmd.current_dir(dir); + } + if let Some(env) = set_env { + for (key, value) in env { + cmd.env(key, value); + } + } + cmd.spawn() +} + +pub fn wait_for_process_exit_with_timeout(process: Child, dur: Duration) -> IoResult> { + + let mut status = process.wait_timeout(dur)?; + if status.is_none() { + return Err(IoError::new(IoErrorKind::TimedOut, "Process timed out")); + } + Ok(status.unwrap().code()) +} + +pub fn wait_for_pid_to_exit(pid: u32, dur: Duration) -> IoResult<()> { + info!("Waiting {}ms for process ({}) to exit.", ms_to_wait, pid); + let mut handle = waitpid_any::WaitHandle::open(pid.try_into()?)?; + let result = handle.wait_timeout(Duration::from_millis(ms_to_wait as u64))?; + if result.is_some() { + info!("Parent process exited."); + Ok(()) + } else { + bail!("Parent process timed out."); + } +} + +pub fn wait_for_parent_to_exit(dur: Duration) -> IoResult<()> { + let id = std::os::unix::process::parent_id(); + info!("Attempting to wait for parent process ({}) to exit.", id); + if id > 1 { + wait_for_pid_to_exit(id, ms_to_wait)?; + } + Ok(()) +} + +pub fn kill_process(mut process: Child) -> IoResult<()> { + process.kill() +} diff --git a/src/lib-rust/src/process_win.rs b/src/lib-rust/src/process_win.rs new file mode 100644 index 00000000..98c73f83 --- /dev/null +++ b/src/lib-rust/src/process_win.rs @@ -0,0 +1,457 @@ +use std::{ + collections::HashMap, + ffi::{OsStr, OsString}, + io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult}, + os::{raw::c_void, windows::ffi::OsStrExt}, + path::Path, + time::Duration, +}; +use windows::{ + core::{PCWSTR, PWSTR}, + Wdk::System::Threading::{NtQueryInformationProcess, ProcessBasicInformation}, + Win32::{ + Foundation::{CloseHandle, FILETIME, HANDLE, WAIT_OBJECT_0, WAIT_TIMEOUT}, + Security::{GetTokenInformation, TokenElevation, TOKEN_ELEVATION}, + System::Threading::{ + CreateProcessW, GetCurrentProcess, GetExitCodeProcess, GetProcessId, GetProcessTimes, OpenProcess, OpenProcessToken, + TerminateProcess, WaitForSingleObject, CREATE_NO_WINDOW, PROCESS_ACCESS_RIGHTS, PROCESS_BASIC_INFORMATION, + PROCESS_CREATION_FLAGS, PROCESS_QUERY_LIMITED_INFORMATION, PROCESS_SYNCHRONIZE, PROCESS_TERMINATE, + }, + UI::{ + Shell::{ShellExecuteExW, SEE_MASK_NOCLOSEPROCESS, SHELLEXECUTEINFOW}, + WindowsAndMessaging::AllowSetForegroundWindow, + }, + }, +}; + +enum Arg { + /// Add quotes (if needed) + Regular(OsString), + // Append raw string without quoting + #[allow(unused)] + Raw(OsString), +} + +enum Quote { + // Every arg is quoted + Always, + // Whitespace and empty args are quoted + Auto, + // Arg appended without any changes (#29494) + Never, +} + +fn ensure_no_nuls>(str: T) -> IoResult { + if str.as_ref().encode_wide().any(|b| b == 0) { + Err(IoError::new(IoErrorKind::InvalidInput, "nul byte found in provided data")) + } else { + Ok(str) + } +} + +fn append_arg(cmd: &mut Vec, arg: &Arg, force_quotes: bool) -> IoResult<()> { + let (arg, quote) = match arg { + Arg::Regular(arg) => (arg, if force_quotes { Quote::Always } else { Quote::Auto }), + Arg::Raw(arg) => (arg, Quote::Never), + }; + + // If an argument has 0 characters then we need to quote it to ensure + // that it actually gets passed through on the command line or otherwise + // it will be dropped entirely when parsed on the other end. + ensure_no_nuls(arg)?; + let arg_bytes = arg.as_encoded_bytes(); + let (quote, escape) = match quote { + Quote::Always => (true, true), + Quote::Auto => (arg_bytes.iter().any(|c| *c == b' ' || *c == b'\t') || arg_bytes.is_empty(), true), + Quote::Never => (false, false), + }; + if quote { + cmd.push('"' as u16); + } + + let mut backslashes: usize = 0; + for x in arg.encode_wide() { + if escape { + if x == '\\' as u16 { + backslashes += 1; + } else { + if x == '"' as u16 { + // Add n+1 backslashes to total 2n+1 before internal '"'. + cmd.extend((0..=backslashes).map(|_| '\\' as u16)); + } + backslashes = 0; + } + } + cmd.push(x); + } + + if quote { + // Add n backslashes to total 2n before ending '"'. + cmd.extend((0..backslashes).map(|_| '\\' as u16)); + cmd.push('"' as u16); + } + Ok(()) +} + +fn make_command_line(argv0: Option<&OsStr>, args: &[Arg], force_quotes: bool) -> IoResult> { + // Encode the command and arguments in a command line string such + // that the spawned process may recover them using CommandLineToArgvW. + let mut cmd: Vec = Vec::new(); + + // Always quote the program name so CreateProcess to avoid ambiguity when + // the child process parses its arguments. + // Note that quotes aren't escaped here because they can't be used in arg0. + // But that's ok because file paths can't contain quotes. + if let Some(argv0) = argv0 { + cmd.push(b'"' as u16); + cmd.extend(argv0.encode_wide()); + cmd.push(b'"' as u16); + cmd.push(' ' as u16); + } + + for arg in args { + append_arg(&mut cmd, arg, force_quotes)?; + cmd.push(' ' as u16); + } + + cmd.push(0); + Ok(cmd) +} + +fn make_envp(maybe_env: Option>) -> IoResult<(Option<*const c_void>, Vec)> { + // On Windows we pass an "environment block" which is not a char**, but + // rather a concatenation of null-terminated k=v\0 sequences, with a final + // \0 to terminate. + if let Some(env) = maybe_env { + let mut blk = Vec::new(); + + // If there are no environment variables to set then signal this by + // pushing a null. + if env.is_empty() { + blk.push(0); + } + + for (k, v) in env { + let os_key = OsString::from(k); + let os_value = OsString::from(v); + blk.extend(ensure_no_nuls(os_key)?.encode_wide()); + blk.push('=' as u16); + blk.extend(ensure_no_nuls(os_value)?.encode_wide()); + blk.push(0); + } + blk.push(0); + Ok((Some(blk.as_ptr() as *mut c_void), blk)) + } else { + Ok((None, Vec::new())) + } +} + +pub fn is_current_process_elevated() -> bool { + // Get the current process handle + let process = unsafe { GetCurrentProcess() }; + + // Variable to hold the process token + let mut token: HANDLE = HANDLE::default(); + + // Open the process token with the TOKEN_QUERY access rights + unsafe { + if OpenProcessToken(process, windows::Win32::Security::TOKEN_QUERY, &mut token).is_ok() { + // Allocate a buffer for the TOKEN_ELEVATION structure + let mut elevation = TOKEN_ELEVATION::default(); + let mut size: u32 = 0; + + let elevation_ptr: *mut core::ffi::c_void = &mut elevation as *mut _ as *mut _; + + // Query the token information to check if it is elevated + if GetTokenInformation(token, TokenElevation, Some(elevation_ptr), std::mem::size_of::() as u32, &mut size) + .is_ok() + { + // Return whether the token is elevated + let _ = CloseHandle(token); + return elevation.TokenIsElevated != 0; + } + } + } + + // Clean up the token handle + if !token.is_invalid() { + unsafe { + let _ = CloseHandle(token); + }; + } + + false +} + +pub struct SafeProcessHandle { + handle: HANDLE, + pid: u32, +} + +impl Drop for SafeProcessHandle { + fn drop(&mut self) { + if !self.handle.is_invalid() { + let _ = unsafe { CloseHandle(self.handle) }; + } + } +} + +impl SafeProcessHandle { + pub fn handle(&self) -> HANDLE { + self.handle + } + + pub fn pid(&self) -> u32 { + self.pid + } +} + +impl Into for SafeProcessHandle { + fn into(self) -> u32 { + self.pid() + } +} + +impl Into for SafeProcessHandle { + fn into(self) -> HANDLE { + self.handle() + } +} + +fn open_pid_safe( + dwdesiredaccess: PROCESS_ACCESS_RIGHTS, + binherithandle: bool, + dwprocessid: u32, +) -> windows::core::Result { + let handle = unsafe { OpenProcess(dwdesiredaccess, binherithandle, dwprocessid)? }; + return Ok(SafeProcessHandle { handle, pid: dwprocessid }); +} + +fn os_to_pcwstr>(d: P) -> IoResult<(PCWSTR, Vec)> { + let d = d.as_ref(); + let d = OsString::from(d); + let mut d_str: Vec = ensure_no_nuls(d)?.encode_wide().collect(); + d_str.push(0); + Ok((PCWSTR(d_str.as_ptr()), d_str)) +} + +fn pathopt_to_pcwstr>(d: Option

) -> IoResult<(PCWSTR, Vec)> { + match d { + Some(dir) => { + let dir = dir.as_ref(); + os_to_pcwstr(dir) + } + None => Ok((PCWSTR::null(), Vec::new())), + } +} + +pub fn run_process_as_admin, P2: AsRef>( + exe_path: P1, + args: Vec, + work_dir: Option, + show_window: bool, +) -> IoResult { + let verb = os_to_pcwstr("runas")?; + let exe = os_to_pcwstr(exe_path.as_ref())?; + let wrapped_args: Vec = args.iter().map(|a| Arg::Regular(a.into())).collect(); + let params = make_command_line(None, &wrapped_args, false)?; + let params = PCWSTR(params.as_ptr()); + let work_dir = pathopt_to_pcwstr(work_dir.as_ref())?; + + let n_show = + if show_window { windows::Win32::UI::WindowsAndMessaging::SW_NORMAL.0 } else { windows::Win32::UI::WindowsAndMessaging::SW_HIDE.0 }; + + let mut exe_info: SHELLEXECUTEINFOW = SHELLEXECUTEINFOW { + cbSize: std::mem::size_of::() as u32, + fMask: SEE_MASK_NOCLOSEPROCESS, + lpVerb: verb.0, + lpFile: exe.0, + lpParameters: params, + lpDirectory: work_dir.0, + nShow: n_show, + ..Default::default() + }; + + unsafe { + info!("About to launch [AS ADMIN]: '{:?}' in dir '{:?}' with arguments: {:?}", exe, work_dir, args); + ShellExecuteExW(&mut exe_info as *mut SHELLEXECUTEINFOW)?; + let process_id = GetProcessId(exe_info.hProcess); + let _ = AllowSetForegroundWindow(process_id); + Ok(SafeProcessHandle { handle: exe_info.hProcess, pid: process_id }) + } +} + +pub fn run_process, P2: AsRef>( + exe_path: P1, + args: Vec, + work_dir: Option, + show_window: bool, + set_env: Option>, +) -> IoResult { + let exe_path = exe_path.as_ref(); + let exe_path = OsString::from(exe_path); + let exe_name = PCWSTR(exe_path.encode_wide().chain(Some(0)).collect::>().as_mut_ptr()); + + let wrapped_args: Vec = args.iter().map(|a| Arg::Regular(a.into())).collect(); + let mut params = make_command_line(Some(&exe_path), &wrapped_args, false)?; + let params = PWSTR(params.as_mut_ptr()); + + let mut pi = windows::Win32::System::Threading::PROCESS_INFORMATION::default(); + + // let si = STARTUPINFOW { + // cb: std::mem::size_of::() as u32, + // lpReserved: PWSTR::null(), + // lpDesktop: PWSTR::null(), + // lpTitle: PWSTR::null(), + // dwX: 0, + // dwY: 0, + // dwXSize: 0, + // dwYSize: 0, + // dwXCountChars: 0, + // dwYCountChars: 0, + // dwFillAttribute: 0, + // dwFlags: STARTUPINFOW_FLAGS(0), + // wShowWindow: 0, + // cbReserved2: 0, + // lpReserved2: std::ptr::null_mut(), + // hStdInput: HANDLE(std::ptr::null_mut()), + // hStdOutput: HANDLE(std::ptr::null_mut()), + // hStdError: HANDLE(std::ptr::null_mut()), + // }; + + let envp = make_envp(set_env)?; + let dirp = pathopt_to_pcwstr(work_dir)?; + + let flags = if show_window { PROCESS_CREATION_FLAGS(0) } else { CREATE_NO_WINDOW }; + + unsafe { + info!("About to launch: '{:?}' in dir '{:?}' with arguments: {:?}", exe_name, dirp, args); + CreateProcessW(exe_name, Option::Some(params), None, None, false, flags, envp.0, dirp.0, std::ptr::null(), &mut pi)?; + let _ = AllowSetForegroundWindow(pi.dwProcessId); + let _ = CloseHandle(pi.hThread); + } + + Ok(SafeProcessHandle { handle: pi.hProcess, pid: pi.dwProcessId }) +} + +fn duration_to_ms(dur: Duration) -> u32 { + let ms = dur + .as_secs() + .checked_mul(1000) + .and_then(|amt| amt.checked_add((dur.subsec_nanos() / 1_000_000) as u64)) + .expect("failed to convert duration to milliseconds"); + if ms > (u32::max_value() as u64) { + u32::max_value() + } else { + ms as u32 + } +} + +pub fn kill_process>(process: T) -> IoResult<()> { + let process = process.into(); + unsafe { + if process.is_invalid() { + return Ok(()); + } + TerminateProcess(process, 1)?; + } + Ok(()) +} + +pub fn kill_pid(pid: u32) -> IoResult<()> { + let handle = open_pid_safe(PROCESS_TERMINATE, false, pid)?; + kill_process(handle)?; + Ok(()) +} + +pub enum WaitResult { + WaitTimeout, + ExitCode(u32), + NoWaitRequired, +} + +pub fn wait_for_process_to_exit_with_timeout>(process: T, dur: Duration) -> IoResult { + let process = process.into(); + if process.is_invalid() { + return Ok(WaitResult::NoWaitRequired); + } + + let ms = duration_to_ms(dur); + info!("Waiting {}ms for process handle to exit.", ms); + + unsafe { + match WaitForSingleObject(process, ms) { + WAIT_OBJECT_0 => {} + WAIT_TIMEOUT => return Ok(WaitResult::WaitTimeout), + _ => return Err(IoError::last_os_error()), + } + + let mut exit_code = 0; + GetExitCodeProcess(process, &mut exit_code)?; + Ok(WaitResult::ExitCode(exit_code)) + } +} + +pub fn wait_for_pid_to_exit(pid: u32, dur: Duration) -> IoResult { + info!("Waiting for process pid-{} to exit.", pid); + let handle = open_pid_safe(PROCESS_SYNCHRONIZE, false, pid)?; + wait_for_process_to_exit_with_timeout(handle, dur) +} + +pub fn wait_for_parent_to_exit(dur: Duration) -> IoResult { + info!("Reading parent process information."); + let basic_info = ProcessBasicInformation; + let my_handle = unsafe { GetCurrentProcess() }; + let mut return_length: u32 = 0; + let return_length_ptr: *mut u32 = &mut return_length as *mut u32; + + let mut info = PROCESS_BASIC_INFORMATION { + AffinityMask: 0, + BasePriority: 0, + ExitStatus: Default::default(), + InheritedFromUniqueProcessId: 0, + PebBaseAddress: std::ptr::null_mut(), + UniqueProcessId: 0, + }; + + let info_ptr: *mut ::core::ffi::c_void = &mut info as *mut _ as *mut ::core::ffi::c_void; + let info_size = std::mem::size_of::() as u32; + let hres = unsafe { NtQueryInformationProcess(my_handle, basic_info, info_ptr, info_size, return_length_ptr) }; + if hres.is_err() { + return Err(IoError::new(IoErrorKind::Other, format!("NtQueryInformationProcess failed: {:?}", hres))); + } + + if info.InheritedFromUniqueProcessId <= 1 { + // the parent process has exited + info!("The parent process ({}) has already exited", info.InheritedFromUniqueProcessId); + return Ok(WaitResult::NoWaitRequired); + } + + fn get_pid_start_time(process: HANDLE) -> IoResult { + let mut creation = FILETIME::default(); + let mut exit = FILETIME::default(); + let mut kernel = FILETIME::default(); + let mut user = FILETIME::default(); + unsafe { + GetProcessTimes(process, &mut creation, &mut exit, &mut kernel, &mut user)?; + } + Ok(((creation.dwHighDateTime as u64) << 32) | creation.dwLowDateTime as u64) + } + + let permissions = PROCESS_SYNCHRONIZE | PROCESS_QUERY_LIMITED_INFORMATION; + let parent_handle = open_pid_safe(permissions, false, info.InheritedFromUniqueProcessId as u32)?; + let parent_start_time = get_pid_start_time(parent_handle.handle())?; + let myself_start_time = get_pid_start_time(my_handle)?; + + if parent_start_time > myself_start_time { + // the parent process has exited and the id has been re-used + info!( + "The parent process ({}) has already exited. parent_start={}, my_start={}", + info.InheritedFromUniqueProcessId, parent_start_time, myself_start_time + ); + return Ok(WaitResult::NoWaitRequired); + } + + info!("Waiting for parent process ({}) to exit.", info.InheritedFromUniqueProcessId); + wait_for_process_to_exit_with_timeout(parent_handle, dur) +}