Skip to content

Commit 73abfac

Browse files
committed
Fix(redshift): do not inherit postgres ROUND generator closes #6340
1 parent c81258e commit 73abfac

File tree

3 files changed

+30
-22
lines changed

3 files changed

+30
-22
lines changed

sqlglot/dialects/postgres.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,29 @@ def _versioned_anyvalue_sql(self: Postgres.Generator, expression: exp.AnyValue)
268268
return rename_func("ANY_VALUE")(self, expression)
269269

270270

271+
def _round_sql(self: Postgres.Generator, expression: exp.Round) -> str:
272+
this = self.sql(expression, "this")
273+
decimals = self.sql(expression, "decimals")
274+
275+
if not decimals:
276+
return self.func("ROUND", this)
277+
278+
if not expression.type:
279+
from sqlglot.optimizer.annotate_types import annotate_types
280+
281+
expression = annotate_types(expression, dialect=self.dialect)
282+
283+
# ROUND(double precision, integer) is not permitted in Postgres
284+
# so it's necessary to cast to decimal before rounding.
285+
if expression.this.is_type(exp.DataType.Type.DOUBLE):
286+
decimal_type = exp.DataType.build(
287+
exp.DataType.Type.DECIMAL, expressions=expression.expressions
288+
)
289+
this = self.sql(exp.Cast(this=this, to=decimal_type))
290+
291+
return self.func("ROUND", this, decimals)
292+
293+
271294
class Postgres(Dialect):
272295
INDEX_OFFSET = 1
273296
TYPED_DIVISION = True
@@ -663,6 +686,7 @@ class Generator(generator.Generator):
663686
e.args.get("occurrence"),
664687
regexp_replace_global_modifier(e),
665688
),
689+
exp.Round: _round_sql,
666690
exp.Select: transforms.preprocess(
667691
[
668692
transforms.eliminate_semi_and_anti_joins,
@@ -710,28 +734,6 @@ class Generator(generator.Generator):
710734
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
711735
}
712736

713-
def round_sql(self, expression: exp.Round) -> str:
714-
this = self.sql(expression, "this")
715-
decimals = self.sql(expression, "decimals")
716-
717-
if not decimals:
718-
return self.func("ROUND", this)
719-
720-
if not expression.type:
721-
from sqlglot.optimizer.annotate_types import annotate_types
722-
723-
expression = annotate_types(expression, dialect=self.dialect)
724-
725-
# ROUND(double precision, integer) is not permitted in Postgres
726-
# so it's necessary to cast to decimal before rounding.
727-
if expression.this.is_type(exp.DataType.Type.DOUBLE):
728-
decimal_type = exp.DataType.build(
729-
exp.DataType.Type.DECIMAL, expressions=expression.expressions
730-
)
731-
this = self.sql(exp.Cast(this=this, to=decimal_type))
732-
733-
return self.func("ROUND", this, decimals)
734-
735737
def schemacommentproperty_sql(self, expression: exp.SchemaCommentProperty) -> str:
736738
self.unsupported("Table comments are not supported in the CREATE statement")
737739
return ""

sqlglot/dialects/redshift.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ class Generator(Postgres.Generator):
231231
TRANSFORMS.pop(exp.LastDay)
232232
TRANSFORMS.pop(exp.SHA2)
233233

234+
# Postgres does not permit a double precision argument in ROUND; Redshift does
235+
TRANSFORMS.pop(exp.Round)
236+
234237
RESERVED_KEYWORDS = {
235238
"aes128",
236239
"aes256",

tests/dialects/test_redshift.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class TestRedshift(Validator):
77

88
def test_redshift(self):
99
self.validate_identity("SELECT COSH(1.5)")
10+
self.validate_identity(
11+
"ROUND(CAST(a AS DOUBLE PRECISION) / CAST(b AS DOUBLE PRECISION), 2)"
12+
)
1013
self.validate_all(
1114
"SELECT SPLIT_TO_ARRAY('12,345,6789')",
1215
write={

0 commit comments

Comments
 (0)