@@ -49,24 +49,27 @@ async fn execute_statements(
49
49
if statements. is_empty ( ) {
50
50
return Ok ( ( ) ) ;
51
51
}
52
- let Some ( default) = databases. get ( "default" ) else {
53
- debug_assert ! (
54
- false ,
55
- "the 'default' sqlite database should always be available but for some reason was not"
56
- ) ;
57
- return Ok ( ( ) ) ;
58
- } ;
59
52
60
53
for m in statements {
61
- if let Some ( file) = m. strip_prefix ( '@' ) {
54
+ if let Some ( config) = m. strip_prefix ( '@' ) {
55
+ let ( file, database) = parse_file_and_label ( config) ?;
56
+ let database = databases. get ( database) . with_context ( || {
57
+ format ! (
58
+ "based on the '@{config}' a registered database named '{database}' was expected but not found. The registered databases are '{:?}'" , databases. keys( )
59
+ )
60
+ } ) ?;
62
61
let sql = std:: fs:: read_to_string ( file) . with_context ( || {
63
62
format ! ( "could not read file '{file}' containing sql statements" )
64
63
} ) ?;
65
- default
64
+ database
66
65
. execute_batch ( & sql)
67
66
. await
68
67
. with_context ( || format ! ( "failed to execute sql from file '{file}'" ) ) ?;
69
68
} else {
69
+ let Some ( default) = databases. get ( "default" ) else {
70
+ debug_assert ! ( false , "the 'default' sqlite database should always be available but for some reason was not" ) ;
71
+ return Ok ( ( ) ) ;
72
+ } ;
70
73
default
71
74
. query ( m, Vec :: new ( ) )
72
75
. await
@@ -76,6 +79,19 @@ async fn execute_statements(
76
79
Ok ( ( ) )
77
80
}
78
81
82
+ /// Parses a @{file:label} sqlite statement
83
+ fn parse_file_and_label ( config : & str ) -> anyhow:: Result < ( & str , & str ) > {
84
+ let config = config. trim ( ) ;
85
+ let ( file, label) = match config. split_once ( ':' ) {
86
+ Some ( ( _, label) ) if label. trim ( ) . is_empty ( ) => {
87
+ anyhow:: bail!( "database label is empty in the '@{config}' sqlite statement" )
88
+ }
89
+ Some ( ( file, label) ) => ( file. trim ( ) , label. trim ( ) ) ,
90
+ None => ( config, "default" ) ,
91
+ } ;
92
+ Ok ( ( file, label) )
93
+ }
94
+
79
95
// Holds deserialized options from a `[sqlite_database.<name>]` runtime config section.
80
96
#[ derive( Clone , Debug , serde:: Deserialize ) ]
81
97
#[ serde( rename_all = "snake_case" , tag = "type" ) ]
@@ -202,3 +218,23 @@ impl TriggerHooks for SqlitePersistenceMessageHook {
202
218
Ok ( ( ) )
203
219
}
204
220
}
221
+
222
+ #[ cfg( test) ]
223
+ mod tests {
224
+ use super :: * ;
225
+
226
+ #[ test]
227
+ fn can_parse_file_and_label ( ) {
228
+ let config = "file:label" ;
229
+ let result = parse_file_and_label ( config) . unwrap ( ) ;
230
+ assert_eq ! ( result, ( "file" , "label" ) ) ;
231
+
232
+ let config = "file:" ;
233
+ let result = parse_file_and_label ( config) ;
234
+ assert ! ( result. is_err( ) ) ;
235
+
236
+ let config = "file" ;
237
+ let result = parse_file_and_label ( config) . unwrap ( ) ;
238
+ assert_eq ! ( result, ( "file" , "default" ) ) ;
239
+ }
240
+ }
0 commit comments