Skip to content

Commit 5eedff2

Browse files
committed
feat: support AUTOINCREMENT when creating table
resolves #184
1 parent b05ee7f commit 5eedff2

File tree

4 files changed

+142
-4
lines changed

4 files changed

+142
-4
lines changed

fakesnow/cursor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,14 @@ def _transform_explode(self, expression: exp.Expression) -> list[exp.Expression]
320320
# Applies transformations that require splitting the expression into multiple expressions
321321
# Split transforms have limited support at the moment.
322322

323-
# Try merge transform first
323+
auto_seq_result = transforms.create_table_autoincrement(expression)
324+
if len(auto_seq_result) > 1:
325+
return auto_seq_result
326+
324327
merge_result = transforms.merge(expression)
325328
if len(merge_result) > 1:
326329
return merge_result
327330

328-
# If merge didn't split the expression, try alter_table_add_multiple_columns
329331
alter_result = transforms.alter_table_add_multiple_columns(expression)
330332
if len(alter_result) > 1:
331333
return alter_result

fakesnow/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fakesnow.transforms.ddl import (
44
alter_table_add_multiple_columns as alter_table_add_multiple_columns,
55
alter_table_strip_cluster_by as alter_table_strip_cluster_by,
6+
create_table_autoincrement as create_table_autoincrement,
67
)
78
from fakesnow.transforms.merge import merge as merge
89
from fakesnow.transforms.show import (

fakesnow/transforms/ddl.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313

1414
from __future__ import annotations
1515

16+
import secrets
17+
18+
import sqlglot
1619
from sqlglot import exp
1720

1821
from fakesnow.transforms.transforms import SUCCESS_NOP
1922

2023

21-
def alter_table_add_multiple_columns(expression: exp.Expression) -> list[exp.Expression]:
24+
def alter_table_add_multiple_columns(
25+
expression: exp.Expression,
26+
) -> list[exp.Expression]:
2227
"""Transform ALTER TABLE ADD COLUMN with multiple columns into separate statements.
2328
2429
Snowflake supports: ALTER TABLE IF EXISTS tab1 ADD COLUMN IF NOT EXISTS col1 INT, col2 VARCHAR(50), col3 BOOLEAN;
@@ -82,3 +87,99 @@ def alter_table_strip_cluster_by(expression: exp.Expression) -> exp.Expression:
8287
):
8388
return SUCCESS_NOP
8489
return expression
90+
91+
92+
def create_table_autoincrement(
93+
expression: exp.Expression,
94+
) -> list[exp.Expression]:
95+
"""Split CREATE TABLE with AUTOINCREMENT into CREATE SEQUENCE + CREATE TABLE with DEFAULT NEXTVAL.
96+
97+
Example transform:
98+
CREATE TABLE test_table (id NUMERIC NOT NULL AUTOINCREMENT, name VARCHAR)
99+
->
100+
CREATE SEQUENCE test_table_id_seq START 1;
101+
CREATE TABLE test_table (id NUMERIC NOT NULL DEFAULT NEXTVAL('test_table_id_seq'), name VARCHAR)
102+
"""
103+
104+
if not (
105+
isinstance(expression, exp.Create)
106+
and expression.kind == "TABLE"
107+
and (schema := expression.this)
108+
and (table := schema.this)
109+
and isinstance(schema, exp.Schema)
110+
and isinstance(table, exp.Table)
111+
# Find AUTOINCREMENT/IDENTITY columns
112+
and (
113+
auto_cols := [
114+
cd
115+
for cd in (schema.expressions or [])
116+
if isinstance(cd, exp.ColumnDef)
117+
and (cd.find(exp.AutoIncrementColumnConstraint) or cd.find(exp.GeneratedAsIdentityColumnConstraint))
118+
]
119+
)
120+
):
121+
return [expression]
122+
123+
if len(auto_cols) > 1:
124+
raise NotImplementedError("Multiple AUTOINCREMENT columns")
125+
126+
auto = auto_cols[0]
127+
col_name = auto.this.name
128+
table_name = table.name
129+
# When recreating the same table with a sequence, we need to give the sequence a unique name to avoid
130+
# Dependency Error: Cannot drop entry "_FS_SEQ_..." because there are entries that depend on it.
131+
random_suffix = secrets.token_hex(4)
132+
seq_name = f"_fs_seq_{table_name}_{col_name}_{random_suffix}"
133+
134+
# Build CREATE SEQUENCE, using START/INCREMENT if provided
135+
start_val = "1"
136+
increment_val = "1"
137+
138+
identity = auto.find(exp.GeneratedAsIdentityColumnConstraint)
139+
if identity:
140+
s = identity.args.get("start")
141+
i = identity.args.get("increment")
142+
if isinstance(s, exp.Literal):
143+
start_val = s.this
144+
if isinstance(i, exp.Literal):
145+
increment_val = i.this
146+
147+
seq_stmt = sqlglot.parse_one(
148+
f"CREATE SEQUENCE {seq_name} START WITH {start_val} INCREMENT BY {increment_val}",
149+
read="duckdb",
150+
)
151+
152+
# Build modified CREATE TABLE with DEFAULT NEXTVAL('<seq_name>') and without AUTOINCREMENT
153+
new_create: exp.Create = expression.copy()
154+
new_schema: exp.Schema = new_create.this
155+
156+
# Find the corresponding column in the copied schema
157+
target_col = next(
158+
cd for cd in (new_schema.expressions or []) if isinstance(cd, exp.ColumnDef) and cd.this.name == col_name
159+
)
160+
161+
existing_constraints: list[exp.Expression] = target_col.args.get("constraints", []) or []
162+
163+
# Replace the AUTOINCREMENT/IDENTITY constraint in-place with DEFAULT NEXTVAL('<seq_name>')
164+
for c in existing_constraints:
165+
if isinstance(c, exp.ColumnConstraint) and isinstance(
166+
c.args.get("kind"),
167+
(
168+
exp.AutoIncrementColumnConstraint,
169+
exp.GeneratedAsIdentityColumnConstraint,
170+
),
171+
):
172+
c.set(
173+
"kind",
174+
exp.DefaultColumnConstraint(
175+
this=exp.Anonymous(
176+
this="NEXTVAL",
177+
expressions=[exp.Literal(this=seq_name, is_string=True)],
178+
)
179+
),
180+
)
181+
break
182+
183+
target_col.set("constraints", existing_constraints)
184+
185+
return [seq_stmt, new_create]

tests/test_sequence.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,41 @@
1+
import re
2+
13
import snowflake.connector.cursor
24
import sqlglot
35

4-
from fakesnow.transforms import sequence_nextval
6+
from fakesnow.transforms import create_table_autoincrement, sequence_nextval
7+
8+
9+
def test_autoincrement(cur: snowflake.connector.cursor.SnowflakeCursor):
10+
cur.execute("CREATE TABLE test_table (id NUMERIC NOT NULL AUTOINCREMENT, name VARCHAR)")
11+
cur.execute("insert into test_table(name) values ('foo'), ('bar')")
12+
cur.execute("select * from test_table")
13+
assert cur.fetchall() == [(1, "foo"), (2, "bar")]
14+
15+
# recreate the table with a different sequence
16+
cur.execute(
17+
"CREATE or replace TABLE test_table(id NUMERIC NOT NULL IDENTITY start 10 increment 5 order, name VARCHAR)"
18+
)
19+
cur.execute("insert into test_table(name) values ('foo'), ('bar')")
20+
cur.execute("select * from test_table")
21+
assert cur.fetchall() == [(10, "foo"), (15, "bar")]
22+
23+
24+
def test_autoincrement_transform() -> None:
25+
expr = sqlglot.parse_one(
26+
"create table test_table (id numeric autoincrement start 10 increment 5 order)",
27+
dialect="snowflake",
28+
)
29+
seq_stmt, table_stmt = create_table_autoincrement(expr)
30+
seq_sql = seq_stmt.sql()
31+
table_sql = table_stmt.sql()
32+
33+
m_seq = re.search(r"CREATE SEQUENCE (\S+) START WITH 10 INCREMENT BY 5", seq_sql)
34+
assert m_seq, f"Unexpected SEQUENCE SQL: {seq_sql}"
35+
seq_name = m_seq.group(1)
36+
assert re.match(r"^_fs_seq_test_table_id_[0-9a-f]{8}$", seq_name)
37+
38+
assert table_sql == f"CREATE TABLE test_table (id DECIMAL(38, 0) DEFAULT NEXTVAL('{seq_name}'))"
539

640

741
def test_sequence(dcur: snowflake.connector.cursor.SnowflakeCursor):

0 commit comments

Comments
 (0)