Skip to content

Commit 14fc644

Browse files
authored
refa(common): better ApiError constructors & getters, feature gate tracing (#2045)
* refa(common): better ApiError constructors & getters, feature gate tracing * axum status code fallback * use display for error types * clippy
1 parent 8fa79e6 commit 14fc644

File tree

4 files changed

+110
-166
lines changed

4 files changed

+110
-166
lines changed

api-client/src/util.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ impl ToJson for reqwest::Response {
3232

3333
let res: ApiError = match serde_json::from_str(&string) {
3434
Ok(res) => res,
35-
_ => ApiError {
36-
message: format!("Failed to parse response from the server:\n{}", string),
37-
status_code: status_code.as_u16(),
38-
},
35+
_ => ApiError::new(
36+
format!(
37+
"Failed to parse error response from the server:\n{}",
38+
string
39+
),
40+
status_code,
41+
),
3942
};
4043

4144
Err(res.into())

cargo-shuttle/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ impl Shuttle {
611611
// If API error contains message regarding format of error name, print that error and prompt again
612612
if let Ok(api_error) = e.downcast::<ApiError>() {
613613
// If the returned error string changes, this could break
614-
if api_error.message.contains("Invalid project name") {
615-
eprintln!("{}", api_error.message.yellow());
614+
if api_error.message().contains("Invalid project name") {
615+
eprintln!("{}", api_error.message().yellow());
616616
eprintln!("{}", "Try a different name.".yellow());
617617
return false;
618618
}

common/Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,19 @@ rstest = "0.24.0"
3232

3333
[features]
3434
# main features
35-
models = ["chrono/serde", "dep:tracing"]
35+
models = ["chrono/serde"]
3636
config = ["anyhow", "dirs", "toml"]
3737

3838
# additional sub-features
39-
axum = ["dep:axum", "dep:tracing"]
39+
axum = ["dep:axum"]
4040
display = ["chrono/clock", "dep:crossterm"]
4141
tables = ["models", "display", "dep:comfy-table"]
42+
tracing-in-errors = ["dep:tracing"]
4243
unknown-variants = [] # add fallback to Unknown variant on enum model deser
4344
utoipa = ["dep:utoipa"] # derive OpenAPI definitions for models
4445

4546
# internal / utility features
46-
integration-tests = []
47+
integration-tests = ["dep:tracing"]
4748

4849
[package.metadata.docs.rs]
4950
features = [

common/src/models/error.rs

Lines changed: 97 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,226 +1,166 @@
1-
use std::fmt::{Display, Formatter};
1+
use std::fmt::Display;
22

3-
use http::StatusCode;
3+
use http::{status::InvalidStatusCode, StatusCode};
44
use serde::{Deserialize, Serialize};
55

6-
#[cfg(feature = "display")]
7-
use crossterm::style::Stylize;
8-
96
#[cfg(feature = "axum")]
107
impl axum::response::IntoResponse for ApiError {
118
fn into_response(self) -> axum::response::Response {
9+
#[cfg(feature = "tracing-in-errors")]
1210
tracing::warn!("{}", self.message);
1311

14-
(self.status(), axum::Json(self)).into_response()
12+
(
13+
self.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
14+
axum::Json(self),
15+
)
16+
.into_response()
1517
}
1618
}
1719

1820
#[derive(Serialize, Deserialize, Debug)]
1921
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2022
#[typeshare::typeshare]
2123
pub struct ApiError {
22-
pub message: String,
23-
pub status_code: u16,
24+
message: String,
25+
status_code: u16,
2426
}
2527

2628
impl ApiError {
27-
pub fn internal(message: &str) -> Self {
29+
#[inline(always)]
30+
pub fn new(message: impl Display, status_code: StatusCode) -> Self {
2831
Self {
2932
message: message.to_string(),
30-
status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
33+
status_code: status_code.as_u16(),
3134
}
3235
}
36+
#[inline(always)]
37+
pub fn status(&self) -> Result<StatusCode, InvalidStatusCode> {
38+
StatusCode::from_u16(self.status_code)
39+
}
40+
#[inline(always)]
41+
pub fn message(&self) -> &str {
42+
self.message.as_str()
43+
}
44+
45+
/// Create a one-off internal error with a string message exposed to the user.
46+
#[inline(always)]
47+
pub fn internal(message: impl AsRef<str>) -> Self {
48+
#[cfg(feature = "tracing-in-errors")]
49+
{
50+
/// Dummy wrapper to allow logging a string `as &dyn std::error::Error`
51+
#[derive(Debug)]
52+
struct InternalError(String);
53+
impl std::error::Error for InternalError {}
54+
impl std::fmt::Display for InternalError {
55+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56+
f.write_str(self.0.as_str())
57+
}
58+
}
59+
60+
tracing::error!(
61+
error = &InternalError(message.as_ref().to_owned()) as &dyn std::error::Error,
62+
"Internal API Error"
63+
);
64+
}
65+
66+
Self::_internal(message.as_ref())
67+
}
3368

3469
/// Creates an internal error without exposing sensitive information to the user.
3570
#[inline(always)]
3671
#[allow(unused_variables)]
37-
pub fn internal_safe<E>(message: &str, error: E) -> Self
38-
where
39-
E: std::error::Error + 'static,
40-
{
41-
tracing::error!(error = &error as &dyn std::error::Error, "{message}");
72+
pub fn internal_safe<E: std::error::Error + 'static>(safe_msg: impl Display, error: E) -> Self {
73+
#[cfg(feature = "tracing-in-errors")]
74+
tracing::error!(error = &error as &dyn std::error::Error, "{}", safe_msg);
4275

4376
// Return the raw error during debug builds
4477
#[cfg(debug_assertions)]
4578
{
46-
ApiError::internal(&error.to_string())
79+
Self::_internal(error)
4780
}
4881
// Return the safe message during release builds
4982
#[cfg(not(debug_assertions))]
5083
{
51-
ApiError::internal(message)
84+
Self::_internal(safe_msg)
5285
}
5386
}
5487

55-
pub fn unavailable(error: impl std::error::Error) -> Self {
56-
Self {
57-
message: error.to_string(),
58-
status_code: StatusCode::SERVICE_UNAVAILABLE.as_u16(),
59-
}
88+
// 5xx
89+
#[inline(always)]
90+
fn _internal(error: impl Display) -> Self {
91+
Self::new(error.to_string(), StatusCode::INTERNAL_SERVER_ERROR)
6092
}
61-
62-
pub fn bad_request(error: impl std::error::Error) -> Self {
63-
Self {
64-
message: error.to_string(),
65-
status_code: StatusCode::BAD_REQUEST.as_u16(),
66-
}
93+
#[inline(always)]
94+
pub fn service_unavailable(error: impl Display) -> Self {
95+
Self::new(error.to_string(), StatusCode::SERVICE_UNAVAILABLE)
6796
}
68-
69-
pub fn unauthorized() -> Self {
70-
Self {
71-
message: "Unauthorized".to_string(),
72-
status_code: StatusCode::UNAUTHORIZED.as_u16(),
73-
}
97+
// 4xx
98+
#[inline(always)]
99+
pub fn bad_request(error: impl Display) -> Self {
100+
Self::new(error.to_string(), StatusCode::BAD_REQUEST)
74101
}
75-
76-
pub fn forbidden() -> Self {
77-
Self {
78-
message: "Forbidden".to_string(),
79-
status_code: StatusCode::FORBIDDEN.as_u16(),
80-
}
102+
#[inline(always)]
103+
pub fn unauthorized(error: impl Display) -> Self {
104+
Self::new(error.to_string(), StatusCode::UNAUTHORIZED)
81105
}
82-
83-
pub fn status(&self) -> StatusCode {
84-
StatusCode::from_u16(self.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
106+
#[inline(always)]
107+
pub fn forbidden(error: impl Display) -> Self {
108+
Self::new(error.to_string(), StatusCode::FORBIDDEN)
109+
}
110+
#[inline(always)]
111+
pub fn not_found(error: impl Display) -> Self {
112+
Self::new(error.to_string(), StatusCode::NOT_FOUND)
85113
}
86114
}
87115

88116
pub trait ErrorContext<T> {
89117
/// Make a new internal server error with the given message.
90-
#[inline(always)]
91-
fn context_internal_error(self, message: &str) -> Result<T, ApiError>
92-
where
93-
Self: Sized,
94-
{
95-
self.with_context_internal_error(move || message.to_string())
96-
}
118+
fn context_internal_error(self, message: impl Display) -> Result<T, ApiError>;
97119

98120
/// Make a new internal server error using the given function to create the message.
99-
fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
100-
101-
/// Make a new bad request error with the given message.
102-
#[inline(always)]
103-
fn context_bad_request(self, message: &str) -> Result<T, ApiError>
104-
where
105-
Self: Sized,
106-
{
107-
self.with_context_bad_request(move || message.to_string())
108-
}
109-
110-
/// Make a new bad request error using the given function to create the message.
111-
fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
112-
113-
/// Make a new not found error with the given message.
114121
#[inline(always)]
115-
fn context_not_found(self, message: &str) -> Result<T, ApiError>
122+
fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError>
116123
where
117124
Self: Sized,
118125
{
119-
self.with_context_not_found(move || message.to_string())
126+
self.context_internal_error(message())
120127
}
121-
122-
/// Make a new not found error using the given function to create the message.
123-
fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError>;
124128
}
125129

126130
impl<T, E> ErrorContext<T> for Result<T, E>
127131
where
128132
E: std::error::Error + 'static,
129133
{
130134
#[inline(always)]
131-
fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
132-
match self {
133-
Ok(value) => Ok(value),
134-
Err(error) => Err(ApiError::internal_safe(message().as_ref(), error)),
135-
}
136-
}
137-
138-
#[inline(always)]
139-
fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
140-
match self {
141-
Ok(value) => Ok(value),
142-
Err(error) => Err({
143-
let message = message();
144-
tracing::warn!(
145-
error = &error as &dyn std::error::Error,
146-
"bad request: {message}"
147-
);
148-
149-
ApiError {
150-
message,
151-
status_code: StatusCode::BAD_REQUEST.as_u16(),
152-
}
153-
}),
154-
}
155-
}
156-
157-
#[inline(always)]
158-
fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
159-
match self {
160-
Ok(value) => Ok(value),
161-
Err(error) => Err({
162-
let message = message();
163-
tracing::warn!(
164-
error = &error as &dyn std::error::Error,
165-
"not found: {message}"
166-
);
167-
168-
ApiError {
169-
message,
170-
status_code: StatusCode::NOT_FOUND.as_u16(),
171-
}
172-
}),
173-
}
174-
}
175-
}
176-
177-
impl<T> ErrorContext<T> for Option<T> {
178-
#[inline]
179-
fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
180-
match self {
181-
Some(value) => Ok(value),
182-
None => Err(ApiError::internal(message().as_ref())),
183-
}
184-
}
185-
186-
#[inline]
187-
fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
188-
match self {
189-
Some(value) => Ok(value),
190-
None => Err({
191-
ApiError {
192-
message: message(),
193-
status_code: StatusCode::BAD_REQUEST.as_u16(),
194-
}
195-
}),
196-
}
197-
}
198-
199-
#[inline]
200-
fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> {
201-
match self {
202-
Some(value) => Ok(value),
203-
None => Err({
204-
ApiError {
205-
message: message(),
206-
status_code: StatusCode::NOT_FOUND.as_u16(),
207-
}
208-
}),
209-
}
135+
fn context_internal_error(self, message: impl Display) -> Result<T, ApiError> {
136+
self.map_err(|error| ApiError::internal_safe(message, error))
210137
}
211138
}
212139

213-
impl Display for ApiError {
214-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215-
#[cfg(feature = "display")]
216-
return write!(
140+
impl std::fmt::Display for ApiError {
141+
#[cfg(feature = "display")]
142+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143+
use crossterm::style::Stylize;
144+
write!(
217145
f,
218146
"{}\nMessage: {}",
219-
self.status().to_string().bold(),
147+
self.status()
148+
.map(|s| s.to_string())
149+
.unwrap_or("Unknown".to_owned())
150+
.bold(),
220151
self.message.to_string().red()
221-
);
222-
#[cfg(not(feature = "display"))]
223-
return write!(f, "{}\nMessage: {}", self.status(), self.message);
152+
)
153+
}
154+
#[cfg(not(feature = "display"))]
155+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156+
write!(
157+
f,
158+
"{}\nMessage: {}",
159+
self.status()
160+
.map(|s| s.to_string())
161+
.unwrap_or("Unknown".to_owned()),
162+
self.message,
163+
)
224164
}
225165
}
226166

0 commit comments

Comments
 (0)