diff --git a/std/src/sys/pal/uefi/helpers.rs b/std/src/sys/pal/uefi/helpers.rs index bd8a6684b64f9..4ced7065c826d 100644 --- a/std/src/sys/pal/uefi/helpers.rs +++ b/std/src/sys/pal/uefi/helpers.rs @@ -177,16 +177,8 @@ pub(crate) fn device_path_to_text(path: NonNull) -> io::R ) }; - // SAFETY: `convert_device_path_to_text` returns a pointer to a null-terminated UTF-16 - // string, and that string cannot be deallocated prior to dropping the `WStrUnits`, so - // it's safe for `WStrUnits` to use. - let path_len = unsafe { - WStrUnits::new(path_ptr) - .ok_or(io::const_io_error!(io::ErrorKind::InvalidData, "Invalid path"))? - .count() - }; - - let path = OsString::from_wide(unsafe { slice::from_raw_parts(path_ptr.cast(), path_len) }); + let path = os_string_from_raw(path_ptr) + .ok_or(io::const_io_error!(io::ErrorKind::InvalidData, "Invalid path"))?; if let Some(boot_services) = crate::os::uefi::env::boot_services() { let boot_services: NonNull = boot_services.cast(); @@ -420,3 +412,15 @@ impl Drop for OwnedTable { unsafe { crate::alloc::dealloc(self.ptr as *mut u8, self.layout) }; } } + +/// Create OsString from a pointer to NULL terminated UTF-16 string +pub(crate) fn os_string_from_raw(ptr: *mut r_efi::efi::Char16) -> Option { + let path_len = unsafe { WStrUnits::new(ptr)?.count() }; + Some(OsString::from_wide(unsafe { slice::from_raw_parts(ptr.cast(), path_len) })) +} + +/// Create NULL terminated UTF-16 string +pub(crate) fn os_string_to_raw(s: &OsStr) -> Option> { + let temp = s.encode_wide().chain(Some(0)).collect::>(); + if temp[..temp.len() - 1].contains(&0) { None } else { Some(temp) } +} diff --git a/std/src/sys/pal/uefi/os.rs b/std/src/sys/pal/uefi/os.rs index 9aee67d622fad..4eb7698b43aa9 100644 --- a/std/src/sys/pal/uefi/os.rs +++ b/std/src/sys/pal/uefi/os.rs @@ -1,7 +1,7 @@ use r_efi::efi::Status; use r_efi::efi::protocols::{device_path, loaded_image_device_path}; -use super::{RawOsError, helpers, unsupported}; +use super::{RawOsError, helpers, unsupported_err}; use crate::error::Error as StdError; use crate::ffi::{OsStr, OsString}; use crate::marker::PhantomData; @@ -125,11 +125,32 @@ pub fn error_string(errno: RawOsError) -> String { } pub fn getcwd() -> io::Result { - unsupported() + match uefi_shell::open_shell() { + Some(shell) => { + // SAFETY: path_ptr is managed by UEFI shell and should not be deallocated + let path_ptr = unsafe { ((*shell.as_ptr()).get_cur_dir)(crate::ptr::null_mut()) }; + helpers::os_string_from_raw(path_ptr) + .map(PathBuf::from) + .ok_or(io::const_io_error!(io::ErrorKind::InvalidData, "Invalid path")) + } + None => { + let mut t = current_exe()?; + // SAFETY: This should never fail since the disk prefix will be present even for root + // executables + assert!(t.pop()); + Ok(t) + } + } } -pub fn chdir(_: &path::Path) -> io::Result<()> { - unsupported() +pub fn chdir(p: &path::Path) -> io::Result<()> { + let shell = uefi_shell::open_shell().ok_or(unsupported_err())?; + + let mut p = helpers::os_string_to_raw(p.as_os_str()) + .ok_or(io::const_io_error!(io::ErrorKind::InvalidData, "Invalid path"))?; + + let r = unsafe { ((*shell.as_ptr()).set_cur_dir)(crate::ptr::null_mut(), p.as_mut_ptr()) }; + if r.is_error() { Err(io::Error::from_raw_os_error(r.as_usize())) } else { Ok(()) } } pub struct SplitPaths<'a>(!, PhantomData<&'a ()>); @@ -239,3 +260,37 @@ pub fn exit(code: i32) -> ! { pub fn getpid() -> u32 { panic!("no pids on this platform") } + +mod uefi_shell { + use r_efi::protocols::shell; + + use super::super::helpers; + use crate::ptr::NonNull; + use crate::sync::atomic::{AtomicPtr, Ordering}; + + pub fn open_shell() -> Option> { + static LAST_VALID_HANDLE: AtomicPtr = + AtomicPtr::new(crate::ptr::null_mut()); + + if let Some(handle) = NonNull::new(LAST_VALID_HANDLE.load(Ordering::Acquire)) { + if let Ok(protocol) = helpers::open_protocol::( + handle, + r_efi::protocols::shell::PROTOCOL_GUID, + ) { + return Some(protocol); + } + } + + let handles = helpers::locate_handles(shell::PROTOCOL_GUID).ok()?; + for handle in handles { + if let Ok(protocol) = + helpers::open_protocol::(handle, shell::PROTOCOL_GUID) + { + LAST_VALID_HANDLE.store(handle.as_ptr(), Ordering::Release); + return Some(protocol); + } + } + + None + } +}