Skip to content

Commit 771eec3

Browse files
hovaescoebyhr
authored andcommitted
Add option to pass attributes in connection string in SQLAlchemy
1 parent 236febf commit 771eec3

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ rows = connection.execute(select(nodes)).fetchall()
9090
```
9191

9292
In order to pass additional connection attributes use [connect_args](https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args) method.
93+
Attributes can be also passed in connection string.
9394

9495
```python
9596
from sqlalchemy import create_engine
@@ -101,6 +102,13 @@ engine = create_engine(
101102
"client_tags": ["tag1", "tag2"]
102103
}
103104
)
105+
106+
# or in connection string
107+
engine = create_engine(
108+
'trino://user@localhost:8080/system?'
109+
'session_properties={"query_max_run_time": "1d"}'
110+
'&client_tags=["tag1", "tag2"]',
111+
)
104112
```
105113

106114
## Authentications

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def setup(self):
4141
source="trino-rulez"
4242
),
4343
),
44+
(
45+
make_url(
46+
'trino://user@localhost:8080?'
47+
'session_properties={"query_max_run_time": "1d"}'
48+
'&http_headers={"trino": 1}'
49+
'&extra_credential=[("a", "b"), ("c", "d")]'
50+
'&client_tags=[1, "sql"]'),
51+
list(),
52+
dict(
53+
host="localhost",
54+
port=8080,
55+
catalog="system",
56+
user="user",
57+
source="trino-sqlalchemy",
58+
session_properties={"query_max_run_time": "1d"},
59+
http_headers={"trino": 1},
60+
extra_credential=[("a", "b"), ("c", "d")],
61+
client_tags=[1, "sql"]
62+
),
63+
),
4464
],
4565
)
4666
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):

trino/sqlalchemy/dialect.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
import json
13+
from ast import literal_eval
1214
from textwrap import dedent
1315
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
1416

@@ -105,6 +107,18 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
105107
else:
106108
kwargs["source"] = "trino-sqlalchemy"
107109

110+
if "session_properties" in url.query:
111+
kwargs["session_properties"] = json.loads(url.query["session_properties"])
112+
113+
if "http_headers" in url.query:
114+
kwargs["http_headers"] = json.loads(url.query["http_headers"])
115+
116+
if "extra_credential" in url.query:
117+
kwargs["extra_credential"] = literal_eval(url.query["extra_credential"])
118+
119+
if "client_tags" in url.query:
120+
kwargs["client_tags"] = json.loads(url.query["client_tags"])
121+
108122
return args, kwargs
109123

110124
def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:

0 commit comments

Comments
 (0)