Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from typing import Any, List

from sqlalchemy import Sequence
from sqlalchemy.sql.schema import MetaData, SchemaItem, Table

from ..._constants import DIALECT_NAME
Expand Down Expand Up @@ -51,8 +52,28 @@ def __init__(
super().__init__(name, metadata, *args, **kw)

if not kw.get("autoload_with", False):
self._attach_implicit_sequence()
self._validate_table()

def _attach_implicit_sequence(self):
"""Attach an implicit Sequence to the autoincrement column if needed.

Snowflake does not support INSERT ... RETURNING, so the ORM cannot
retrieve auto-generated primary key values after INSERT. By attaching
a Sequence, SQLAlchemy's fire_sequence mechanism pre-fetches the next
ID before each INSERT, allowing the ORM to work transparently.
"""
auto_col = self._autoincrement_column
if (
auto_col is not None
and auto_col.default is None
and auto_col.identity is None
and auto_col.server_default is None
):
seq_name = f"{self.name}_{auto_col.name}_seq"
seq = Sequence(seq_name, schema=self.schema)
auto_col._init_items(seq)

def _validate_table(self):
exceptions: List[Exception] = []

Expand Down
86 changes: 85 additions & 1 deletion tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import pytest
from sqlalchemy import (
Column,
Identity,
Expand All @@ -11,11 +11,16 @@
String,
Table,
insert,
inspect,
select,
)
from sqlalchemy.orm import Mapped, Session, declarative_base, mapped_column
from sqlalchemy.orm.exc import FlushError
from sqlalchemy.sql import text
from sqlalchemy.sql.ddl import CreateTable

from snowflake.sqlalchemy import SnowflakeTable


def test_table_with_sequence(engine_testaccount, db_parameters):
"""Snowflake does not guarantee generating sequence numbers without gaps.
Expand Down Expand Up @@ -139,6 +144,85 @@ def test_table_with_autoincrement(engine_testaccount):
metadata.drop_all(engine_testaccount)


def test_orm_autoincrement_without_snowflake_table_fails(engine_testaccount):
"""ORM autoincrement fails with default Table because Snowflake does not
support INSERT ... RETURNING and the connector lacks lastrowid support.

Without SnowflakeTable's implicit Sequence, the ORM cannot retrieve the
auto-generated primary key after INSERT, resulting in a FlushError.
"""

Base = declarative_base()

class User(Base):
__tablename__ = "test_orm_autoincrement_no_sf_table"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column()

try:
Base.metadata.create_all(engine_testaccount)

with Session(engine_testaccount) as session:
session.add(User(name="Spongebob"))
with pytest.raises(FlushError, match="NULL identity key"):
session.commit()
finally:
Base.metadata.drop_all(engine_testaccount)


def test_orm_autoincrement_with_snowflake_table(engine_testaccount):
"""ORM autoincrement should work with SnowflakeTable without explicit Sequence.

SnowflakeTable implicitly creates a Sequence for autoincrement columns,
allowing the ORM to pre-fetch the next ID via fire_sequence before INSERT.
This works around the lack of INSERT ... RETURNING support in Snowflake.
"""

Base = declarative_base()

class User(Base):
__tablename__ = "test_orm_autoincrement"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column()

@classmethod
def __table_cls__(cls, name, metadata, *arg, **kw):
return SnowflakeTable(name, metadata, *arg, **kw)

# Verify implicit Sequence was attached to the id column
id_col = User.__table__.c.id
assert isinstance(
id_col.default, Sequence
), "SnowflakeTable should attach an implicit Sequence to the autoincrement column"
assert id_col.default.name == "test_orm_autoincrement_id_seq"

try:
Base.metadata.create_all(engine_testaccount)

# Verify the sequence actually exists in Snowflake
insp = inspect(engine_testaccount)
assert insp.has_sequence(
"test_orm_autoincrement_id_seq"
), "Implicit sequence should be created in Snowflake"

with Session(engine_testaccount) as session:
u1 = User(name="Spongebob")
u2 = User(name="Patrick")
session.add_all([u1, u2])
session.commit()

assert u1.id is not None
assert u2.id is not None
assert u1.id != u2.id

fetched = session.get(User, u1.id)
assert fetched.name == "Spongebob"
finally:
Base.metadata.drop_all(engine_testaccount)


def test_table_with_identity(sql_compiler):
test_table_name = "identity"
metadata = MetaData()
Expand Down
Loading