@@ -4,6 +4,7 @@ use std::path::Path;
4
4
5
5
use bytes:: Bytes ;
6
6
use chrono:: Utc ;
7
+ use http:: { HeaderValue , StatusCode } ;
7
8
use hyper:: Body ;
8
9
use tokio:: io:: AsyncWriteExt as _;
9
10
use uuid:: Uuid ;
@@ -12,11 +13,46 @@ const METADATA_VERSION: u32 = 0;
12
13
13
14
const DEFAULT_MAX_RETRIES : usize = 5 ;
14
15
16
+ #[ derive( thiserror:: Error , Debug ) ]
17
+ #[ non_exhaustive]
18
+ pub enum SyncError {
19
+ #[ error( "io: msg={msg}, err={err}" ) ]
20
+ Io {
21
+ msg : & ' static str ,
22
+ #[ source]
23
+ err : std:: io:: Error ,
24
+ } ,
25
+ #[ error( "invalid auth header: {0}" ) ]
26
+ InvalidAuthHeader ( http:: header:: InvalidHeaderValue ) ,
27
+ #[ error( "http dispatch error: {0}" ) ]
28
+ HttpDispatch ( hyper:: Error ) ,
29
+ #[ error( "body error: {0}" ) ]
30
+ HttpBody ( hyper:: Error ) ,
31
+ #[ error( "json decode error: {0}" ) ]
32
+ JsonDecode ( serde_json:: Error ) ,
33
+ #[ error( "json value error, unexpected value: {0}" ) ]
34
+ JsonValue ( serde_json:: Value ) ,
35
+ #[ error( "json encode error: {0}" ) ]
36
+ JsonEncode ( serde_json:: Error ) ,
37
+ #[ error( "failed to push frame: status={0}, error={1}" ) ]
38
+ PushFrame ( StatusCode , String ) ,
39
+ #[ error( "failed to verify metadata file version: expected={0}, got={1}" ) ]
40
+ VerifyVersion ( u32 , u32 ) ,
41
+ #[ error( "failed to verify metadata file hash: expected={0}, got={1}" ) ]
42
+ VerifyHash ( u32 , u32 ) ,
43
+ }
44
+
45
+ impl SyncError {
46
+ fn io ( msg : & ' static str ) -> impl FnOnce ( std:: io:: Error ) -> SyncError {
47
+ move |err| SyncError :: Io { msg, err }
48
+ }
49
+ }
50
+
15
51
pub struct SyncContext {
16
52
db_path : String ,
17
53
client : hyper:: Client < ConnectorService , Body > ,
18
54
sync_url : String ,
19
- auth_token : Option < String > ,
55
+ auth_token : Option < HeaderValue > ,
20
56
max_retries : usize ,
21
57
/// Represents the max_frame_no from the server.
22
58
durable_frame_num : u32 ,
@@ -33,6 +69,14 @@ impl SyncContext {
33
69
) -> Result < Self > {
34
70
let client = hyper:: client:: Client :: builder ( ) . build :: < _ , hyper:: Body > ( connector) ;
35
71
72
+ let auth_token = match auth_token {
73
+ Some ( t) => Some (
74
+ HeaderValue :: try_from ( format ! ( "Bearer {}" , t) )
75
+ . map_err ( SyncError :: InvalidAuthHeader ) ?,
76
+ ) ,
77
+ None => None ,
78
+ } ;
79
+
36
80
let mut me = Self {
37
81
db_path,
38
82
sync_url,
@@ -85,36 +129,53 @@ impl SyncContext {
85
129
86
130
match & self . auth_token {
87
131
Some ( auth_token) => {
88
- let auth_header =
89
- http:: HeaderValue :: try_from ( format ! ( "Bearer {}" , auth_token. to_owned( ) ) )
90
- . unwrap ( ) ;
91
-
92
132
req. headers_mut ( )
93
133
. expect ( "valid http request" )
94
- . insert ( "Authorization" , auth_header ) ;
134
+ . insert ( "Authorization" , auth_token . clone ( ) ) ;
95
135
}
96
136
None => { }
97
137
}
98
138
99
139
let req = req. body ( frame. clone ( ) . into ( ) ) . expect ( "valid body" ) ;
100
140
101
- let res = self . client . request ( req) . await . unwrap ( ) ;
141
+ let res = self
142
+ . client
143
+ . request ( req)
144
+ . await
145
+ . map_err ( SyncError :: HttpDispatch ) ?;
102
146
103
147
// TODO(lucio): only retry on server side errors
104
148
if res. status ( ) . is_success ( ) {
105
- let res_body = hyper:: body:: to_bytes ( res. into_body ( ) ) . await . unwrap ( ) ;
106
- let resp = serde_json:: from_slice :: < serde_json:: Value > ( & res_body[ ..] ) . unwrap ( ) ;
149
+ let res_body = hyper:: body:: to_bytes ( res. into_body ( ) )
150
+ . await
151
+ . map_err ( SyncError :: HttpBody ) ?;
152
+
153
+ let resp = serde_json:: from_slice :: < serde_json:: Value > ( & res_body[ ..] )
154
+ . map_err ( SyncError :: JsonDecode ) ?;
155
+
156
+ let max_frame_no = resp
157
+ . get ( "max_frame_no" )
158
+ . ok_or_else ( || SyncError :: JsonValue ( resp. clone ( ) ) ) ?;
159
+
160
+ let max_frame_no = max_frame_no
161
+ . as_u64 ( )
162
+ . ok_or_else ( || SyncError :: JsonValue ( max_frame_no. clone ( ) ) ) ?;
107
163
108
- let max_frame_no = resp. get ( "max_frame_no" ) . unwrap ( ) . as_u64 ( ) . unwrap ( ) ;
109
164
return Ok ( max_frame_no as u32 ) ;
110
165
}
111
166
112
167
if nr_retries > max_retries {
113
- return Err ( crate :: errors:: Error :: ConnectionFailed ( format ! (
114
- "Failed to push frame: {}" ,
115
- res. status( )
116
- ) ) ) ;
168
+ let status = res. status ( ) ;
169
+
170
+ let res_body = hyper:: body:: to_bytes ( res. into_body ( ) )
171
+ . await
172
+ . map_err ( SyncError :: HttpBody ) ?;
173
+
174
+ let msg = String :: from_utf8_lossy ( & res_body[ ..] ) ;
175
+
176
+ return Err ( SyncError :: PushFrame ( status, msg. to_string ( ) ) . into ( ) ) ;
117
177
}
178
+
118
179
let delay = std:: time:: Duration :: from_millis ( 100 * ( 1 << nr_retries) ) ;
119
180
tokio:: time:: sleep ( delay) . await ;
120
181
nr_retries += 1 ;
@@ -141,32 +202,33 @@ impl SyncContext {
141
202
142
203
metadata. set_hash ( ) ;
143
204
144
- let contents = serde_json:: to_vec ( & metadata) . unwrap ( ) ;
205
+ let contents = serde_json:: to_vec ( & metadata) . map_err ( SyncError :: JsonEncode ) ? ;
145
206
146
- atomic_write ( path, & contents[ ..] ) . await . unwrap ( ) ;
207
+ atomic_write ( path, & contents[ ..] ) . await ? ;
147
208
148
209
Ok ( ( ) )
149
210
}
150
211
151
212
async fn read_metadata ( & mut self ) -> Result < ( ) > {
152
213
let path = format ! ( "{}-info" , self . db_path) ;
153
214
154
- if !std:: fs:: exists ( & path) . unwrap ( ) {
215
+ if !std:: fs:: exists ( & path) . map_err ( SyncError :: io ( "metadata file exists" ) ) ? {
155
216
tracing:: debug!( "no metadata info file found" ) ;
156
217
return Ok ( ( ) ) ;
157
218
}
158
219
159
- let contents = tokio:: fs:: read ( & path) . await . unwrap ( ) ;
220
+ let contents = tokio:: fs:: read ( & path)
221
+ . await
222
+ . map_err ( SyncError :: io ( "metadata read" ) ) ?;
160
223
161
- let metadata = serde_json:: from_slice :: < MetadataJson > ( & contents[ ..] ) . unwrap ( ) ;
224
+ let metadata =
225
+ serde_json:: from_slice :: < MetadataJson > ( & contents[ ..] ) . map_err ( SyncError :: JsonDecode ) ?;
162
226
163
227
metadata. verify_hash ( ) ?;
164
228
165
- // TODO(lucio): convert this into a proper error
166
- assert_eq ! (
167
- metadata. version, METADATA_VERSION ,
168
- "Reading metadata from a different version than expected"
169
- ) ;
229
+ if metadata. version != METADATA_VERSION {
230
+ return Err ( SyncError :: VerifyVersion ( metadata. version , METADATA_VERSION ) . into ( ) ) ;
231
+ }
170
232
171
233
self . durable_frame_num = metadata. durable_frame_num ;
172
234
self . generation = metadata. generation ;
@@ -205,37 +267,46 @@ impl MetadataJson {
205
267
if self . hash == calculated_hash {
206
268
Ok ( ( ) )
207
269
} else {
208
- // TODO(lucio): convert this into a proper error rather than
209
- // an panic.
210
- panic ! (
211
- "metadata hash mismatch, expected={}, got={}" ,
212
- self . hash, calculated_hash
213
- ) ;
270
+ Err ( SyncError :: VerifyHash ( self . hash , calculated_hash) . into ( ) )
214
271
}
215
272
}
216
273
}
217
274
218
275
async fn atomic_write < P : AsRef < Path > > ( path : P , data : & [ u8 ] ) -> Result < ( ) > {
219
276
// Create a temporary file in the same directory as the target file
220
- let directory = path. as_ref ( ) . parent ( ) . unwrap ( ) ;
277
+ let directory = path. as_ref ( ) . parent ( ) . ok_or_else ( || {
278
+ SyncError :: io ( "parent path" ) ( std:: io:: Error :: other (
279
+ "unable to get parent of the provided path" ,
280
+ ) )
281
+ } ) ?;
221
282
222
283
let timestamp = Utc :: now ( ) . format ( "%Y%m%d_%H%M%S" ) ;
223
284
let temp_name = format ! ( ".tmp.{}.{}" , timestamp, Uuid :: new_v4( ) ) ;
224
285
let temp_path = directory. join ( temp_name) ;
225
286
226
287
// Write data to temporary file
227
- let mut temp_file = tokio:: fs:: File :: create ( & temp_path) . await . unwrap ( ) ;
288
+ let mut temp_file = tokio:: fs:: File :: create ( & temp_path)
289
+ . await
290
+ . map_err ( SyncError :: io ( "temp file create" ) ) ?;
228
291
229
- temp_file. write_all ( data) . await . unwrap ( ) ;
292
+ temp_file
293
+ . write_all ( data)
294
+ . await
295
+ . map_err ( SyncError :: io ( "temp file write_all" ) ) ?;
230
296
231
297
// Ensure all data is flushed to disk
232
- temp_file. sync_all ( ) . await . unwrap ( ) ;
298
+ temp_file
299
+ . sync_all ( )
300
+ . await
301
+ . map_err ( SyncError :: io ( "temp file sync_all" ) ) ?;
233
302
234
303
// Close the file explicitly
235
304
drop ( temp_file) ;
236
305
237
306
// Atomically rename temporary file to target file
238
- tokio:: fs:: rename ( & temp_path, & path) . await . unwrap ( ) ;
307
+ tokio:: fs:: rename ( & temp_path, & path)
308
+ . await
309
+ . map_err ( SyncError :: io ( "atomic rename" ) ) ?;
239
310
240
311
Ok ( ( ) )
241
312
}
0 commit comments