Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
72457dc
Simple PG workflow working
aaron-congo Feb 6, 2026
6cd61c7
Cleanup
aaron-congo Feb 6, 2026
bc607f6
Fix failover2 wrong writer host
aaron-congo Feb 9, 2026
5daa970
Add mysql-connector SQLAlchemy ORM
jonathanl-bq Mar 25, 2026
5ee014a
Revert connection string in sqlalchemy orm unit test
jonathanl-bq Mar 25, 2026
4871ac7
Add __init__.py for sqlalchemy integration tests
jonathanl-bq Mar 26, 2026
068e2e9
Fix RdsUtils not being found
jonathanl-bq Mar 26, 2026
f81b1db
Translate basic django test to sqlalchemy
jonathanl-bq Apr 2, 2026
a1ebf4c
Add basic CRUD test for sqlalchemy ORM mysql tests
jonathanl-bq Apr 4, 2026
f081a63
Add remaining basic MySQL SQLAlchemy ORM tests
jonathanl-bq Apr 7, 2026
1b8d401
Remove temporary changes to get tests to run locally
jonathanl-bq Apr 7, 2026
4429afd
Add license headers and remove unused import
jonathanl-bq Apr 8, 2026
4f6500c
Try fixing mypy errors in tests
jonathanl-bq Apr 9, 2026
9bd824f
Try to fix mypy Base class error
jonathanl-bq Apr 13, 2026
ee53f8c
fix: test failures due to using legacy sqlalchemy api (#1226)
karenc-bq Apr 13, 2026
9f484bb
Add WIP sqlalchemy plugins tests
jonathanl-bq Apr 13, 2026
0ab9c66
Fix multiple class definition errors
jonathanl-bq Apr 14, 2026
a64d748
Override initialize for mysql_orm_dialect.py
jonathanl-bq Apr 21, 2026
fa1f875
Fix most of the sqlalchemy ORM plugin tests
jonathanl-bq May 4, 2026
c981d2d
test: add clean up between tests (#1232)
karenc-bq May 7, 2026
5c21749
Fix issue with plugins being shadowed by sqlalchemy create_engine
jonathanl-bq May 8, 2026
d5b2b82
Remove wrapper_plugins from opts after processing it
jonathanl-bq May 11, 2026
79f990f
Merge branch 'sqlalchemy-orm-mysql' into sqlalchemy-orm-plugins
jonathanl-bq May 12, 2026
f955a9c
Try to fix mypy issues
jonathanl-bq May 13, 2026
c6958bc
Try fixing one mypy error
jonathanl-bq May 13, 2026
1480c48
Fix syntax error
jonathanl-bq May 14, 2026
0f9859d
Fix retrieved variable types
jonathanl-bq May 14, 2026
c75e9f7
Fix mypy error about row
jonathanl-bq May 14, 2026
771b4b0
Try to fix mypy errors in mysql_orm_dialect.py
jonathanl-bq May 14, 2026
7941fe0
Try to fix LSP violation errors
jonathanl-bq May 14, 2026
159b6c3
Try to fix mypy error for missing errno field
jonathanl-bq May 14, 2026
7e2bb1c
Fix last mypy error
jonathanl-bq May 14, 2026
edd4d02
Use err variable's errno property
jonathanl-bq May 14, 2026
9f80c4e
Check errno property on correct type
jonathanl-bq May 14, 2026
f0d7d91
Address flake8 errors
jonathanl-bq May 14, 2026
cfebe5b
Add annotations import
jonathanl-bq May 14, 2026
b45cd27
Run isort
jonathanl-bq May 14, 2026
88dfb8e
Run isort on tests
jonathanl-bq May 14, 2026
c999602
Move int cast for connect_timeout to fix unit tests
jonathanl-bq May 15, 2026
9d2f642
Remove breakpoint call
jonathanl-bq May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 169 additions & 1 deletion aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import mysql.connector
from mysql.connector import CMySQLConnection
from mysql.connector.errors import Error
from sqlalchemy.dialects.mysql.mysqlconnector import \
MySQLDialect_mysqlconnector
from sqlalchemy.engine import default

from aws_advanced_python_wrapper import AwsWrapperConnection
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.utils.properties import (Properties,
PropertiesUtils)

if TYPE_CHECKING:
from sqlalchemy import Connection

from aws_advanced_python_wrapper.hostinfo import HostInfo


class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector):
Expand All @@ -27,3 +44,154 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector):

name = 'mysql'
driver = 'aws_wrapper_mysqlconnector'

@classmethod
def import_dbapi(cls):
"""
Return the DB-API 2.0 module.
SQLAlchemy calls this to get the driver module.
"""
import aws_advanced_python_wrapper
return aws_advanced_python_wrapper

def create_connect_args(self, url):
"""
Transform SQLAlchemy URL into connection arguments.
Must include the 'target' parameter for our wrapper driver.
"""
# Extract standard connection parameters
opts = url.translate_connect_args(username='user')

# Add query string parameters
opts.update(url.query)

# Add the required 'target' parameter for our wrapper
if 'target' not in opts:
opts['target'] = mysql.connector.Connect
if 'wrapper_plugins' not in opts:
opts['plugins'] = "aurora_connection_tracker,failover"
else:
opts['plugins'] = opts['wrapper_plugins']
opts.pop('wrapper_plugins', None)
if 'connect_timeout' in opts:
opts['connect_timeout'] = int(opts['connect_timeout'])

# Return empty args list and kwargs dict
return [], opts

def _detect_charset(self, connection: Connection) -> str:
if isinstance(connection, CMySQLConnection):
return connection.charset
else:
raise Exception("Could not detect charset because connection was not a CMySQLConnection.")

def _extract_error_code(self, exception: BaseException) -> int:
if isinstance(exception, AwsWrapperError):
err = exception.driver_error
if err and isinstance(err, Error):
return err.errno
else:
raise Exception("Could not extract error code because driver_error was not a BaseException.")
else:
raise Exception("Could not extract error code because exception was not an AwsWrapperError.")

def initialize(self, connection):
"""
Override initialization to handle type introspection.
The parent class tries to use TypeInfo.fetch() which requires
a native SQLAlchemy connection, not AwsWrapperConnection.
"""

# Unwrap SQLAlchemy's connection object
wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection)

# this is driver-based, does not need server version info
# and is fairly critical for even basic SQL operations
self._connection_charset: Optional[str] = self._detect_charset(
wrapper_conn.target_connection
)

# call super().initialize() because we need to have
# server_version_info set up. in 1.4 under python 2 only this does the
# "check unicode returns" thing, which is the one area that some
# SQL gets compiled within initialize() currently
default.DefaultDialect.initialize(self, connection)

self._detect_sql_mode(connection)
self._detect_ansiquotes(connection) # depends on sql mode
self._detect_casing(connection)
if self._server_ansiquotes:
# if ansiquotes == True, build a new IdentifierPreparer
# with the new setting
self.identifier_preparer = self.preparer(
self, server_ansiquotes=self._server_ansiquotes
)

self.supports_sequences = (
self.is_mariadb and self.server_version_info >= (10, 3)
)

self.supports_for_update_of = (
self._is_mysql and self.server_version_info >= (8,)
)

self.use_mysql_for_share = (
self._is_mysql and self.server_version_info >= (8, 0, 1)
)

self._needs_correct_for_88718_96365 = (
not self.is_mariadb and self.server_version_info >= (8,)
)

self.delete_returning = (
self.is_mariadb and self.server_version_info >= (10, 0, 5)
)

self.insert_returning = (
self.is_mariadb and self.server_version_info >= (10, 5)
)

self._requires_alias_for_on_duplicate_key = (
self._is_mysql and self.server_version_info >= (8, 0, 20)
)

self._warn_for_known_db_issues()

def _get_wrapper_connection_and_parent(self, connection):
"""
Traverse the connection chain to find AwsWrapperConnection and its parent connection.

Args:
connection: SQLAlchemy Connection object

Returns:
AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None
"""
# Start with the DBAPI connection
parent = connection
child = connection.connection

# Traverse up to 5 levels deep (reasonable limit)
for _ in range(5):
if isinstance(child, AwsWrapperConnection):
return child, parent

# Try to go deeper if there's a .connection attribute
if hasattr(child, 'connection'):
parent = child
child = child.connection
else:
break

return None

def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties:
prop_copy: Properties = Properties(props.copy())

prop_copy["host"] = host_info.host

if host_info.is_port_specified():
prop_copy["port"] = str(host_info.port)

PropertiesUtils.remove_wrapper_props(prop_copy)
return prop_copy
29 changes: 8 additions & 21 deletions tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
subqueryload)
from sqlalchemy.sql import func

from tests.integration.container.utils.rds_test_utility import RdsTestUtility
from ..utils.conditions import (disable_on_features, enable_on_deployments,
enable_on_engines)
from ..utils.database_engine import DatabaseEngine
Expand Down Expand Up @@ -114,41 +113,29 @@ class Book(Base):
TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT,
TestEnvironmentFeatures.PERFORMANCE])
class TestSqlAlchemy:
@pytest.fixture(scope='class')
def rds_utils(self):
region: str = TestEnvironment.get_current().get_info().get_region()
return RdsTestUtility(region)


@pytest.fixture(scope="class")
@pytest.fixture(scope="function")
def engine(self, conn_utils):
conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}'
engine = create_engine(conn_str)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
engine.dispose()

@pytest.fixture(scope="class")
def Session(self, engine):
Session = sessionmaker(bind=engine)
yield Session

@pytest.fixture(scope="class")
def session(self, Session):
session = Session()
@pytest.fixture(scope="function")
def session(self, engine):
session = sessionmaker(bind=engine)()
yield session
session.rollback()
session.close()

def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine):
def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, session):
"""Test SQLAlchemy backend configuration with empty plugins"""
# Verify that the connection is using the AWS wrapper
with engine.connect() as connection:
assert connection.connection is not None
assert session.connection().connection is not None

# Test basic connection functionality
with Session(engine) as session:
assert session.query(TestModel).count() == 0
assert session.query(TestModel).count() == 0

def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment):
"""Test basic SQLAlchemy ORM operations (CRUD)"""
Expand Down
Loading
Loading