@@ -566,6 +566,82 @@ def test_same_duplicate_subtree(session):
566566 assert count_number_of_ctes (df_result2 .queries ["queries" ][- 1 ]) == 3
567567
568568
569+ @pytest .mark .parametrize ("use_different_df" , [True , False ])
570+ def test_cte_preserves_join_suffix_aliases (session , use_different_df ):
571+ df_ad_group = session .create_dataframe (
572+ [["1048771" , "group_1" , "campaign_1" ]],
573+ schema = ["ACCOUNT_ID" , "AD_GROUP_ID" , "CAMPAIGN_ID" ],
574+ )
575+
576+ df_ad_group_excv = session .create_dataframe (
577+ [["1048771" , "group_1" , "device" , "8308" ]],
578+ schema = ["ACCOUNT_ID" , "AD_GROUP_ID" , "DEVICE" , "EXTERNAL_CONVERSION_ID" ],
579+ )
580+
581+ df_ad_group_excv = df_ad_group_excv .join (
582+ df_ad_group ,
583+ df_ad_group .col ("AD_GROUP_ID" ) == df_ad_group_excv .col ("AD_GROUP_ID" ),
584+ rsuffix = "_WITH_AD_GROUP" ,
585+ ).select (
586+ col ("ACCOUNT_ID" ),
587+ col ("CAMPAIGN_ID" ),
588+ col ("AD_GROUP_ID" ),
589+ lit (None ).as_ ("AD_ID" ),
590+ )
591+
592+ if use_different_df :
593+ df_ad_group = session .create_dataframe (
594+ [["1048771" , "group_1" , "campaign_1" ]],
595+ schema = ["ACCOUNT_ID" , "AD_GROUP_ID" , "CAMPAIGN_ID" ],
596+ )
597+
598+ df_ad_group_ad = session .create_dataframe (
599+ [["1048771" , "ad_1" , "group_1" ]],
600+ schema = ["ACCOUNT_ID" , "AD_ID" , "AD_GROUP_ID" ],
601+ )
602+
603+ df_ad_excv = session .create_dataframe (
604+ [["1048771" , "group_1" , "ad_1" , "device" , "8308" ]],
605+ schema = [
606+ "ACCOUNT_ID" ,
607+ "AD_GROUP_ID" ,
608+ "AD_ID" ,
609+ "DEVICE" ,
610+ "EXTERNAL_CONVERSION_ID" ,
611+ ],
612+ )
613+
614+ df_ad_excv = (
615+ df_ad_excv .join (
616+ df_ad_group_ad ,
617+ df_ad_group_ad .col ("AD_ID" ) == df_ad_excv .col ("AD_ID" ),
618+ rsuffix = "_WITH_AD_GROUP_AD" ,
619+ )
620+ .join (
621+ df_ad_group ,
622+ df_ad_group .col ("AD_GROUP_ID" ) == df_ad_group_ad .col ("AD_GROUP_ID" ),
623+ rsuffix = "_WITH_AD_GROUP" ,
624+ )
625+ .select (
626+ col ("ACCOUNT_ID" ),
627+ col ("CAMPAIGN_ID" ),
628+ col ("AD_GROUP_ID" ),
629+ col ("AD_ID" ),
630+ )
631+ )
632+
633+ df_union = df_ad_group_excv .union_all (df_ad_excv )
634+ union_sql = df_union .queries ["queries" ][- 1 ]
635+
636+ # the second one is incorrect join condition as we have rsuffix for join alias
637+ assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
638+ assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
639+ # when using different df_ad_group with disambiguation, because rsuffix in join,
640+ # they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
641+ # However there is still a CTE for create_dataframe call
642+ assert count_number_of_ctes (Utils .normalize_sql (union_sql )) == 1
643+
644+
569645@pytest .mark .parametrize (
570646 "mode" , ["append" , "truncate" , "overwrite" , "errorifexists" , "ignore" ]
571647)
@@ -736,12 +812,12 @@ def test_sql_simplifier(session):
736812 describe_count_for_optimized = 1 if session ._join_alias_fix else None ,
737813 )
738814 with SqlCounter (query_count = 0 , describe_count = 0 ):
739- # When adding a lsuffix, the columns of right dataframe don't need to be renamed,
740- # so we will get a common CTE with filter
815+ # When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
816+ # different and have different ids. So only df1 and df will be converted to a CTE
741817 assert (
742- count_number_of_ctes (Utils .normalize_sql (df6 .queries ["queries" ][- 1 ])) == 2
818+ count_number_of_ctes (Utils .normalize_sql (df6 .queries ["queries" ][- 1 ])) == 1
743819 )
744- assert Utils .normalize_sql (df6 .queries ["queries" ][- 1 ]).count (filter_clause ) == 2
820+ assert Utils .normalize_sql (df6 .queries ["queries" ][- 1 ]).count (filter_clause ) == 3
745821
746822 df7 = df1 .with_column ("c" , lit (1 ))
747823 df8 = df1 .with_column ("c" , lit (1 )).with_column ("d" , lit (1 ))
0 commit comments