Skip to content

Commit 888505c

Browse files
authored
SNOW-2246506: Support SPCS function invocation (#3798)
1 parent a6ae80b commit 888505c

File tree

7 files changed

+116
-0
lines changed

7 files changed

+116
-0
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Release History
22

3+
## 1.41.0 (YYYY-MM-DD)
4+
5+
### Snowpark Python API Updates
6+
7+
#### New Features
8+
9+
- Added a new function `service` in `snowflake.snowpark.functions` that allows users to create a callable representing a Snowpark Container Services (SPCS) service.
10+
311
## 1.40.0 (YYYY-MM-DD)
412

513
### Snowpark Python API Updates

docs/source/snowpark/functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ Functions
414414
seq4
415415
seq8
416416
sequence
417+
service
417418
sha1
418419
sha2
419420
sin

src/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def pytest_collection_modifyitems(config, items):
9797
disabled_doctests = [
9898
"ai_classify",
9999
"model",
100+
"service",
100101
] # Add any test names that should be skipped
101102
for item in items:
102103
# identify doctest items

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
like_expression,
2929
list_agg,
3030
model_expression,
31+
service_expression,
3132
named_arguments_function,
3233
order_expression,
3334
range_statement,
@@ -74,6 +75,7 @@
7475
ListAgg,
7576
Literal,
7677
ModelExpression,
78+
ServiceExpression,
7779
MultipleExpression,
7880
NamedExpression,
7981
NamedFunctionExpression,
@@ -430,6 +432,16 @@ def analyze(
430432
],
431433
)
432434

435+
if isinstance(expr, ServiceExpression):
436+
return service_expression(
437+
expr.service_name,
438+
expr.method_name,
439+
[
440+
self.to_sql_try_avoid_cast(c, df_aliased_col_name_to_real_col_name)
441+
for c in expr.children
442+
],
443+
)
444+
433445
if isinstance(expr, FunctionExpression):
434446
if expr.api_call_source is not None:
435447
self.session._conn._telemetry_client.send_function_usage_telemetry(

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ def model_expression(
278278
return f"{MODEL}{LEFT_PARENTHESIS}{model_args_str}{RIGHT_PARENTHESIS}{EXCLAMATION_MARK}{method_name}{LEFT_PARENTHESIS}{COMMA.join(children)}{RIGHT_PARENTHESIS}"
279279

280280

281+
def service_expression(
282+
service_name: str,
283+
method_name: str,
284+
children: List[str],
285+
) -> str:
286+
return f"{service_name}{EXCLAMATION_MARK}{method_name}{LEFT_PARENTHESIS}{COMMA.join(children)}{RIGHT_PARENTHESIS}"
287+
288+
281289
def function_expression(name: str, children: List[str], is_distinct: bool) -> str:
282290
return (
283291
name

src/snowflake/snowpark/_internal/analyzer/expression.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,29 @@ def plan_node_category(self) -> PlanNodeCategory:
575575
return PlanNodeCategory.FUNCTION
576576

577577

578+
class ServiceExpression(Expression):
579+
def __init__(
580+
self,
581+
service_name: str,
582+
method_name: str,
583+
arguments: List[Expression],
584+
) -> None:
585+
super().__init__()
586+
self.service_name = service_name
587+
self.method_name = method_name
588+
self.children = arguments
589+
590+
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
591+
return derive_dependent_columns(*self.children)
592+
593+
def dependent_column_names_with_duplication(self) -> List[str]:
594+
return derive_dependent_columns_with_duplication(*self.children)
595+
596+
@property
597+
def plan_node_category(self) -> PlanNodeCategory:
598+
return PlanNodeCategory.FUNCTION
599+
600+
578601
class FunctionExpression(Expression):
579602
def __init__(
580603
self,

src/snowflake/snowpark/functions.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
ListAgg,
182182
Literal,
183183
ModelExpression,
184+
ServiceExpression,
184185
MultipleExpression,
185186
Star,
186187
NamedFunctionExpression,
@@ -10746,6 +10747,30 @@ def _call_model(
1074610747
)
1074710748

1074810749

10750+
def _call_service(
10751+
service_name: str,
10752+
method_name: str,
10753+
*args,
10754+
_emit_ast: bool = True,
10755+
) -> Column:
10756+
if _emit_ast:
10757+
_ast = build_function_expr("service", [service_name, method_name, *args])
10758+
else:
10759+
_ast = None
10760+
10761+
args_list = parse_positional_args_to_list(*args)
10762+
expressions = [Column._to_expr(arg) for arg in args_list]
10763+
return Column(
10764+
ServiceExpression(
10765+
service_name,
10766+
method_name,
10767+
expressions,
10768+
),
10769+
_ast=_ast,
10770+
_emit_ast=_emit_ast,
10771+
)
10772+
10773+
1074910774
@publicapi
1075010775
def model(
1075110776
model_name: str,
@@ -10775,6 +10800,44 @@ def model(
1077510800
)
1077610801

1077710802

10803+
@publicapi
10804+
def service(
10805+
service_name: str,
10806+
_emit_ast: bool = True,
10807+
) -> Callable:
10808+
"""
10809+
Creates a service function that can be used to call a service method.
10810+
10811+
Args:
10812+
service_name: The name of the service to call.
10813+
10814+
Example::
10815+
10816+
>>> service_instance = service("TESTSCHEMA_SNOWPARK_PYTHON.FORECAST_MODEL_SERVICE")
10817+
>>> # Prepare a DataFrame with the ten expected features
10818+
>>> df = session.create_dataframe(
10819+
... [
10820+
... (0.038076, 0.050680, 0.061696, 0.021872, -0.044223, -0.034821, -0.043401, -0.002592, 0.019907, -0.017646),
10821+
... ],
10822+
... schema=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"],
10823+
... )
10824+
>>> # Invoke the model's predict method exposed by the service
10825+
>>> result_df = df.select(
10826+
... service_instance("predict")(col("age"), col("sex"), col("bmi"), col("bp"), col("s1"), col("s2"), col("s3"), col("s4"), col("s5"), col("s6"))["output_feature_0"]
10827+
... )
10828+
>>> result_df.show()
10829+
------------------------------------------------------
10830+
|"TESTSCHEMA_SNOWPARK_PYTHON.FORECAST_MODEL_SERV... |
10831+
------------------------------------------------------
10832+
|220.2223358154297 |
10833+
------------------------------------------------------
10834+
<BLANKLINE>
10835+
"""
10836+
return lambda method_name: lambda *args: _call_service(
10837+
service_name, method_name, *args, _emit_ast=_emit_ast
10838+
)
10839+
10840+
1077810841
# Add these alias for user code migration
1077910842
call_builtin = call_function
1078010843
collect_set = array_unique_agg

0 commit comments

Comments
 (0)