Skip to content

Commit 3203d7d

Browse files
hovaescoebyhr
authored andcommitted
Add source property and default to SQLAlchemy
1 parent ca57138 commit 3203d7d

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ nodes = Table(
8989
rows = connection.execute(select(nodes)).fetchall()
9090
```
9191

92+
In order to pass additional connection attributes use [connect_args](https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args) method.
93+
94+
```python
95+
from sqlalchemy import create_engine
96+
97+
engine = create_engine(
98+
'trino://user@localhost:8080/system',
99+
connect_args={
100+
"session_properties": {'query_max_run_time': '1d'},
101+
"client_tags": ["tag1", "tag2"]
102+
}
103+
)
104+
```
105+
92106
## Authentications
93107

94108
### Basic Authentication

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def setup(self):
2222
(
2323
make_url("trino://user@localhost"),
2424
list(),
25-
dict(host="localhost", catalog="system", user="user"),
25+
dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"),
2626
),
2727
(
2828
make_url("trino://user@localhost:8080"),
2929
list(),
30-
dict(host="localhost", port=8080, catalog="system", user="user"),
30+
dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
3131
),
3232
(
33-
make_url("trino://user:pass@localhost:8080"),
33+
make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
3434
list(),
3535
dict(
3636
host="localhost",
@@ -39,6 +39,7 @@ def setup(self):
3939
user="user",
4040
auth=BasicAuthentication("user", "pass"),
4141
http_scheme="https",
42+
source="trino-rulez"
4243
),
4344
),
4445
],

trino/sqlalchemy/dialect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
9696
kwargs["http_scheme"] = "https"
9797
kwargs["auth"] = JWTAuthentication(url.query["access_token"])
9898

99+
if "source" in url.query:
100+
kwargs["source"] = url.query["source"]
101+
else:
102+
kwargs["source"] = "trino-sqlalchemy"
103+
99104
return args, kwargs
100105

101106
def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:

0 commit comments

Comments
 (0)