@@ -433,159 +433,27 @@ impl AsyncPgConnection {
433433 // so there is no need to even access the query in the async block below
434434 let mut query_builder = PgQueryBuilder :: default ( ) ;
435435
436- let ( collect_bind_result, fake_oid_locations, generated_oids, bind_collector) = {
437- // we don't resolve custom types here yet, we do that later
438- // in the async block below as we might need to perform lookup
439- // queries for that.
440- //
441- // We apply this workaround to prevent requiring all the diesel
442- // serialization code to beeing async
443- //
444- // We give out constant fake oids here to optimize for the "happy" path
445- // without custom type lookup
446- let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
447- let mut metadata_lookup_0 = PgAsyncMetadataLookup {
448- custom_oid : false ,
449- generated_oids : None ,
450- oid_generator : |_, _| ( FAKE_OID , FAKE_OID ) ,
451- } ;
452- let collect_bind_result_0 =
453- query. collect_binds ( & mut bind_collector_0, & mut metadata_lookup_0, & Pg ) ;
454-
455- // we have encountered a custom type oid, so we need to perform more work here.
456- // These oids can occure in two locations:
457- //
458- // * In the collected metadata -> relativly easy to resolve, just need to replace them below
459- // * As part of the seralized bind blob -> hard to replace
460- //
461- // To address the second case, we perform a second run of the bind collector
462- // with a different set of fake oids. Then we compare the output of the two runs
463- // and use that information to infer where to replace bytes in the serialized output
464-
465- if metadata_lookup_0. custom_oid {
466- // we try to get the maxium oid we encountered here
467- // to be sure that we don't accidently give out a fake oid below that collides with
468- // something
469- let mut max_oid = bind_collector_0
470- . metadata
471- . iter ( )
472- . flat_map ( |t| {
473- [
474- t. oid ( ) . unwrap_or_default ( ) ,
475- t. array_oid ( ) . unwrap_or_default ( ) ,
476- ]
477- } )
478- . max ( )
479- . unwrap_or_default ( ) ;
480- let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
481- let mut metadata_lookup_1 = PgAsyncMetadataLookup {
482- custom_oid : false ,
483- generated_oids : Some ( HashMap :: new ( ) ) ,
484- oid_generator : move |_, _| {
485- max_oid += 2 ;
486- ( max_oid, max_oid + 1 )
487- } ,
488- } ;
489- let collect_bind_result_2 =
490- query. collect_binds ( & mut bind_collector_1, & mut metadata_lookup_1, & Pg ) ;
491-
492- assert_eq ! (
493- bind_collector_0. binds. len( ) ,
494- bind_collector_0. metadata. len( )
495- ) ;
496- let fake_oid_locations = std:: iter:: zip (
497- bind_collector_0
498- . binds
499- . iter ( )
500- . zip ( & bind_collector_0. metadata ) ,
501- & bind_collector_1. binds ,
502- )
503- . enumerate ( )
504- . flat_map ( |( bind_index, ( ( bytes_0, metadata_0) , bytes_1) ) | {
505- // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
506- // in both cases the relevant buffer is a custom type on it's own
507- // so we only need to check the cases that contain a fake OID on their own
508- let ( bytes_0, bytes_1) = if matches ! ( metadata_0. oid( ) , Ok ( FAKE_OID ) ) {
509- (
510- bytes_0. as_deref ( ) . unwrap_or_default ( ) ,
511- bytes_1. as_deref ( ) . unwrap_or_default ( ) ,
512- )
513- } else {
514- // for all other cases, just return an empty
515- // list to make the iteration below a no-op
516- // and prevent the need of boxing
517- ( & [ ] as & [ _ ] , & [ ] as & [ _ ] )
518- } ;
519- let lookup_map = metadata_lookup_1
520- . generated_oids
521- . as_ref ( )
522- . map ( |map| {
523- map. values ( )
524- . flat_map ( |( oid, array_oid) | [ * oid, * array_oid] )
525- . collect :: < HashSet < _ > > ( )
526- } )
527- . unwrap_or_default ( ) ;
528- std:: iter:: zip (
529- bytes_0. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
530- bytes_1. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
531- )
532- . enumerate ( )
533- . filter_map ( move |( byte_index, ( l, r) ) | {
534- // here we infer if some byte sequence is a fake oid
535- // We use the following conditions for that:
536- //
537- // * The first byte sequence matches the constant FAKE_OID
538- // * The second sequence does not match the constant FAKE_OID
539- // * The second sequence is contained in the set of generated oid,
540- // otherwise we get false positives around the boundary
541- // of a to be replaced byte sequence
542- let r_val =
543- u32:: from_be_bytes ( r. try_into ( ) . expect ( "That's the right size" ) ) ;
544- ( l == FAKE_OID . to_be_bytes ( )
545- && r != FAKE_OID . to_be_bytes ( )
546- && lookup_map. contains ( & r_val) )
547- . then_some ( ( bind_index, byte_index) )
548- } )
549- } )
550- // Avoid storing the bind collectors in the returned Future
551- . collect :: < Vec < _ > > ( ) ;
552- (
553- collect_bind_result_0. and ( collect_bind_result_2) ,
554- fake_oid_locations,
555- metadata_lookup_1. generated_oids ,
556- bind_collector_1,
557- )
558- } else {
559- ( collect_bind_result_0, Vec :: new ( ) , None , bind_collector_0)
560- }
561- } ;
436+ let bind_data = construct_bind_data ( & query) ;
562437
563438 // The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
564439 self . with_prepared_statement_after_sql_built (
565440 callback,
566441 query. is_safe_to_cache_prepared ( & Pg ) ,
567442 T :: query_id ( ) ,
568443 query. to_sql ( & mut query_builder, & Pg ) ,
569- collect_bind_result,
570444 query_builder,
571- bind_collector,
572- fake_oid_locations,
573- generated_oids,
445+ bind_data,
574446 )
575447 }
576448
577- #[ allow( clippy:: too_many_arguments) ]
578449 fn with_prepared_statement_after_sql_built < ' a , F , R > (
579450 & mut self ,
580451 callback : fn ( Arc < tokio_postgres:: Client > , Statement , Vec < ToSqlHelper > ) -> F ,
581452 is_safe_to_cache_prepared : QueryResult < bool > ,
582453 query_id : Option < std:: any:: TypeId > ,
583454 to_sql_result : QueryResult < ( ) > ,
584- collect_bind_result : QueryResult < ( ) > ,
585455 query_builder : PgQueryBuilder ,
586- mut bind_collector : RawBytesBindCollector < Pg > ,
587- fake_oid_locations : Vec < ( usize , usize ) > ,
588- generated_oids : GeneratedOidTypeMap ,
456+ bind_data : BindData ,
589457 ) -> BoxFuture < ' a , QueryResult < R > >
590458 where
591459 F : Future < Output = QueryResult < R > > + Send + ' a ,
@@ -596,6 +464,12 @@ impl AsyncPgConnection {
596464 let metadata_cache = self . metadata_cache . clone ( ) ;
597465 let tm = self . transaction_state . clone ( ) ;
598466 let instrumentation = self . instrumentation . clone ( ) ;
467+ let BindData {
468+ collect_bind_result,
469+ fake_oid_locations,
470+ generated_oids,
471+ mut bind_collector,
472+ } = bind_data;
599473
600474 async move {
601475 let sql = to_sql_result. map ( |_| query_builder. finish ( ) ) ?;
@@ -710,6 +584,142 @@ impl AsyncPgConnection {
710584 }
711585}
712586
587+ struct BindData {
588+ collect_bind_result : Result < ( ) , Error > ,
589+ fake_oid_locations : Vec < ( usize , usize ) > ,
590+ generated_oids : GeneratedOidTypeMap ,
591+ bind_collector : RawBytesBindCollector < Pg > ,
592+ }
593+
594+ fn construct_bind_data ( query : & dyn QueryFragment < diesel:: pg:: Pg > ) -> BindData {
595+ // we don't resolve custom types here yet, we do that later
596+ // in the async block below as we might need to perform lookup
597+ // queries for that.
598+ //
599+ // We apply this workaround to prevent requiring all the diesel
600+ // serialization code to beeing async
601+ //
602+ // We give out constant fake oids here to optimize for the "happy" path
603+ // without custom type lookup
604+ let mut bind_collector_0 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
605+ let mut metadata_lookup_0 = PgAsyncMetadataLookup {
606+ custom_oid : false ,
607+ generated_oids : None ,
608+ oid_generator : |_, _| ( FAKE_OID , FAKE_OID ) ,
609+ } ;
610+ let collect_bind_result_0 =
611+ query. collect_binds ( & mut bind_collector_0, & mut metadata_lookup_0, & Pg ) ;
612+ // we have encountered a custom type oid, so we need to perform more work here.
613+ // These oids can occure in two locations:
614+ //
615+ // * In the collected metadata -> relativly easy to resolve, just need to replace them below
616+ // * As part of the seralized bind blob -> hard to replace
617+ //
618+ // To address the second case, we perform a second run of the bind collector
619+ // with a different set of fake oids. Then we compare the output of the two runs
620+ // and use that information to infer where to replace bytes in the serialized output
621+ if metadata_lookup_0. custom_oid {
622+ // we try to get the maxium oid we encountered here
623+ // to be sure that we don't accidently give out a fake oid below that collides with
624+ // something
625+ let mut max_oid = bind_collector_0
626+ . metadata
627+ . iter ( )
628+ . flat_map ( |t| {
629+ [
630+ t. oid ( ) . unwrap_or_default ( ) ,
631+ t. array_oid ( ) . unwrap_or_default ( ) ,
632+ ]
633+ } )
634+ . max ( )
635+ . unwrap_or_default ( ) ;
636+ let mut bind_collector_1 = RawBytesBindCollector :: < diesel:: pg:: Pg > :: new ( ) ;
637+ let mut metadata_lookup_1 = PgAsyncMetadataLookup {
638+ custom_oid : false ,
639+ generated_oids : Some ( HashMap :: new ( ) ) ,
640+ oid_generator : move |_, _| {
641+ max_oid += 2 ;
642+ ( max_oid, max_oid + 1 )
643+ } ,
644+ } ;
645+ let collect_bind_result_1 =
646+ query. collect_binds ( & mut bind_collector_1, & mut metadata_lookup_1, & Pg ) ;
647+
648+ assert_eq ! (
649+ bind_collector_0. binds. len( ) ,
650+ bind_collector_0. metadata. len( )
651+ ) ;
652+ let fake_oid_locations = std:: iter:: zip (
653+ bind_collector_0
654+ . binds
655+ . iter ( )
656+ . zip ( & bind_collector_0. metadata ) ,
657+ & bind_collector_1. binds ,
658+ )
659+ . enumerate ( )
660+ . flat_map ( |( bind_index, ( ( bytes_0, metadata_0) , bytes_1) ) | {
661+ // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
662+ // in both cases the relevant buffer is a custom type on it's own
663+ // so we only need to check the cases that contain a fake OID on their own
664+ let ( bytes_0, bytes_1) = if matches ! ( metadata_0. oid( ) , Ok ( FAKE_OID ) ) {
665+ (
666+ bytes_0. as_deref ( ) . unwrap_or_default ( ) ,
667+ bytes_1. as_deref ( ) . unwrap_or_default ( ) ,
668+ )
669+ } else {
670+ // for all other cases, just return an empty
671+ // list to make the iteration below a no-op
672+ // and prevent the need of boxing
673+ ( & [ ] as & [ _ ] , & [ ] as & [ _ ] )
674+ } ;
675+ let lookup_map = metadata_lookup_1
676+ . generated_oids
677+ . as_ref ( )
678+ . map ( |map| {
679+ map. values ( )
680+ . flat_map ( |( oid, array_oid) | [ * oid, * array_oid] )
681+ . collect :: < HashSet < _ > > ( )
682+ } )
683+ . unwrap_or_default ( ) ;
684+ std:: iter:: zip (
685+ bytes_0. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
686+ bytes_1. windows ( std:: mem:: size_of_val ( & FAKE_OID ) ) ,
687+ )
688+ . enumerate ( )
689+ . filter_map ( move |( byte_index, ( l, r) ) | {
690+ // here we infer if some byte sequence is a fake oid
691+ // We use the following conditions for that:
692+ //
693+ // * The first byte sequence matches the constant FAKE_OID
694+ // * The second sequence does not match the constant FAKE_OID
695+ // * The second sequence is contained in the set of generated oid,
696+ // otherwise we get false positives around the boundary
697+ // of a to be replaced byte sequence
698+ let r_val = u32:: from_be_bytes ( r. try_into ( ) . expect ( "That's the right size" ) ) ;
699+ ( l == FAKE_OID . to_be_bytes ( )
700+ && r != FAKE_OID . to_be_bytes ( )
701+ && lookup_map. contains ( & r_val) )
702+ . then_some ( ( bind_index, byte_index) )
703+ } )
704+ } )
705+ // Avoid storing the bind collectors in the returned Future
706+ . collect :: < Vec < _ > > ( ) ;
707+ BindData {
708+ collect_bind_result : collect_bind_result_0. and ( collect_bind_result_1) ,
709+ fake_oid_locations,
710+ generated_oids : metadata_lookup_1. generated_oids ,
711+ bind_collector : bind_collector_1,
712+ }
713+ } else {
714+ BindData {
715+ collect_bind_result : collect_bind_result_0,
716+ fake_oid_locations : Vec :: new ( ) ,
717+ generated_oids : None ,
718+ bind_collector : bind_collector_0,
719+ }
720+ }
721+ }
722+
713723type GeneratedOidTypeMap = Option < HashMap < ( Option < String > , String ) , ( u32 , u32 ) > > ;
714724
715725/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
0 commit comments