Skip to content

Commit 04b23af

Browse files
committed
SNOW-1873529: Sync structured type AST changes
1 parent 95ca97f commit 04b23af

File tree

16 files changed

+242
-235
lines changed

16 files changed

+242
-235
lines changed

scripts/copy-remote-ast.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
set -euxo pipefail
23

34
# This script assumes the target Cloud Workspace specified as the command-line argument has the build target.
45
# To make sure this is the case, run bazel build //Snowpark/ast:ast_proto and bazel build //Snowpark/unparser.

src/snowflake/snowpark/_internal/ast/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def build_proto_from_struct_type(
339339

340340
expr.structured = schema.structured
341341
for field in schema.fields:
342-
ast_field = expr.fields.add()
342+
ast_field = expr.fields.list.add()
343343
field.column_identifier._fill_ast(ast_field.column_identifier) # type: ignore[attr-defined] # TODO(SNOW-1491199) # "ColumnIdentifier" has no attribute "_fill_ast"
344344
field.datatype._fill_ast(ast_field.data_type) # type: ignore[attr-defined] # TODO(SNOW-1491199) # "DataType" has no attribute "_fill_ast"
345345
ast_field.nullable = field.nullable

src/snowflake/snowpark/_internal/proto/ast.proto

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ message List_SpDataType {
2222
repeated SpDataType list = 1;
2323
}
2424

25+
message List_SpStructField {
26+
repeated SpStructField list = 1;
27+
}
28+
2529
message List_String {
2630
repeated string list = 1;
2731
}
@@ -174,7 +178,7 @@ message SpStructField {
174178

175179
// sp-type.ir:46
176180
message SpStructType {
177-
repeated SpStructField fields = 1;
181+
List_SpStructField fields = 1;
178182
bool structured = 2;
179183
}
180184

src/snowflake/snowpark/types.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,8 @@ def json_value(self) -> Dict[str, Any]:
384384

385385
def _fill_ast(self, ast: proto.SpDataType) -> None:
386386
ast.sp_array_type.structured = self.structured
387-
if self.element_type is None:
388-
raise NotImplementedError(
389-
"SNOW-1862700: AST does not support empty element_type."
390-
)
391-
self.element_type._fill_ast(ast.sp_array_type.ty)
387+
if self.element_type is not None:
388+
self.element_type._fill_ast(ast.sp_array_type.ty)
392389

393390

394391
class MapType(DataType):
@@ -783,8 +780,8 @@ def json_value(self) -> Dict[str, Any]:
783780

784781
def _fill_ast(self, ast: proto.SpDataType) -> None:
785782
ast.sp_struct_type.structured = self.structured
786-
for field in self.fields:
787-
field._fill_ast(ast.sp_struct_type.fields.add())
783+
for field in self.fields or []:
784+
field._fill_ast(ast.sp_struct_type.fields.list.add())
788785

789786

790787
class VariantType(DataType):

tests/ast/data/DataFrame.count2.test

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,28 @@ body {
3737
sp_dataframe_schema__struct {
3838
v {
3939
fields {
40-
column_identifier {
41-
name: "\"A\""
42-
}
43-
data_type {
44-
sp_string_type {
45-
length {
46-
value: 16777216
40+
list {
41+
column_identifier {
42+
name: "\"A\""
43+
}
44+
data_type {
45+
sp_string_type {
46+
length {
47+
value: 16777216
48+
}
4749
}
4850
}
51+
nullable: true
4952
}
50-
nullable: true
51-
}
52-
fields {
53-
column_identifier {
54-
name: "\"B\""
55-
}
56-
data_type {
57-
sp_long_type: true
53+
list {
54+
column_identifier {
55+
name: "\"B\""
56+
}
57+
data_type {
58+
sp_long_type: true
59+
}
60+
nullable: true
5861
}
59-
nullable: true
6062
}
6163
}
6264
}

tests/ast/data/Dataframe.to_snowpark_pandas.test

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ body {
3737
}
3838
}
3939
src {
40-
end_column: 41
41-
end_line: 25
4240
file: "SRC_POSITION_TEST_MODE"
43-
start_column: 13
4441
start_line: 25
4542
}
4643
variant {
@@ -69,10 +66,7 @@ body {
6966
}
7067
}
7168
src {
72-
end_column: 52
73-
end_line: 27
7469
file: "SRC_POSITION_TEST_MODE"
75-
start_column: 29
7670
start_line: 27
7771
}
7872
}
@@ -101,10 +95,7 @@ body {
10195
list: "A"
10296
}
10397
src {
104-
end_column: 65
105-
end_line: 29
10698
file: "SRC_POSITION_TEST_MODE"
107-
start_column: 29
10899
start_line: 29
109100
}
110101
}
@@ -134,10 +125,7 @@ body {
134125
}
135126
}
136127
src {
137-
end_column: 70
138-
end_line: 31
139128
file: "SRC_POSITION_TEST_MODE"
140-
start_column: 29
141129
start_line: 31
142130
}
143131
}
@@ -170,10 +158,7 @@ body {
170158
list: "A"
171159
}
172160
src {
173-
end_column: 87
174-
end_line: 33
175161
file: "SRC_POSITION_TEST_MODE"
176-
start_column: 29
177162
start_line: 33
178163
}
179164
}

tests/ast/data/RelationalGroupedDataFrame.test

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -869,34 +869,36 @@ body {
869869
}
870870
output_schema {
871871
fields {
872-
column_identifier {
873-
name: "location"
874-
}
875-
data_type {
876-
sp_string_type {
877-
length {
872+
list {
873+
column_identifier {
874+
name: "location"
875+
}
876+
data_type {
877+
sp_string_type {
878+
length {
879+
}
878880
}
879881
}
882+
nullable: true
880883
}
881-
nullable: true
882-
}
883-
fields {
884-
column_identifier {
885-
name: "temp_c"
886-
}
887-
data_type {
888-
sp_float_type: true
889-
}
890-
nullable: true
891-
}
892-
fields {
893-
column_identifier {
894-
name: "temp_f"
884+
list {
885+
column_identifier {
886+
name: "temp_c"
887+
}
888+
data_type {
889+
sp_float_type: true
890+
}
891+
nullable: true
895892
}
896-
data_type {
897-
sp_float_type: true
893+
list {
894+
column_identifier {
895+
name: "temp_f"
896+
}
897+
data_type {
898+
sp_float_type: true
899+
}
900+
nullable: true
898901
}
899-
nullable: true
900902
}
901903
}
902904
src {

tests/ast/data/Session.create_dataframe.test

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -412,25 +412,27 @@ body {
412412
sp_dataframe_schema__struct {
413413
v {
414414
fields {
415-
column_identifier {
416-
name: "a"
417-
}
418-
data_type {
419-
sp_integer_type: true
420-
}
421-
nullable: true
422-
}
423-
fields {
424-
column_identifier {
425-
name: "b"
415+
list {
416+
column_identifier {
417+
name: "a"
418+
}
419+
data_type {
420+
sp_integer_type: true
421+
}
422+
nullable: true
426423
}
427-
data_type {
428-
sp_string_type {
429-
length {
424+
list {
425+
column_identifier {
426+
name: "b"
427+
}
428+
data_type {
429+
sp_string_type {
430+
length {
431+
}
430432
}
431433
}
434+
nullable: true
432435
}
433-
nullable: true
434436
}
435437
}
436438
}

tests/ast/data/Table.merge.test

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,27 @@ body {
165165
sp_dataframe_schema__struct {
166166
v {
167167
fields {
168-
column_identifier {
169-
name: "num"
170-
}
171-
data_type {
172-
sp_integer_type: true
173-
}
174-
nullable: true
175-
}
176-
fields {
177-
column_identifier {
178-
name: "str"
168+
list {
169+
column_identifier {
170+
name: "num"
171+
}
172+
data_type {
173+
sp_integer_type: true
174+
}
175+
nullable: true
179176
}
180-
data_type {
181-
sp_string_type {
182-
length {
177+
list {
178+
column_identifier {
179+
name: "str"
180+
}
181+
data_type {
182+
sp_string_type {
183+
length {
184+
}
183185
}
184186
}
187+
nullable: true
185188
}
186-
nullable: true
187189
}
188190
}
189191
}

tests/ast/data/col_cast.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ df = df.select(col("A").cast(MapType(StringType(), StringType(), structured=Fals
9393

9494
df = df.select(col("A").cast(VectorType(FloatType(), 42)))
9595

96-
df = df.select(col("A").cast(StructType([], structured=False)))
96+
df = df.select(col("A").cast(StructType(structured=False)))
9797

9898
df = df.select(col("A").cast(VariantType()))
9999

0 commit comments

Comments
 (0)