Skip to content

Commit f821e38

Browse files
committed
refactor: improve wrapper types with derive_more and PathBuf
- Add derive_more dependency for automatic trait derivation - Refactor SshPrivateKeyFile to use PathBuf instead of String for proper file path handling - Replace manual trait implementations with derive_more for Display trait - Use serde(transparent) for automatic serialization of newtype wrappers - Replace anyhow with explicit error enums using thiserror for better error handling - Change from infallible constructors (From) to fallible ones (TryFrom) with proper validation - Add comprehensive unit tests for both wrapper types (27 new tests total) - Update integration tests to handle new validation requirements - All linters pass and type safety is improved while maintaining backward compatibility
1 parent 2a0a0cb commit f821e38

File tree

5 files changed

+310
-53
lines changed

5 files changed

+310
-53
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ path = "src/bin/linter.rs"
2828
tokio = { version = "1.0", features = [ "full" ] }
2929
anyhow = "1.0"
3030
clap = { version = "4.0", features = [ "derive" ] }
31+
derive_more = "0.99"
3132
serde = { version = "1.0", features = [ "derive" ] }
3233
serde_json = "1.0"
3334
tempfile = "3.0"
Lines changed: 157 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
//! Ansible host wrapper type for IP address validation and serialization
22
3-
use anyhow::{Context, Result};
3+
use derive_more::{Display, From};
44
use serde::Serialize;
5-
use std::fmt;
65
use std::net::IpAddr;
76
use std::str::FromStr;
7+
use thiserror::Error;
8+
9+
/// Errors that can occur when working with Ansible hosts
10+
#[derive(Debug, Error, PartialEq)]
11+
pub enum AnsibleHostError {
12+
#[error("Invalid IP address format: {input}")]
13+
InvalidIpAddress { input: String },
14+
}
815

916
/// Wrapper type for Ansible host address using the newtype pattern
1017
///
@@ -16,49 +23,179 @@ use std::str::FromStr;
1623
/// - SSH proxy configurations
1724
///
1825
/// For this implementation, we only support IP addresses (IPv4 and IPv6) for simplicity.
19-
#[derive(Debug, Clone, PartialEq, Eq)]
20-
pub struct AnsibleHost(IpAddr);
26+
#[derive(Debug, Clone, PartialEq, Eq, Display, From, Serialize)]
27+
#[display(fmt = "{ip}")]
28+
#[serde(transparent)]
29+
pub struct AnsibleHost {
30+
ip: IpAddr,
31+
}
2132

2233
impl AnsibleHost {
2334
/// Create a new `AnsibleHost` from an IP address
2435
#[must_use]
2536
pub fn new(ip: IpAddr) -> Self {
26-
Self(ip)
37+
Self { ip }
2738
}
2839

2940
/// Get the inner IP address
3041
#[must_use]
3142
pub fn as_ip_addr(&self) -> &IpAddr {
32-
&self.0
43+
&self.ip
3344
}
3445

3546
/// Convert to string representation
3647
#[must_use]
3748
pub fn as_str(&self) -> String {
38-
self.0.to_string()
49+
self.ip.to_string()
3950
}
4051
}
4152

4253
impl FromStr for AnsibleHost {
43-
type Err = anyhow::Error;
54+
type Err = AnsibleHostError;
4455

4556
fn from_str(s: &str) -> Result<Self, Self::Err> {
46-
let ip = IpAddr::from_str(s).with_context(|| format!("Invalid IP address format: {s}"))?;
47-
Ok(Self(ip))
57+
let ip = IpAddr::from_str(s).map_err(|_| AnsibleHostError::InvalidIpAddress {
58+
input: s.to_string(),
59+
})?;
60+
Ok(Self::new(ip))
4861
}
4962
}
5063

51-
impl fmt::Display for AnsibleHost {
52-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53-
write!(f, "{}", self.0)
64+
#[cfg(test)]
65+
mod tests {
66+
use super::*;
67+
use serde_json;
68+
use std::net::{Ipv4Addr, Ipv6Addr};
69+
70+
#[test]
71+
fn test_new_with_ipv4() {
72+
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
73+
let host = AnsibleHost::new(ip);
74+
assert_eq!(host.as_ip_addr(), &ip);
75+
assert_eq!(host.as_str(), "192.168.1.1");
76+
}
77+
78+
#[test]
79+
fn test_new_with_ipv6() {
80+
let ip = IpAddr::V6(Ipv6Addr::new(
81+
0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334,
82+
));
83+
let host = AnsibleHost::new(ip);
84+
assert_eq!(host.as_ip_addr(), &ip);
85+
assert_eq!(host.as_str(), "2001:db8:85a3::8a2e:370:7334");
86+
}
87+
88+
#[test]
89+
fn test_from_str_valid_ipv4() {
90+
let result = AnsibleHost::from_str("192.168.1.1");
91+
assert!(result.is_ok());
92+
let host = result.unwrap();
93+
assert_eq!(host.as_str(), "192.168.1.1");
94+
}
95+
96+
#[test]
97+
fn test_from_str_valid_ipv6() {
98+
let result = AnsibleHost::from_str("2001:db8:85a3::8a2e:370:7334");
99+
assert!(result.is_ok());
100+
let host = result.unwrap();
101+
assert_eq!(host.as_str(), "2001:db8:85a3::8a2e:370:7334");
102+
}
103+
104+
#[test]
105+
fn test_from_str_localhost_ipv4() {
106+
let result = AnsibleHost::from_str("127.0.0.1");
107+
assert!(result.is_ok());
108+
let host = result.unwrap();
109+
assert_eq!(host.as_str(), "127.0.0.1");
110+
}
111+
112+
#[test]
113+
fn test_from_str_localhost_ipv6() {
114+
let result = AnsibleHost::from_str("::1");
115+
assert!(result.is_ok());
116+
let host = result.unwrap();
117+
assert_eq!(host.as_str(), "::1");
118+
}
119+
120+
#[test]
121+
fn test_from_str_invalid_ip() {
122+
let result = AnsibleHost::from_str("invalid.ip.address");
123+
assert_eq!(
124+
result,
125+
Err(AnsibleHostError::InvalidIpAddress {
126+
input: "invalid.ip.address".to_string()
127+
})
128+
);
129+
}
130+
131+
#[test]
132+
fn test_from_str_invalid_ipv4() {
133+
let result = AnsibleHost::from_str("256.256.256.256");
134+
assert_eq!(
135+
result,
136+
Err(AnsibleHostError::InvalidIpAddress {
137+
input: "256.256.256.256".to_string()
138+
})
139+
);
140+
}
141+
142+
#[test]
143+
fn test_from_str_empty_string() {
144+
let result = AnsibleHost::from_str("");
145+
assert_eq!(
146+
result,
147+
Err(AnsibleHostError::InvalidIpAddress {
148+
input: String::new()
149+
})
150+
);
151+
}
152+
153+
#[test]
154+
fn test_display_trait() {
155+
let host = AnsibleHost::from_str("192.168.1.100").unwrap();
156+
assert_eq!(format!("{host}"), "192.168.1.100");
157+
}
158+
159+
#[test]
160+
fn test_serialization_ipv4() {
161+
let host = AnsibleHost::from_str("10.0.0.1").unwrap();
162+
let json = serde_json::to_string(&host).unwrap();
163+
assert_eq!(json, "\"10.0.0.1\"");
164+
}
165+
166+
#[test]
167+
fn test_serialization_ipv6() {
168+
let host = AnsibleHost::from_str("::1").unwrap();
169+
let json = serde_json::to_string(&host).unwrap();
170+
assert_eq!(json, "\"::1\"");
171+
}
172+
173+
#[test]
174+
fn test_clone_and_equality() {
175+
let host1 = AnsibleHost::from_str("192.168.1.1").unwrap();
176+
let host2 = host1.clone();
177+
assert_eq!(host1, host2);
178+
}
179+
180+
#[test]
181+
fn test_from_trait_ipv4() {
182+
let ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
183+
let host = AnsibleHost::from(ip);
184+
assert_eq!(host.as_str(), "172.16.0.1");
185+
}
186+
187+
#[test]
188+
fn test_from_trait_ipv6() {
189+
let ip = IpAddr::V6(Ipv6Addr::LOCALHOST);
190+
let host = AnsibleHost::from(ip);
191+
assert_eq!(host.as_str(), "::1");
54192
}
55-
}
56193

57-
impl Serialize for AnsibleHost {
58-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59-
where
60-
S: serde::Serializer,
61-
{
62-
serializer.serialize_str(&self.0.to_string())
194+
#[test]
195+
fn test_error_display() {
196+
let error = AnsibleHostError::InvalidIpAddress {
197+
input: "bad_input".to_string(),
198+
};
199+
assert_eq!(format!("{error}"), "Invalid IP address format: bad_input");
63200
}
64201
}

src/template/wrappers/ansible/inventory/mod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ impl InventoryContext {
3434
/// # Errors
3535
///
3636
/// Returns an error if the `ansible_host` cannot be parsed as a valid IP address
37+
/// or if the `ansible_ssh_private_key_file` path is invalid
3738
pub fn new(ansible_host: &str, ansible_ssh_private_key_file: &str) -> Result<Self> {
38-
let ansible_host = AnsibleHost::from_str(ansible_host)?;
39-
let ansible_ssh_private_key_file = SshPrivateKeyFile::new(ansible_ssh_private_key_file);
39+
let ansible_host = AnsibleHost::from_str(ansible_host)
40+
.map_err(|e| anyhow::anyhow!("Invalid ansible host: {}", e))?;
41+
let ansible_ssh_private_key_file = SshPrivateKeyFile::new(ansible_ssh_private_key_file)
42+
.map_err(|e| anyhow::anyhow!("Invalid SSH private key file path: {}", e))?;
4043

4144
Ok(Self {
4245
ansible_host,
@@ -53,7 +56,7 @@ impl InventoryContext {
5356
/// Get the ansible SSH private key file path as a string
5457
#[must_use]
5558
pub fn ansible_ssh_private_key_file(&self) -> String {
56-
self.ansible_ssh_private_key_file.as_string()
59+
self.ansible_ssh_private_key_file.as_str()
5760
}
5861

5962
/// Get the ansible host wrapper

0 commit comments

Comments
 (0)