Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions uefi-test-runner/src/proto/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,65 @@ pub fn test_current_dir(shell: &ScopedProtocol<Shell>) {
assert_eq!(cur_fs_str, expected_fs_str);
}

/// Test `var()`, `vars()`, and `set_var()`
pub fn test_var(shell: &ScopedProtocol<Shell>) {
/* Test retrieving list of environment variable names */
let mut cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("path"));
// check pre-defined shell variables; see UEFI Shell spec
assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("nonesting"));
let cur_env_vec = shell.vars();
let default_len = cur_env_vec.count();

/* Test setting and getting a specific environment variable */
let cur_env_vec = shell.vars();
let test_var = cstr16!("test_var");
let test_val = cstr16!("test_val");
assert!(shell.var(test_var).is_none());
let status = shell.set_var(test_var, test_val, false);
assert!(status.is_ok());
let cur_env_str = shell
.var(test_var)
.expect("Could not get environment variable");
assert_eq!(cur_env_str, test_val);

let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(!found_var);
let cur_env_vec = shell.vars();
let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(found_var);

let cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.count(), default_len + 1);

/* Test deleting environment variable */
let test_val = cstr16!("");
let status = shell.set_var(test_var, test_val, false);
assert!(status.is_ok());
assert!(shell.var(test_var).is_none());

let cur_env_vec = shell.vars();
let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(!found_var);
let cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.count(), default_len);
}

pub fn test() {
info!("Running shell protocol tests");

Expand All @@ -109,4 +168,5 @@ pub fn test() {
boot::open_protocol_exclusive::<Shell>(handle).expect("Failed to open Shell protocol");

test_current_dir(&shell);
test_var(&shell);
}
1 change: 1 addition & 0 deletions uefi/src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod rng;
#[cfg(feature = "alloc")]
pub mod scsi;
pub mod security;
#[cfg(feature = "alloc")]
pub mod shell;
pub mod shell_params;
pub mod shim;
Expand Down
192 changes: 192 additions & 0 deletions uefi/src/proto/shell/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

use crate::proto::unsafe_protocol;
use crate::{CStr16, Char16, Error, Result, Status, StatusExt};

use core::marker::PhantomData;
use core::ptr;
use uefi_raw::protocol::shell::ShellProtocol;

Expand All @@ -13,6 +15,45 @@ use uefi_raw::protocol::shell::ShellProtocol;
#[unsafe_protocol(ShellProtocol::GUID)]
pub struct Shell(ShellProtocol);

/// Trait for implementing the var function
pub trait ShellVar {
/// Gets the value of the specified environment variable
fn var(&self, name: &CStr16) -> Option<&CStr16>;
}

/// Iterator over the names of environmental variables obtained from the Shell protocol.
#[derive(Debug)]
pub struct Vars<'a, T: ShellVar> {
/// Char16 containing names of environment variables
names: *const Char16,
/// Reference to Shell Protocol
protocol: *const T,
/// Placeholder to attach a lifetime to `Vars`
placeholder: PhantomData<&'a CStr16>,
}

impl<'a, T: ShellVar + 'a> Iterator for Vars<'a, T> {
type Item = (&'a CStr16, Option<&'a CStr16>);
// We iterate a list of NUL terminated CStr16s.
// The list is terminated with a double NUL.
fn next(&mut self) -> Option<Self::Item> {
let s = unsafe { CStr16::from_ptr(self.names) };
if s.is_empty() {
None
} else {
self.names = unsafe { self.names.add(s.num_chars() + 1) };
Some((s, unsafe { self.protocol.as_ref().unwrap().var(s) }))
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation could be simplified, something like this:

    fn next(&mut self) -> Option<Self::Item> {
        let s = unsafe { CStr16::from_ptr(self.inner) };
        if s.is_empty() {
            None
        } else {
            self.inner = unsafe { self.inner.add(s.num_chars() + 1) };
            Some(s)
        }
    }


impl ShellVar for Shell {
/// Gets the value of the specified environment variable
fn var(&self, name: &CStr16) -> Option<&CStr16> {
self.var(name)
}
}

impl Shell {
/// Returns the current directory on the specified device.
///
Expand Down Expand Up @@ -54,4 +95,155 @@ impl Shell {
let dir_ptr: *const Char16 = directory.map_or(ptr::null(), |x| x.as_ptr());
unsafe { (self.0.set_cur_dir)(fs_ptr.cast(), dir_ptr.cast()) }.to_result()
}

/// Gets the value of the specified environment variable
///
/// # Arguments
///
/// * `name` - The environment variable name of which to retrieve the
/// value.
///
/// # Returns
///
/// * `Some(<env_value>)` - &CStr16 containing the value of the
/// environment variable
/// * `None` - If environment variable does not exist
#[must_use]
pub fn var(&self, name: &CStr16) -> Option<&CStr16> {
let name_ptr: *const Char16 = name.as_ptr();
let var_val = unsafe { (self.0.get_env)(name_ptr.cast()) };
if var_val.is_null() {
None
} else {
unsafe { Some(CStr16::from_ptr(var_val.cast())) }
}
}

/// Gets an iterator over the names of all environment variables
#[must_use]
pub fn vars(&self) -> Vars<'_, Self> {
let env_ptr = unsafe { (self.0.get_env)(ptr::null()) };
Vars {
names: env_ptr.cast::<Char16>(),
protocol: self,
placeholder: PhantomData,
}
}

/// Sets the environment variable
///
/// # Arguments
///
/// * `name` - The environment variable for which to set the value
/// * `value` - The new value of the environment variable
/// * `volatile` - Indicates whether the variable is volatile or
/// not
///
/// # Returns
///
/// * `Status::SUCCESS` - The variable was successfully set
pub fn set_var(&self, name: &CStr16, value: &CStr16, volatile: bool) -> Result {
let name_ptr: *const Char16 = name.as_ptr();
let value_ptr: *const Char16 = value.as_ptr();
unsafe { (self.0.set_env)(name_ptr.cast(), value_ptr.cast(), volatile) }.to_result()
}
}

#[cfg(test)]
mod tests {
use super::*;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use uefi::cstr16;

struct ShellMock<'a> {
inner: BTreeMap<&'a CStr16, &'a CStr16>,
}

impl<'a> ShellMock<'a> {
fn new(names: Vec<&'a CStr16>, values: Vec<&'a CStr16>) -> ShellMock<'a> {
let mut inner_map = BTreeMap::new();
for (name, val) in names.iter().zip(values.iter()) {
inner_map.insert(*name, *val);
}
ShellMock { inner: inner_map }
}
}
impl<'a> ShellVar for ShellMock<'a> {
fn var(&self, name: &CStr16) -> Option<&CStr16> {
if let Some(val) = self.inner.get(name) {
Some(*val)
} else {
None
}
}
}

/// Testing Vars struct
#[test]
fn test_vars() {
// Empty Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(0);
vars_mock.push(0);
let mut vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(Vec::new(), Vec::new()),
placeholder: PhantomData,
};

assert!(vars.next().is_none());

// One environment variable in Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(b'f' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(0);
vars_mock.push(0);
let vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(Vec::from([cstr16!("foo")]), Vec::from([cstr16!("value")])),
placeholder: PhantomData,
};
assert_eq!(
vars.collect::<Vec<_>>(),
Vec::from([(cstr16!("foo"), Some(cstr16!("value")))])
);

// Multiple environment variables in Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(b'f' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'1' as u16);
vars_mock.push(0);
vars_mock.push(b'b' as u16);
vars_mock.push(b'a' as u16);
vars_mock.push(b'r' as u16);
vars_mock.push(0);
vars_mock.push(b'b' as u16);
vars_mock.push(b'a' as u16);
vars_mock.push(b'z' as u16);
vars_mock.push(b'2' as u16);
vars_mock.push(0);
vars_mock.push(0);

let vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(
Vec::from([cstr16!("foo1"), cstr16!("bar"), cstr16!("baz2")]),
Vec::from([cstr16!("value"), cstr16!("one"), cstr16!("two")]),
),
placeholder: PhantomData,
};
assert_eq!(
vars.collect::<Vec<_>>(),
Vec::from([
(cstr16!("foo1"), Some(cstr16!("value"))),
(cstr16!("bar"), Some(cstr16!("one"))),
(cstr16!("baz2"), Some(cstr16!("two")))
])
);
}
}
Loading