11use anyhow:: Context as _;
2+ use spin_core:: async_trait;
23use spin_factor_sqlite:: SqliteFactor ;
34use spin_factors:: RuntimeFactors ;
45use spin_factors_executor:: ExecutorHooks ;
@@ -21,68 +22,65 @@ impl SqlStatementExecutorHook {
2122 pub fn new ( sql_statements : Vec < String > ) -> Self {
2223 Self { sql_statements }
2324 }
24- }
2525
26- impl < F : RuntimeFactors , U > ExecutorHooks < F , U > for SqlStatementExecutorHook {
27- fn configure_app (
28- & mut self ,
29- configured_app : & spin_factors:: ConfiguredApp < F > ,
30- ) -> anyhow:: Result < ( ) > {
26+ /// Executes the sql statements.
27+ pub async fn execute ( & self , sqlite : & spin_factor_sqlite:: AppState ) -> anyhow:: Result < ( ) > {
3128 if self . sql_statements . is_empty ( ) {
3229 return Ok ( ( ) ) ;
3330 }
34- let Some ( sqlite) = configured_app. app_state :: < SqliteFactor > ( ) . ok ( ) else {
35- return Ok ( ( ) ) ;
36- } ;
37- if let Ok ( current) = tokio:: runtime:: Handle :: try_current ( ) {
38- let _ = current. spawn ( execute ( sqlite. clone ( ) , self . sql_statements . clone ( ) ) ) ;
39- }
40- Ok ( ( ) )
41- }
42- }
43-
44- /// Executes the sql statements.
45- pub async fn execute (
46- sqlite : spin_factor_sqlite:: AppState ,
47- sql_statements : Vec < String > ,
48- ) -> anyhow:: Result < ( ) > {
49- let get_database = |label| {
50- let sqlite = & sqlite;
51- async move {
31+ let get_database = |label| async move {
5232 sqlite
5333 . get_connection ( label)
5434 . await
5535 . transpose ( )
5636 . with_context ( || format ! ( "failed connect to database with label '{label}'" ) )
57- }
58- } ;
37+ } ;
5938
60- for statement in & sql_statements {
61- if let Some ( config) = statement. strip_prefix ( '@' ) {
62- let ( file, label) = parse_file_and_label ( config) ?;
63- let database = get_database ( label) . await ?. with_context ( || {
39+ for statement in & self . sql_statements {
40+ if let Some ( config) = statement. strip_prefix ( '@' ) {
41+ let ( file, label) = parse_file_and_label ( config) ?;
42+ let database = get_database ( label) . await ?. with_context ( || {
6443 format ! (
6544 "based on the '@{config}' a registered database named '{label}' was expected but not found."
6645 )
6746 } ) ?;
68- let sql = std:: fs:: read_to_string ( file) . with_context ( || {
69- format ! ( "could not read file '{file}' containing sql statements" )
70- } ) ?;
71- database. execute_batch ( & sql) . await . with_context ( || {
72- format ! ( "failed to execute sql against database '{label}' from file '{file}'" )
73- } ) ?;
74- } else {
75- let Some ( default) = get_database ( DEFAULT_SQLITE_LABEL ) . await ? else {
76- debug_assert ! ( false , "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not" ) ;
77- return Ok ( ( ) ) ;
78- } ;
79- default
47+ let sql = std:: fs:: read_to_string ( file) . with_context ( || {
48+ format ! ( "could not read file '{file}' containing sql statements" )
49+ } ) ?;
50+ database. execute_batch ( & sql) . await . with_context ( || {
51+ format ! ( "failed to execute sql against database '{label}' from file '{file}'" )
52+ } ) ?;
53+ } else {
54+ let Some ( default) = get_database ( DEFAULT_SQLITE_LABEL ) . await ? else {
55+ debug_assert ! ( false , "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not" ) ;
56+ return Ok ( ( ) ) ;
57+ } ;
58+ default
8059 . query ( statement, Vec :: new ( ) )
8160 . await
8261 . with_context ( || format ! ( "failed to execute following sql statement against default database: '{statement}'" ) ) ?;
62+ }
8363 }
64+ Ok ( ( ) )
65+ }
66+ }
67+
68+ #[ async_trait]
69+ impl < F , U > ExecutorHooks < F , U > for SqlStatementExecutorHook
70+ where
71+ F : RuntimeFactors ,
72+ F :: AppState : Sync ,
73+ {
74+ async fn configure_app (
75+ & mut self ,
76+ configured_app : & spin_factors:: ConfiguredApp < F > ,
77+ ) -> anyhow:: Result < ( ) > {
78+ let Some ( sqlite) = configured_app. app_state :: < SqliteFactor > ( ) . ok ( ) else {
79+ return Ok ( ( ) ) ;
80+ } ;
81+ self . execute ( & sqlite) . await ?;
82+ Ok ( ( ) )
8483 }
85- Ok ( ( ) )
8684}
8785
8886/// Parses a @{file:label} sqlite statement
0 commit comments