@@ -14,6 +14,7 @@ use tokio::time::timeout;
1414use tracing:: { debug, error, trace, trace_span, Instrument } ;
1515use uuid:: Uuid ;
1616
17+ use super :: connection:: NonErrorQueryResponse ;
1718use super :: connection:: QueryResponse ;
1819use super :: errors:: { BadQuery , NewSessionError , QueryError } ;
1920use crate :: cql_to_rust:: FromRow ;
@@ -470,6 +471,7 @@ impl Session {
470471 connection
471472 . query ( query_ref, values_ref, paging_state_ref. clone ( ) )
472473 . await
474+ . and_then ( QueryResponse :: into_non_error_query_response)
473475 }
474476 } ,
475477 )
@@ -484,7 +486,7 @@ impl Session {
484486
485487 async fn handle_set_keyspace_response (
486488 & self ,
487- response : & QueryResponse ,
489+ response : & NonErrorQueryResponse ,
488490 ) -> Result < ( ) , QueryError > {
489491 if let Some ( set_keyspace) = response. as_set_keyspace ( ) {
490492 debug ! (
@@ -501,7 +503,7 @@ impl Session {
501503 async fn handle_auto_await_schema_agreement (
502504 & self ,
503505 contents : & str ,
504- response : & QueryResponse ,
506+ response : & NonErrorQueryResponse ,
505507 ) -> Result < ( ) , QueryError > {
506508 if let Some ( timeout) = self . auto_await_schema_agreement_timeout {
507509 if response. as_schema_change ( ) . is_some ( )
@@ -748,7 +750,7 @@ impl Session {
748750 "Request" ,
749751 prepared_id = format!( "{:X}" , prepared. get_id( ) ) . as_str( )
750752 ) ;
751- let response = self
753+ let response: NonErrorQueryResponse = self
752754 . run_query (
753755 statement_info,
754756 & prepared. config ,
@@ -762,6 +764,7 @@ impl Session {
762764 connection
763765 . execute ( prepared, values_ref, paging_state_ref. clone ( ) )
764766 . await
767+ . and_then ( QueryResponse :: into_non_error_query_response)
765768 } ,
766769 )
767770 . instrument ( span)
@@ -1104,6 +1107,7 @@ impl Session {
11041107 where
11051108 ConnFut : Future < Output = Result < Arc < Connection > , QueryError > > ,
11061109 QueryFut : Future < Output = Result < ResT , QueryError > > ,
1110+ ResT : AllowedRunQueryResTType ,
11071111 {
11081112 let cluster_data = self . cluster . get_data ( ) ;
11091113 let query_plan = self . load_balancer . plan ( & statement_info, & cluster_data) ;
@@ -1196,6 +1200,7 @@ impl Session {
11961200 where
11971201 ConnFut : Future < Output = Result < Arc < Connection > , QueryError > > ,
11981202 QueryFut : Future < Output = Result < ResT , QueryError > > ,
1203+ ResT : AllowedRunQueryResTType ,
11991204 {
12001205 let mut last_error: Option < QueryError > = None ;
12011206
@@ -1303,6 +1308,7 @@ impl Session {
13031308 ) -> Result < ResT , QueryError >
13041309 where
13051310 QueryFut : Future < Output = Result < ResT , QueryError > > ,
1311+ ResT : AllowedRunQueryResTType ,
13061312 {
13071313 let info = Statement :: default ( ) ;
13081314 let config = StatementConfig {
@@ -1395,3 +1401,18 @@ async fn resolve_hostname(hostname: &str) -> Result<SocketAddr, NewSessionError>
13951401
13961402 ret. ok_or ( failed_err)
13971403}
1404+
1405+ // run_query, execute_query, etc have a template type called ResT.
1406+ // There was a bug where ResT was set to QueryResponse, which could
1407+ // be an error response. This was not caught by retry policy which
1408+ // assumed all errors would come from analyzing Result<ResT, QueryError>.
1409+ // This trait is a guard to make sure that this mistake doesn't
1410+ // happen again.
1411+ // When using run_query make sure that the ResT type is NOT able
1412+ // to contain any errors.
1413+ // See https://github.com/scylladb/scylla-rust-driver/issues/501
1414+ pub trait AllowedRunQueryResTType { }
1415+
1416+ impl AllowedRunQueryResTType for Uuid { }
1417+ impl AllowedRunQueryResTType for BatchResult { }
1418+ impl AllowedRunQueryResTType for NonErrorQueryResponse { }
0 commit comments