|
1 | | -use std::fmt::{Display, Formatter}; |
| 1 | +use std::fmt::Display; |
2 | 2 |
|
3 | | -use http::StatusCode; |
| 3 | +use http::{status::InvalidStatusCode, StatusCode}; |
4 | 4 | use serde::{Deserialize, Serialize}; |
5 | 5 |
|
6 | | -#[cfg(feature = "display")] |
7 | | -use crossterm::style::Stylize; |
8 | | - |
9 | 6 | #[cfg(feature = "axum")] |
10 | 7 | impl axum::response::IntoResponse for ApiError { |
11 | 8 | fn into_response(self) -> axum::response::Response { |
| 9 | + #[cfg(feature = "tracing-in-errors")] |
12 | 10 | tracing::warn!("{}", self.message); |
13 | 11 |
|
14 | | - (self.status(), axum::Json(self)).into_response() |
| 12 | + ( |
| 13 | + self.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), |
| 14 | + axum::Json(self), |
| 15 | + ) |
| 16 | + .into_response() |
15 | 17 | } |
16 | 18 | } |
17 | 19 |
|
18 | 20 | #[derive(Serialize, Deserialize, Debug)] |
19 | 21 | #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] |
20 | 22 | #[typeshare::typeshare] |
21 | 23 | pub struct ApiError { |
22 | | - pub message: String, |
23 | | - pub status_code: u16, |
| 24 | + message: String, |
| 25 | + status_code: u16, |
24 | 26 | } |
25 | 27 |
|
26 | 28 | impl ApiError { |
27 | | - pub fn internal(message: &str) -> Self { |
| 29 | + #[inline(always)] |
| 30 | + pub fn new(message: impl Display, status_code: StatusCode) -> Self { |
28 | 31 | Self { |
29 | 32 | message: message.to_string(), |
30 | | - status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), |
| 33 | + status_code: status_code.as_u16(), |
31 | 34 | } |
32 | 35 | } |
| 36 | + #[inline(always)] |
| 37 | + pub fn status(&self) -> Result<StatusCode, InvalidStatusCode> { |
| 38 | + StatusCode::from_u16(self.status_code) |
| 39 | + } |
| 40 | + #[inline(always)] |
| 41 | + pub fn message(&self) -> &str { |
| 42 | + self.message.as_str() |
| 43 | + } |
| 44 | + |
| 45 | + /// Create a one-off internal error with a string message exposed to the user. |
| 46 | + #[inline(always)] |
| 47 | + pub fn internal(message: impl AsRef<str>) -> Self { |
| 48 | + #[cfg(feature = "tracing-in-errors")] |
| 49 | + { |
| 50 | + /// Dummy wrapper to allow logging a string `as &dyn std::error::Error` |
| 51 | + #[derive(Debug)] |
| 52 | + struct InternalError(String); |
| 53 | + impl std::error::Error for InternalError {} |
| 54 | + impl std::fmt::Display for InternalError { |
| 55 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 56 | + f.write_str(self.0.as_str()) |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + tracing::error!( |
| 61 | + error = &InternalError(message.as_ref().to_owned()) as &dyn std::error::Error, |
| 62 | + "Internal API Error" |
| 63 | + ); |
| 64 | + } |
| 65 | + |
| 66 | + Self::_internal(message.as_ref()) |
| 67 | + } |
33 | 68 |
|
34 | 69 | /// Creates an internal error without exposing sensitive information to the user. |
35 | 70 | #[inline(always)] |
36 | 71 | #[allow(unused_variables)] |
37 | | - pub fn internal_safe<E>(message: &str, error: E) -> Self |
38 | | - where |
39 | | - E: std::error::Error + 'static, |
40 | | - { |
41 | | - tracing::error!(error = &error as &dyn std::error::Error, "{message}"); |
| 72 | + pub fn internal_safe<E: std::error::Error + 'static>(safe_msg: impl Display, error: E) -> Self { |
| 73 | + #[cfg(feature = "tracing-in-errors")] |
| 74 | + tracing::error!(error = &error as &dyn std::error::Error, "{}", safe_msg); |
42 | 75 |
|
43 | 76 | // Return the raw error during debug builds |
44 | 77 | #[cfg(debug_assertions)] |
45 | 78 | { |
46 | | - ApiError::internal(&error.to_string()) |
| 79 | + Self::_internal(error) |
47 | 80 | } |
48 | 81 | // Return the safe message during release builds |
49 | 82 | #[cfg(not(debug_assertions))] |
50 | 83 | { |
51 | | - ApiError::internal(message) |
| 84 | + Self::_internal(safe_msg) |
52 | 85 | } |
53 | 86 | } |
54 | 87 |
|
55 | | - pub fn unavailable(error: impl std::error::Error) -> Self { |
56 | | - Self { |
57 | | - message: error.to_string(), |
58 | | - status_code: StatusCode::SERVICE_UNAVAILABLE.as_u16(), |
59 | | - } |
| 88 | + // 5xx |
| 89 | + #[inline(always)] |
| 90 | + fn _internal(error: impl Display) -> Self { |
| 91 | + Self::new(error.to_string(), StatusCode::INTERNAL_SERVER_ERROR) |
60 | 92 | } |
61 | | - |
62 | | - pub fn bad_request(error: impl std::error::Error) -> Self { |
63 | | - Self { |
64 | | - message: error.to_string(), |
65 | | - status_code: StatusCode::BAD_REQUEST.as_u16(), |
66 | | - } |
| 93 | + #[inline(always)] |
| 94 | + pub fn service_unavailable(error: impl Display) -> Self { |
| 95 | + Self::new(error.to_string(), StatusCode::SERVICE_UNAVAILABLE) |
67 | 96 | } |
68 | | - |
69 | | - pub fn unauthorized() -> Self { |
70 | | - Self { |
71 | | - message: "Unauthorized".to_string(), |
72 | | - status_code: StatusCode::UNAUTHORIZED.as_u16(), |
73 | | - } |
| 97 | + // 4xx |
| 98 | + #[inline(always)] |
| 99 | + pub fn bad_request(error: impl Display) -> Self { |
| 100 | + Self::new(error.to_string(), StatusCode::BAD_REQUEST) |
74 | 101 | } |
75 | | - |
76 | | - pub fn forbidden() -> Self { |
77 | | - Self { |
78 | | - message: "Forbidden".to_string(), |
79 | | - status_code: StatusCode::FORBIDDEN.as_u16(), |
80 | | - } |
| 102 | + #[inline(always)] |
| 103 | + pub fn unauthorized(error: impl Display) -> Self { |
| 104 | + Self::new(error.to_string(), StatusCode::UNAUTHORIZED) |
81 | 105 | } |
82 | | - |
83 | | - pub fn status(&self) -> StatusCode { |
84 | | - StatusCode::from_u16(self.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) |
| 106 | + #[inline(always)] |
| 107 | + pub fn forbidden(error: impl Display) -> Self { |
| 108 | + Self::new(error.to_string(), StatusCode::FORBIDDEN) |
| 109 | + } |
| 110 | + #[inline(always)] |
| 111 | + pub fn not_found(error: impl Display) -> Self { |
| 112 | + Self::new(error.to_string(), StatusCode::NOT_FOUND) |
85 | 113 | } |
86 | 114 | } |
87 | 115 |
|
88 | 116 | pub trait ErrorContext<T> { |
89 | 117 | /// Make a new internal server error with the given message. |
90 | | - #[inline(always)] |
91 | | - fn context_internal_error(self, message: &str) -> Result<T, ApiError> |
92 | | - where |
93 | | - Self: Sized, |
94 | | - { |
95 | | - self.with_context_internal_error(move || message.to_string()) |
96 | | - } |
| 118 | + fn context_internal_error(self, message: impl Display) -> Result<T, ApiError>; |
97 | 119 |
|
98 | 120 | /// Make a new internal server error using the given function to create the message. |
99 | | - fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError>; |
100 | | - |
101 | | - /// Make a new bad request error with the given message. |
102 | | - #[inline(always)] |
103 | | - fn context_bad_request(self, message: &str) -> Result<T, ApiError> |
104 | | - where |
105 | | - Self: Sized, |
106 | | - { |
107 | | - self.with_context_bad_request(move || message.to_string()) |
108 | | - } |
109 | | - |
110 | | - /// Make a new bad request error using the given function to create the message. |
111 | | - fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError>; |
112 | | - |
113 | | - /// Make a new not found error with the given message. |
114 | 121 | #[inline(always)] |
115 | | - fn context_not_found(self, message: &str) -> Result<T, ApiError> |
| 122 | + fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> |
116 | 123 | where |
117 | 124 | Self: Sized, |
118 | 125 | { |
119 | | - self.with_context_not_found(move || message.to_string()) |
| 126 | + self.context_internal_error(message()) |
120 | 127 | } |
121 | | - |
122 | | - /// Make a new not found error using the given function to create the message. |
123 | | - fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError>; |
124 | 128 | } |
125 | 129 |
|
126 | 130 | impl<T, E> ErrorContext<T> for Result<T, E> |
127 | 131 | where |
128 | 132 | E: std::error::Error + 'static, |
129 | 133 | { |
130 | 134 | #[inline(always)] |
131 | | - fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
132 | | - match self { |
133 | | - Ok(value) => Ok(value), |
134 | | - Err(error) => Err(ApiError::internal_safe(message().as_ref(), error)), |
135 | | - } |
136 | | - } |
137 | | - |
138 | | - #[inline(always)] |
139 | | - fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
140 | | - match self { |
141 | | - Ok(value) => Ok(value), |
142 | | - Err(error) => Err({ |
143 | | - let message = message(); |
144 | | - tracing::warn!( |
145 | | - error = &error as &dyn std::error::Error, |
146 | | - "bad request: {message}" |
147 | | - ); |
148 | | - |
149 | | - ApiError { |
150 | | - message, |
151 | | - status_code: StatusCode::BAD_REQUEST.as_u16(), |
152 | | - } |
153 | | - }), |
154 | | - } |
155 | | - } |
156 | | - |
157 | | - #[inline(always)] |
158 | | - fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
159 | | - match self { |
160 | | - Ok(value) => Ok(value), |
161 | | - Err(error) => Err({ |
162 | | - let message = message(); |
163 | | - tracing::warn!( |
164 | | - error = &error as &dyn std::error::Error, |
165 | | - "not found: {message}" |
166 | | - ); |
167 | | - |
168 | | - ApiError { |
169 | | - message, |
170 | | - status_code: StatusCode::NOT_FOUND.as_u16(), |
171 | | - } |
172 | | - }), |
173 | | - } |
174 | | - } |
175 | | -} |
176 | | - |
177 | | -impl<T> ErrorContext<T> for Option<T> { |
178 | | - #[inline] |
179 | | - fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
180 | | - match self { |
181 | | - Some(value) => Ok(value), |
182 | | - None => Err(ApiError::internal(message().as_ref())), |
183 | | - } |
184 | | - } |
185 | | - |
186 | | - #[inline] |
187 | | - fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
188 | | - match self { |
189 | | - Some(value) => Ok(value), |
190 | | - None => Err({ |
191 | | - ApiError { |
192 | | - message: message(), |
193 | | - status_code: StatusCode::BAD_REQUEST.as_u16(), |
194 | | - } |
195 | | - }), |
196 | | - } |
197 | | - } |
198 | | - |
199 | | - #[inline] |
200 | | - fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result<T, ApiError> { |
201 | | - match self { |
202 | | - Some(value) => Ok(value), |
203 | | - None => Err({ |
204 | | - ApiError { |
205 | | - message: message(), |
206 | | - status_code: StatusCode::NOT_FOUND.as_u16(), |
207 | | - } |
208 | | - }), |
209 | | - } |
| 135 | + fn context_internal_error(self, message: impl Display) -> Result<T, ApiError> { |
| 136 | + self.map_err(|error| ApiError::internal_safe(message, error)) |
210 | 137 | } |
211 | 138 | } |
212 | 139 |
|
213 | | -impl Display for ApiError { |
214 | | - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
215 | | - #[cfg(feature = "display")] |
216 | | - return write!( |
| 140 | +impl std::fmt::Display for ApiError { |
| 141 | + #[cfg(feature = "display")] |
| 142 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 143 | + use crossterm::style::Stylize; |
| 144 | + write!( |
217 | 145 | f, |
218 | 146 | "{}\nMessage: {}", |
219 | | - self.status().to_string().bold(), |
| 147 | + self.status() |
| 148 | + .map(|s| s.to_string()) |
| 149 | + .unwrap_or("Unknown".to_owned()) |
| 150 | + .bold(), |
220 | 151 | self.message.to_string().red() |
221 | | - ); |
222 | | - #[cfg(not(feature = "display"))] |
223 | | - return write!(f, "{}\nMessage: {}", self.status(), self.message); |
| 152 | + ) |
| 153 | + } |
| 154 | + #[cfg(not(feature = "display"))] |
| 155 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 156 | + write!( |
| 157 | + f, |
| 158 | + "{}\nMessage: {}", |
| 159 | + self.status() |
| 160 | + .map(|s| s.to_string()) |
| 161 | + .unwrap_or("Unknown".to_owned()), |
| 162 | + self.message, |
| 163 | + ) |
224 | 164 | } |
225 | 165 | } |
226 | 166 |
|
|
0 commit comments