@@ -93,7 +93,12 @@ def test_axis_1_basic_types_without_type_hints(data, func, return_type):
9393 native_df = native_pd .DataFrame (data , columns = ["A" , "b" ])
9494 snow_df = pd .DataFrame (data , columns = ["A" , "b" ])
9595 # np.min is mapped to sql builtin function.
96- with SqlCounter (query_count = 1 if func == np .min else 5 ):
96+ query_count = 1 if func == np .min else 5
97+ join_count = 0 if func == np .min else 2
98+ udtf_count = 0 if func == np .min else 1
99+ with SqlCounter (
100+ query_count = query_count , join_count = join_count , udtf_count = udtf_count
101+ ):
97102 eval_snowpark_pandas_result (snow_df , native_df , lambda x : x .apply (func , axis = 1 ))
98103
99104
@@ -107,7 +112,7 @@ def test_axis_1_basic_types_with_type_hints(data, func, return_type):
107112 snow_df = pd .DataFrame (data , columns = ["A" , "b" ])
108113 func_with_type_hint = create_func_with_return_type_hint (func , return_type )
109114 # Invoking a single UDF typically requires 3 queries (package management, code upload, UDF registration) upfront.
110- with SqlCounter (query_count = 4 , join_count = 0 , udtf_count = 0 ):
115+ with SqlCounter (query_count = 4 , join_count = 0 , udtf_count = 0 , udf_count = 1 ):
111116 eval_snowpark_pandas_result (
112117 snow_df , native_df , lambda x : x .apply (func_with_type_hint , axis = 1 )
113118 )
@@ -144,7 +149,7 @@ def foo(row) -> str:
144149
145150 snow_df = pd .DataFrame (df )
146151 # Invoking a single UDF typically requires 3 queries (package management, code upload, UDF registration) upfront.
147- with SqlCounter (query_count = 4 , join_count = 0 , udtf_count = 0 ):
152+ with SqlCounter (query_count = 4 , join_count = 0 , udtf_count = 0 , udf_count = 1 ):
148153 eval_snowpark_pandas_result (snow_df , df , lambda x : x .apply (foo , axis = 1 ))
149154
150155
@@ -350,12 +355,12 @@ def f(x, y, z=1) -> int:
350355 assert_exception_equal = False ,
351356 )
352357
353- with SqlCounter (query_count = 4 ):
358+ with SqlCounter (query_count = 4 , udf_count = 1 ):
354359 eval_snowpark_pandas_result (
355360 snow_df , native_df , lambda x : x .apply (f , axis = 1 , args = (1 ,))
356361 )
357362
358- with SqlCounter (query_count = 4 ):
363+ with SqlCounter (query_count = 4 , udf_count = 1 ):
359364 eval_snowpark_pandas_result (
360365 snow_df , native_df , lambda x : x .apply (f , axis = 1 , args = (1 ,), z = 2 )
361366 )
@@ -640,6 +645,12 @@ def test_basic_dataframe_transform(data, apply_func, expected_query_count):
640645 snow_df = pd .DataFrame (data )
641646 with SqlCounter (
642647 query_count = expected_query_count ,
648+ join_count = 5
649+ if expected_query_count == 16
650+ else (3 if expected_query_count > 1 else 0 ),
651+ udtf_count = 3
652+ if expected_query_count == 16
653+ else (2 if expected_query_count > 1 else 0 ),
643654 high_count_expected = expected_query_count > 8 ,
644655 high_count_reason = msg ,
645656 ):
@@ -900,6 +911,7 @@ def test_apply_axis1_with_3rd_party_libraries_and_decorator(
900911
901912 with SqlCounter (
902913 query_count = expected_query_count ,
914+ udf_count = 1 ,
903915 high_count_expected = True ,
904916 high_count_reason = "Snowpark package upload requires many queries." ,
905917 ):
@@ -1052,6 +1064,7 @@ def f(s, default_arg=2):
10521064
10531065 axis_0_no_cache_kwargs = {
10541066 "query_count" : 16 ,
1067+ "join_count" : 5 ,
10551068 "udtf_count" : 3 ,
10561069 "high_count_expected" : True ,
10571070 "high_count_reason" : "UDTF creation on multiple columns" ,
@@ -1065,13 +1078,13 @@ def f(s, default_arg=2):
10651078 )
10661079 # second application should trigger cache hit, even with explicit axis argument or via transform call
10671080 # unclear why SQL counter still parses UDTF creations here despite having lower query counts
1068- with SqlCounter (query_count = 7 , udtf_count = 3 ):
1081+ with SqlCounter (query_count = 7 , join_count = 5 , udtf_count = 3 ):
10691082 eval_snowpark_pandas_result (
10701083 snow_df ,
10711084 native_df ,
10721085 lambda df : df .transform (f ),
10731086 )
1074- with SqlCounter (query_count = 7 , udtf_count = 3 ):
1087+ with SqlCounter (query_count = 7 , join_count = 5 , udtf_count = 3 ):
10751088 eval_snowpark_pandas_result (
10761089 snow_df ,
10771090 native_df ,
@@ -1085,27 +1098,27 @@ def f(s, default_arg=2):
10851098 lambda df : df .apply (f , default_arg = 3 ),
10861099 )
10871100 # application on a new dataframe with the same schema should hit the cache
1088- with SqlCounter (query_count = 7 , udtf_count = 3 ):
1101+ with SqlCounter (query_count = 7 , join_count = 5 , udtf_count = 3 ):
10891102 eval_snowpark_pandas_result (
10901103 * create_test_dfs (test_data ),
10911104 lambda df : df .transform (f ),
10921105 )
10931106 # calling on axis=1 creates a new UDTF
1094- with SqlCounter (query_count = 5 , udtf_count = 1 ):
1107+ with SqlCounter (query_count = 5 , join_count = 2 , udtf_count = 1 ):
10951108 eval_snowpark_pandas_result (
10961109 snow_df ,
10971110 native_df ,
10981111 lambda df : df .apply (f , axis = 1 ),
10991112 )
11001113 # a second call with axis=1 does hit the cache (not sure why SQL counter registers a udtf creation)
1101- with SqlCounter (query_count = 2 , udtf_count = 1 ):
1114+ with SqlCounter (query_count = 2 , join_count = 2 , udtf_count = 1 ):
11021115 eval_snowpark_pandas_result (
11031116 snow_df ,
11041117 native_df ,
11051118 lambda df : df .apply (f , axis = 1 ),
11061119 )
11071120 # calling on axis=1 with different argument value does not hit the cache
1108- with SqlCounter (query_count = 5 , udtf_count = 1 ):
1121+ with SqlCounter (query_count = 5 , join_count = 2 , udtf_count = 1 ):
11091122 eval_snowpark_pandas_result (
11101123 snow_df ,
11111124 native_df ,
@@ -1128,6 +1141,7 @@ def f(s, default_arg=2):
11281141 with SqlCounter (
11291142 query_count = 11 ,
11301143 udtf_count = 2 ,
1144+ join_count = 3 ,
11311145 high_count_expected = True ,
11321146 high_count_reason = "UDTF creation on multiple columns" ,
11331147 ):
@@ -1137,21 +1151,21 @@ def f(s, default_arg=2):
11371151 lambda df : df .apply (f , axis = 0 ),
11381152 )
11391153 # A second call hits the cache.
1140- with SqlCounter (query_count = 5 , udtf_count = 2 ):
1154+ with SqlCounter (query_count = 5 , udtf_count = 2 , join_count = 3 ):
11411155 eval_snowpark_pandas_result (
11421156 snow_df ,
11431157 native_df ,
11441158 lambda df : df .apply (f , axis = 0 ),
11451159 )
11461160 # The same rules apply with a different axis argument.
1147- with SqlCounter (query_count = 5 , udtf_count = 1 ):
1161+ with SqlCounter (query_count = 5 , udtf_count = 1 , join_count = 2 ):
11481162 eval_snowpark_pandas_result (
11491163 snow_df ,
11501164 native_df ,
11511165 lambda df : df .apply (f , axis = 1 ),
11521166 )
11531167 # A second call still does not hit the cache.
1154- with SqlCounter (query_count = 2 , udtf_count = 1 ):
1168+ with SqlCounter (query_count = 2 , udtf_count = 1 , join_count = 2 ):
11551169 eval_snowpark_pandas_result (
11561170 snow_df ,
11571171 native_df ,
@@ -1174,6 +1188,7 @@ def f(s, default_arg=2):
11741188 with SqlCounter (
11751189 query_count = 11 ,
11761190 udtf_count = 2 ,
1191+ join_count = 3 ,
11771192 high_count_expected = True ,
11781193 high_count_reason = "UDTF creation on multiple columns" ,
11791194 ):
@@ -1183,21 +1198,21 @@ def f(s, default_arg=2):
11831198 lambda df : df .apply (f , axis = 0 ),
11841199 )
11851200 # A second call hits the cache.
1186- with SqlCounter (query_count = 5 , udtf_count = 2 ):
1201+ with SqlCounter (query_count = 5 , udtf_count = 2 , join_count = 3 ):
11871202 eval_snowpark_pandas_result (
11881203 snow_df ,
11891204 native_df ,
11901205 lambda df : df .apply (f , axis = 0 ),
11911206 )
11921207 # The same rules apply with a different axis argument.
1193- with SqlCounter (query_count = 5 , udtf_count = 1 ):
1208+ with SqlCounter (query_count = 5 , udtf_count = 1 , join_count = 2 ):
11941209 eval_snowpark_pandas_result (
11951210 snow_df ,
11961211 native_df ,
11971212 lambda df : df .apply (f , axis = 1 ),
11981213 )
11991214 # A second call hits the cache.
1200- with SqlCounter (query_count = 2 , udtf_count = 1 ):
1215+ with SqlCounter (query_count = 2 , udtf_count = 1 , join_count = 2 ):
12011216 eval_snowpark_pandas_result (
12021217 snow_df ,
12031218 native_df ,
@@ -1219,14 +1234,14 @@ def __init__(self) -> None:
12191234 def operation (col , arg ):
12201235 return col + sum (arg .x )
12211236
1222- with SqlCounter (query_count = 6 , udtf_count = 1 ):
1237+ with SqlCounter (query_count = 6 , join_count = 1 , udtf_count = 1 ):
12231238 eval_snowpark_pandas_result (
12241239 * create_test_dfs (test_data ), lambda df : df .apply (operation , arg = arg )
12251240 )
12261241
12271242 # Mutate arg.x, preventing a cache entry from being created
12281243 arg .x .append (10 )
1229- with SqlCounter (query_count = 6 , udtf_count = 1 ):
1244+ with SqlCounter (query_count = 6 , join_count = 1 , udtf_count = 1 ):
12301245 eval_snowpark_pandas_result (
12311246 * create_test_dfs (test_data ), lambda df : df .apply (operation , arg = arg )
12321247 )
@@ -1235,6 +1250,7 @@ def operation(col, arg):
12351250 with SqlCounter (
12361251 query_count = 11 ,
12371252 udtf_count = 2 ,
1253+ join_count = 2 ,
12381254 high_count_expected = True ,
12391255 high_count_reason = "multiple apply calls in sequence" ,
12401256 ):
@@ -1251,7 +1267,7 @@ def operation(col, arg):
12511267 # pickling creates different binary blobs.
12521268 arg2 = A ()
12531269 arg2 .x .append (10 )
1254- with SqlCounter (query_count = 3 , udtf_count = 1 ):
1270+ with SqlCounter (query_count = 3 , join_count = 1 , udtf_count = 1 ):
12551271 eval_snowpark_pandas_result (
12561272 * create_test_dfs (test_data ), lambda df : df .apply (operation , arg = arg2 )
12571273 )
0 commit comments