@@ -757,7 +757,7 @@ mod tests {
757757 for got in gots {
758758 // First, make sure that `got` has the right number of items,
759759 // equal to the sum of sizes of all expected groups
760- let combined_groups_len = self . groups . iter ( ) . map ( |s| s. len ( ) ) . sum ( ) ;
760+ let combined_groups_len: usize = self . groups . iter ( ) . map ( |s| s. len ( ) ) . sum ( ) ;
761761 assert_eq ! ( got. len( ) , combined_groups_len) ;
762762
763763 // Now, split `got` into groups of expected sizes
@@ -1619,16 +1619,27 @@ mod latency_awareness {
16191619 }
16201620
16211621 pub ( super ) fn report_query ( & self , node : & Node , latency : Duration ) {
1622- if let Some ( node_avg) = self . node_avgs . read ( ) . unwrap ( ) . get ( & node. host_id ) {
1622+ let node_avgs_guard = self . node_avgs . read ( ) . unwrap ( ) ;
1623+ if let Some ( previous_node_avg) = node_avgs_guard. get ( & node. host_id ) {
16231624 // The usual path, the node has been already noticed.
1624- let mut node_avg = node_avg . write ( ) . unwrap ( ) ;
1625- let previous = * node_avg ;
1626- * node_avg = TimestampedAverage :: compute_next ( previous , latency) ;
1625+ let mut node_avg_guard = previous_node_avg . write ( ) . unwrap ( ) ;
1626+ let previous_node_avg = * node_avg_guard ;
1627+ * node_avg_guard = TimestampedAverage :: compute_next ( previous_node_avg , latency) ;
16271628 } else {
1628- // We need to add the node to the map.
1629- self . node_avgs . write ( ) . unwrap ( ) . insert (
1629+ // We drop the read lock not to deadlock while taking write lock.
1630+ std:: mem:: drop ( node_avgs_guard) ;
1631+ let mut node_avgs_guard = self . node_avgs . write ( ) . unwrap ( ) ;
1632+
1633+ // We have to read this again, as other threads may race with us.
1634+ let previous_node_avg = node_avgs_guard
1635+ . get ( & node. host_id )
1636+ . and_then ( |rwlock| * rwlock. read ( ) . unwrap ( ) ) ;
1637+
1638+ // We most probably need to add the node to the map.
1639+ // (this will be Some only in an unlikely case that another thread raced with us and won)
1640+ node_avgs_guard. insert (
16301641 node. host_id ,
1631- RwLock :: new ( TimestampedAverage :: compute_next ( None , latency) ) ,
1642+ RwLock :: new ( TimestampedAverage :: compute_next ( previous_node_avg , latency) ) ,
16321643 ) ;
16331644 }
16341645 }
@@ -1908,6 +1919,7 @@ mod latency_awareness {
19081919 } ,
19091920 ClusterData , NodeAddr ,
19101921 } ,
1922+ ExecutionProfile , SessionBuilder ,
19111923 } ;
19121924 use std:: time:: Instant ;
19131925
@@ -2582,5 +2594,29 @@ mod latency_awareness {
25822594 . await ;
25832595 }
25842596 }
2597+
2598+ // This is a regression test for #696.
2599+ #[ tokio:: test]
2600+ #[ ntest:: timeout( 1000 ) ]
2601+ async fn latency_aware_query_completes ( ) {
2602+ let uri = std:: env:: var ( "SCYLLA_URI" ) . unwrap_or_else ( |_| "127.0.0.1:9042" . to_string ( ) ) ;
2603+
2604+ let policy = DefaultPolicy :: builder ( )
2605+ . latency_awareness ( LatencyAwarenessBuilder :: default ( ) )
2606+ . build ( ) ;
2607+ let handle = ExecutionProfile :: builder ( )
2608+ . load_balancing_policy ( policy)
2609+ . build ( )
2610+ . into_handle ( ) ;
2611+
2612+ let session = SessionBuilder :: new ( )
2613+ . known_node ( uri)
2614+ . default_execution_profile_handle ( handle)
2615+ . build ( )
2616+ . await
2617+ . unwrap ( ) ;
2618+
2619+ session. query ( "whatever" , ( ) ) . await . unwrap_err ( ) ;
2620+ }
25852621 }
25862622}
0 commit comments