Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class Generator(MySQL.Generator):
VARCHAR_REQUIRES_SIZE = False
WITH_PROPERTIES_PREFIX = "PROPERTIES"
RENAME_TABLE_WITH_DB = False
UPDATE_STATEMENT_SUPPORTS_FROM = True

TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
Expand Down
29 changes: 1 addition & 28 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ class Generator(generator.Generator):
WRAP_DERIVED_VALUES = False
VARCHAR_REQUIRES_SIZE = True
SUPPORTS_MEDIAN = False
UPDATE_STATEMENT_SUPPORTS_FROM = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down Expand Up @@ -1337,31 +1338,3 @@ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
@unsupported_args("this")
def currentschema_sql(self, expression: exp.CurrentSchema) -> str:
return self.func("SCHEMA")

def _update_from_joins_sql(self, expression: exp.Update) -> t.Tuple[str, str]:
from_expr = expression.args.get("from_")
if not from_expr:
return ("", "")

# Qualify unqualified columns in SET clause with the target table
# MySQL requires qualified column names in multi-table UPDATE to avoid ambiguity
target_table = expression.this
if isinstance(target_table, exp.Table):
target_name = target_table.alias_or_name
for eq in expression.expressions:
col = eq.this
if isinstance(col, exp.Column) and not col.table:
col.set("table", exp.to_identifier(target_name))

table = from_expr.this
nested_joins = table.args.get("joins") or []
if nested_joins:
table.set("joins", None)

join_sql = self.sql(exp.Join(this=table, on=exp.true()))
for nested in nested_joins:
if not nested.args.get("on") and not nested.args.get("using"):
nested.set("on", exp.true())
join_sql += self.sql(nested)

return (join_sql, "")
1 change: 1 addition & 0 deletions sqlglot/dialects/starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class Generator(MySQL.Generator):
VARCHAR_REQUIRES_SIZE = False
PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"
WITH_PROPERTIES_PREFIX = "PROPERTIES"
UPDATE_STATEMENT_SUPPORTS_FROM = True

# StarRocks doesn't support "IS TRUE/FALSE" syntax.
IS_BOOL_ALLOWED = False
Expand Down
32 changes: 30 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,11 @@ class Generator(metaclass=_Generator):
# Whether to include the VARIABLE keyword for SET assignments
SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD = False

# Whether FROM is supported in UPDATE statements or if joins must be generated instead, e.g:
# Supported (Postgres, Doris etc): UPDATE t1 SET t1.a = t2.b FROM t2
# Unsupported (MySQL, SingleStore): UPDATE t1 JOIN t2 ON TRUE SET t1.a = t2.b
UPDATE_STATEMENT_SUPPORTS_FROM = True

TYPE_MAPPING = {
exp.DataType.Type.DATETIME2: "TIMESTAMP",
exp.DataType.Type.NCHAR: "CHAR",
Expand Down Expand Up @@ -2228,9 +2233,32 @@ def _update_from_joins_sql(self, expression: exp.Update) -> t.Tuple[str, str]:
Returns (join_sql, from_sql) for UPDATE statements.
- join_sql: placed after UPDATE table, before SET
- from_sql: placed after SET clause (standard position)
Dialects like MySQL override to convert FROM to JOIN syntax.
Dialects like MySQL need to convert FROM to JOIN syntax.
"""
return ("", self.sql(expression, "from_"))
if self.UPDATE_STATEMENT_SUPPORTS_FROM or not (from_expr := expression.args.get("from_")):
return ("", self.sql(expression, "from_"))

# Qualify unqualified columns in SET clause with the target table
# MySQL requires qualified column names in multi-table UPDATE to avoid ambiguity
target_table = expression.this
if isinstance(target_table, exp.Table):
target_name = exp.to_identifier(target_table.alias_or_name)
for eq in expression.expressions:
col = eq.this
if isinstance(col, exp.Column) and not col.table:
col.set("table", target_name)

table = from_expr.this
if nested_joins := table.args.get("joins", []):
table.set("joins", None)

join_sql = self.sql(exp.Join(this=table, on=exp.true()))
for nested in nested_joins:
if not nested.args.get("on") and not nested.args.get("using"):
nested.set("on", exp.true())
join_sql += self.sql(nested)

return (join_sql, "")

def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
Expand Down
19 changes: 2 additions & 17 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,8 @@ def test_ddl(self):
self.validate_identity("ALTER TABLE t ALTER INDEX i VISIBLE")
self.validate_identity("ALTER TABLE t ALTER COLUMN c SET INVISIBLE")
self.validate_identity("ALTER TABLE t ALTER COLUMN c SET VISIBLE")

def test_update_from_to_join(self):
# MySQL multi-table UPDATE requires qualified columns in SET to avoid ambiguity
self.validate_all(
"UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a WHERE foo.id = bar.id",
read={
"postgres": "UPDATE foo SET a = bar.a FROM bar WHERE foo.id = bar.id",
"mysql": "UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a WHERE foo.id = bar.id",
},
)

# Multiple columns in SET clause
self.validate_all(
"UPDATE t1 JOIN t2 ON TRUE SET t1.id = t2.id, t1.name = t2.name WHERE t1.x = t2.x",
read={
"postgres": "UPDATE t1 SET id = t2.id, name = t2.name FROM t2 WHERE t1.x = t2.x",
},
self.validate_identity(
"UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a WHERE foo.id = bar.id"
)

def test_identity(self):
Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,17 @@ def test_postgres(self):
width_bucket = self.validate_identity("WIDTH_BUCKET(10, 5, 15, 25)")
self.assertIsNone(width_bucket.args.get("threshold"))

self.validate_all(
"UPDATE foo SET a = bar.a, b = bar.b FROM bar WHERE foo.id = bar.id",
write={
"postgres": "UPDATE foo SET a = bar.a, b = bar.b FROM bar WHERE foo.id = bar.id",
"doris": "UPDATE foo SET a = bar.a, b = bar.b FROM bar WHERE foo.id = bar.id",
"starrocks": "UPDATE foo SET a = bar.a, b = bar.b FROM bar WHERE foo.id = bar.id",
"mysql": "UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a, foo.b = bar.b WHERE foo.id = bar.id",
"singlestore": "UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a, foo.b = bar.b WHERE foo.id = bar.id",
},
)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
Expand Down