|
1 | 1 | use anyhow::{Context, Result}; |
2 | 2 | use async_trait::async_trait; |
| 3 | +use bytes::Bytes; |
3 | 4 | use http::StatusCode; |
4 | 5 | use serde::de::DeserializeOwned; |
5 | 6 | use shuttle_common::models::error::ApiError; |
6 | 7 |
|
7 | | -/// A to_json wrapper for handling our error states |
| 8 | +/// Helpers for consuming and parsing response bodies and handling parsing of an ApiError if the response is 4xx/5xx |
8 | 9 | #[async_trait] |
9 | | -pub trait ToJson { |
| 10 | +pub trait ToBodyContent { |
10 | 11 | async fn to_json<T: DeserializeOwned>(self) -> Result<T>; |
| 12 | + async fn to_text(self) -> Result<String>; |
| 13 | + async fn to_bytes(self) -> Result<Bytes>; |
| 14 | + async fn to_empty(self) -> Result<()>; |
| 15 | +} |
| 16 | + |
| 17 | +fn into_api_error(body: &str, status_code: StatusCode) -> ApiError { |
| 18 | + #[cfg(feature = "tracing")] |
| 19 | + tracing::trace!("Parsing response as API error"); |
| 20 | + |
| 21 | + let res: ApiError = match serde_json::from_str(body) { |
| 22 | + Ok(res) => res, |
| 23 | + _ => ApiError::new( |
| 24 | + format!("Failed to parse error response from the server:\n{}", body), |
| 25 | + status_code, |
| 26 | + ), |
| 27 | + }; |
| 28 | + |
| 29 | + res |
| 30 | +} |
| 31 | + |
| 32 | +/// Tries to convert bytes to string. If not possible, returns a string symbolizing the bytes and the length |
| 33 | +fn bytes_to_string_with_fallback(bytes: Bytes) -> String { |
| 34 | + String::from_utf8(bytes.to_vec()).unwrap_or_else(|_| format!("[{} bytes]", bytes.len())) |
11 | 35 | } |
12 | 36 |
|
13 | 37 | #[async_trait] |
14 | | -impl ToJson for reqwest::Response { |
| 38 | +impl ToBodyContent for reqwest::Response { |
15 | 39 | async fn to_json<T: DeserializeOwned>(self) -> Result<T> { |
16 | 40 | let status_code = self.status(); |
17 | 41 | let bytes = self.bytes().await?; |
18 | | - let string = String::from_utf8(bytes.to_vec()) |
19 | | - .unwrap_or_else(|_| format!("[{} bytes]", bytes.len())); |
| 42 | + let string = bytes_to_string_with_fallback(bytes); |
20 | 43 |
|
21 | 44 | #[cfg(feature = "tracing")] |
22 | 45 | tracing::trace!(response = %string, "Parsing response as JSON"); |
23 | 46 |
|
24 | | - if matches!( |
25 | | - status_code, |
26 | | - StatusCode::OK | StatusCode::SWITCHING_PROTOCOLS |
27 | | - ) { |
28 | | - serde_json::from_str(&string).context("failed to parse a successful response") |
29 | | - } else { |
30 | | - #[cfg(feature = "tracing")] |
31 | | - tracing::trace!("Parsing response as API error"); |
32 | | - |
33 | | - let res: ApiError = match serde_json::from_str(&string) { |
34 | | - Ok(res) => res, |
35 | | - _ => ApiError::new( |
36 | | - format!( |
37 | | - "Failed to parse error response from the server:\n{}", |
38 | | - string |
39 | | - ), |
40 | | - status_code, |
41 | | - ), |
42 | | - }; |
43 | | - |
44 | | - Err(res.into()) |
| 47 | + if status_code.is_client_error() || status_code.is_server_error() { |
| 48 | + return Err(into_api_error(&string, status_code).into()); |
| 49 | + } |
| 50 | + |
| 51 | + serde_json::from_str(&string).context("failed to parse a successful response") |
| 52 | + } |
| 53 | + |
| 54 | + async fn to_text(self) -> Result<String> { |
| 55 | + let status_code = self.status(); |
| 56 | + let bytes = self.bytes().await?; |
| 57 | + let string = bytes_to_string_with_fallback(bytes); |
| 58 | + |
| 59 | + #[cfg(feature = "tracing")] |
| 60 | + tracing::trace!(response = %string, "Parsing response as text"); |
| 61 | + |
| 62 | + if status_code.is_client_error() || status_code.is_server_error() { |
| 63 | + return Err(into_api_error(&string, status_code).into()); |
45 | 64 | } |
| 65 | + |
| 66 | + Ok(string) |
| 67 | + } |
| 68 | + |
| 69 | + async fn to_bytes(self) -> Result<Bytes> { |
| 70 | + let status_code = self.status(); |
| 71 | + let bytes = self.bytes().await?; |
| 72 | + |
| 73 | + #[cfg(feature = "tracing")] |
| 74 | + tracing::trace!(response_length = bytes.len(), "Got response bytes"); |
| 75 | + |
| 76 | + if status_code.is_client_error() || status_code.is_server_error() { |
| 77 | + let string = bytes_to_string_with_fallback(bytes); |
| 78 | + return Err(into_api_error(&string, status_code).into()); |
| 79 | + } |
| 80 | + |
| 81 | + Ok(bytes) |
| 82 | + } |
| 83 | + |
| 84 | + async fn to_empty(self) -> Result<()> { |
| 85 | + let status_code = self.status(); |
| 86 | + |
| 87 | + if status_code.is_client_error() || status_code.is_server_error() { |
| 88 | + let bytes = self.bytes().await?; |
| 89 | + let string = bytes_to_string_with_fallback(bytes); |
| 90 | + return Err(into_api_error(&string, status_code).into()); |
| 91 | + } |
| 92 | + |
| 93 | + Ok(()) |
46 | 94 | } |
47 | 95 | } |
0 commit comments