Skip to content

Commit 158a132

Browse files
uros-dbHyukjinKwon
authored andcommitted
[SPARK-54399][GEO][SQL][PYTHON] Implement the st_setsrid function in Scala and PySpark
### What changes were proposed in this pull request? Implement the `st_setsrid` function in Scala and PySpark API. ### Why are the changes needed? Expand API support for the `ST_SetSrid` expression. ### Does this PR introduce _any_ user-facing change? Yes, the new function is now available in Scala and PySpark API. ### How was this patch tested? Added appropriate Scala function unit tests: - `STFunctionsSuite` Added appropriate PySpark function unit tests: - `test_functions` ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#53117 from uros-db/geo-ST_SetSrid-scala. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent ee0f692 commit 158a132

File tree

7 files changed

+101
-1
lines changed

7 files changed

+101
-1
lines changed

python/docs/source/reference/pyspark.sql/functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ Geospatial ST Functions
678678
st_asbinary
679679
st_geogfromwkb
680680
st_geomfromwkb
681+
st_setsrid
681682
st_srid
682683

683684

python/pyspark/sql/connect/functions/builtin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4969,6 +4969,15 @@ def st_geomfromwkb(wkb: "ColumnOrName") -> Column:
49694969
st_geomfromwkb.__doc__ = pysparkfuncs.st_geomfromwkb.__doc__
49704970

49714971

4972+
def st_setsrid(geo: "ColumnOrName", srid: Union["ColumnOrName", int]) -> Column:
4973+
srid = _enum_to_value(srid)
4974+
srid = lit(srid) if isinstance(srid, int) else srid
4975+
return _invoke_function_over_columns("st_setsrid", geo, srid)
4976+
4977+
4978+
st_setsrid.__doc__ = pysparkfuncs.st_setsrid.__doc__
4979+
4980+
49724981
def st_srid(geo: "ColumnOrName") -> Column:
49734982
return _invoke_function_over_columns("st_srid", geo)
49744983

python/pyspark/sql/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@
544544
"st_asbinary",
545545
"st_geogfromwkb",
546546
"st_geomfromwkb",
547+
"st_setsrid",
547548
"st_srid",
548549
# Call Functions
549550
"call_udf",

python/pyspark/sql/functions/builtin.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26143,6 +26143,39 @@ def st_geomfromwkb(wkb: "ColumnOrName") -> Column:
2614326143
return _invoke_function_over_columns("st_geomfromwkb", wkb)
2614426144

2614526145

26146+
@_try_remote_functions
26147+
def st_setsrid(geo: "ColumnOrName", srid: Union["ColumnOrName", int]) -> Column:
26148+
"""Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
26149+
26150+
.. versionadded:: 4.1.0
26151+
26152+
Parameters
26153+
----------
26154+
geo : :class:`~pyspark.sql.Column` or str
26155+
A geospatial value, either a GEOGRAPHY or a GEOMETRY.
26156+
srid : :class:`~pyspark.sql.Column` or int
26157+
An INTEGER representing the new SRID of the geospatial value.
26158+
26159+
Examples
26160+
--------
26161+
26162+
Example 1: Setting the SRID on GEOGRAPHY with SRID from another column.
26163+
>>> from pyspark.sql import functions as sf
26164+
>>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'), 4326)], ['wkb', 'srid']) # noqa
26165+
>>> df.select(sf.st_srid(sf.st_setsrid(sf.st_geogfromwkb('wkb'), 'srid'))).collect()
26166+
[Row(st_srid(st_setsrid(st_geogfromwkb(wkb), srid))=4326)]
26167+
26168+
Example 2: Setting the SRID on GEOMETRY with SRID as an integer literal.
26169+
>>> from pyspark.sql import functions as sf
26170+
>>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
26171+
>>> df.select(sf.st_srid(sf.st_setsrid(sf.st_geomfromwkb('wkb'), 4326))).collect()
26172+
[Row(st_srid(st_setsrid(st_geomfromwkb(wkb), 4326))=4326)]
26173+
"""
26174+
srid = _enum_to_value(srid)
26175+
srid = lit(srid) if isinstance(srid, int) else srid
26176+
return _invoke_function_over_columns("st_setsrid", geo, srid)
26177+
26178+
2614626179
@_try_remote_functions
2614726180
def st_srid(geo: "ColumnOrName") -> Column:
2614826181
"""Returns the SRID of the input GEOGRAPHY or GEOMETRY value.

python/pyspark/sql/tests/test_functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2974,6 +2974,25 @@ def test_st_asbinary(self):
29742974
)
29752975
self.assertEqual(results, [expected])
29762976

2977+
def test_st_setsrid(self):
2978+
df = self.spark.createDataFrame(
2979+
[(bytes.fromhex("0101000000000000000000F03F0000000000000040"), 4326)],
2980+
["wkb", "srid"],
2981+
)
2982+
results = df.select(
2983+
F.st_srid(F.st_setsrid(F.st_geogfromwkb("wkb"), "srid")),
2984+
F.st_srid(F.st_setsrid(F.st_geomfromwkb("wkb"), "srid")),
2985+
F.st_srid(F.st_setsrid(F.st_geogfromwkb("wkb"), 4326)),
2986+
F.st_srid(F.st_setsrid(F.st_geomfromwkb("wkb"), 4326)),
2987+
).collect()
2988+
expected = Row(
2989+
4326,
2990+
4326,
2991+
4326,
2992+
4326,
2993+
)
2994+
self.assertEqual(results, [expected])
2995+
29772996
def test_st_srid(self):
29782997
df = self.spark.createDataFrame(
29792998
[(bytes.fromhex("0101000000000000000000F03F0000000000000040"),)],

sql/api/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9471,6 +9471,24 @@ object functions {
94719471
def st_geomfromwkb(wkb: Column): Column =
94729472
Column.fn("st_geomfromwkb", wkb)
94739473

9474+
/**
9475+
* Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
9476+
*
9477+
* @group st_funcs
9478+
* @since 4.1.0
9479+
*/
9480+
def st_setsrid(geo: Column, srid: Column): Column =
9481+
Column.fn("st_setsrid", geo, srid)
9482+
9483+
/**
9484+
* Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
9485+
*
9486+
* @group st_funcs
9487+
* @since 4.1.0
9488+
*/
9489+
def st_setsrid(geo: Column, srid: Int): Column =
9490+
Column.fn("st_setsrid", geo, lit(srid))
9491+
94749492
/**
94759493
* Returns the SRID of the input GEOGRAPHY or GEOMETRY value.
94769494
*

sql/core/src/test/scala/org/apache/spark/sql/STFunctionsSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ class STFunctionsSuite extends QueryTest with SharedSparkSession {
5959
Row(4326, 0))
6060
}
6161

62+
/** ST modifier expressions. */
63+
64+
test("st_setsrid") {
65+
// Test data: Well-Known Binary (WKB) representations.
66+
val df = Seq[(String, Int)](
67+
(
68+
"0101000000000000000000f03f0000000000000040", 4326
69+
)).toDF("wkb", "srid")
70+
// ST_GeogFromWKB/ST_GeomFromWKB and ST_Srid.
71+
checkAnswer(
72+
df.select(
73+
st_srid(st_setsrid(st_geogfromwkb(unhex($"wkb")), $"srid")).as("col0"),
74+
st_srid(st_setsrid(st_geomfromwkb(unhex($"wkb")), $"srid")).as("col1"),
75+
st_srid(st_setsrid(st_geomfromwkb(unhex($"wkb")), 4326)).as("col1"),
76+
st_srid(st_setsrid(st_geomfromwkb(unhex($"wkb")), 4326)).as("col1")),
77+
Row(4326, 4326, 4326, 4326))
78+
}
79+
6280
/** Geospatial feature is disabled. */
6381

6482
test("verify that geospatial functions are disabled when the config is off") {
@@ -68,7 +86,8 @@ class STFunctionsSuite extends QueryTest with SharedSparkSession {
6886
st_asbinary(lit(null)).as("res"),
6987
st_geogfromwkb(lit(null)).as("res"),
7088
st_geomfromwkb(lit(null)).as("res"),
71-
st_srid(lit(null)).as("res")
89+
st_srid(lit(null)).as("res"),
90+
st_setsrid(lit(null), lit(null)).as("res")
7291
).foreach { func =>
7392
checkError(
7493
exception = intercept[AnalysisException] {

0 commit comments

Comments
 (0)