Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile dev-requirements.in
Expand All @@ -18,8 +18,6 @@ distlib==0.3.8
# via virtualenv
docopt==0.6.2
# via tbump
exceptiongroup==1.3.0
# via pytest
filelock==3.16.1
# via virtualenv
freezegun==1.5.1
Expand Down Expand Up @@ -74,15 +72,8 @@ tabulate==0.8.10
# via cli-ui
tbump==6.11.0
# via -r dev-requirements.in
tomli==2.2.1
# via
# build
# pip-tools
# pytest
tomlkit==0.11.8
# via tbump
typing-extensions==4.15.0
# via exceptiongroup
tzdata==2025.2
# via pandas
unidecode==1.3.8
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ members = [
"soda-sqlserver",
"soda-synapse",
"soda-tests",
"soda-trino"
]

[dependency-groups]
Expand Down
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ testpaths =
soda-redshift/tests
soda-fabric/tests
soda-athena/tests
soda-trino/tests

pythonpath =
soda-core/src
Expand All @@ -30,7 +31,8 @@ pythonpath =
soda-redshift/tests
soda-fabric/tests
soda-athena/tests

soda-trino/tests

log_cli=false
log_cli_level=DEBUG
log_cli_format=%(message)s
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ soda-snowflake
soda-sqlserver
soda-synapse
soda-tests
soda-trino
soda-sparkdf
2 changes: 1 addition & 1 deletion scripts/release_matrix.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

printf "["; find . -maxdepth 1 -mindepth 1 -type d -name "soda*" -printf '"%f",' | sed 's/,$//'; printf "]"
printf "["; find . -maxdepth 1 -mindepth 1 -type d -name "soda*" -not -name "soda-trino" -printf '"%f",' | sed 's/,$//'; printf "]"
2 changes: 1 addition & 1 deletion scripts/test_matrix.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

printf "[";find . -maxdepth 1 -mindepth 1 -type d -name "soda*" -not -name "soda-core" -not -name "soda-tests" -printf '"%f",' | sed 's/,$//;s/soda-//g';printf "]"
printf "[";find . -maxdepth 1 -mindepth 1 -type d -name "soda*" -not -name "soda-core" -not -name "soda-tests" -not -name "soda-trino" -printf '"%f",' | sed 's/,$//;s/soda-//g';printf "]"
12 changes: 12 additions & 0 deletions soda-tests/src/helpers/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ class TestConnection:

You can provide invalid connection parameters and test that an expected error is raised. See
example implementation in soda-bigquery/tests/data_sources/test_bigquery.py

If the connection is expected to fail prior to the creation of the connection object,
then valid_yaml should be False, and expected_yaml_error must be set to include a substring
contained within the error message.

If the connection is expected to fail after the creation of the connection object but
prior to successful connecting, then valid_connection_params should be False, and expected_connection_error must be set to include a substring
contained within the error message.

If executing a test query (normally `SELECT 1`) is expected to fail,
then query_should_succeed should be False, and expected_query_error must be set to include a substring
contained within the error message.
"""

test_name: str
Expand Down
56 changes: 56 additions & 0 deletions soda-trino/local_instance/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Trino JWT Testing

Use the `docker-compose.yml` and associated config files to launch a local Trino instance configured for JWT authentication.

To generate the keys, run from within the `local_instance` directory:

```
openssl genrsa -out jwt-private.pem 2048
openssl rsa -in jwt-private.pem -pubout -out jwt-public.pem
keytool -genkeypair -alias trino -keyalg RSA -keystore trino-config/keystore.jks \
-storepass changeit -keypass changeit -dname "CN=localhost" -validity 365
cp jwt-public.pem trino-config/
```

To start Trino:

```
docker compose up
```
If the following runs without error, your instance is up
```
curl -k https://localhost:8443/v1/info "
```




To generate a JWT token:

```
jwt -sign - \
-alg RS256 \
-key jwt-private.pem <<'EOF'
{
"sub": "test-user",
"iss": "local",
"aud": "trino"
}
EOF
```

If the following runs without error, your token is valid:
```
curl -k https://localhost:8443/v1/query -H "Authorization: Bearer {token}"
```


Copy the token into your .env file along with these env vars
```
TRINO_HOST="localhost"
TRINO_PORT=8443
TRINO_CATALOG="system"
TRINO_JWT_TOKEN="{token}"
```

Uncomment and run the `real_jwt_token` test in `test_trino.py`. The test should pass.
7 changes: 7 additions & 0 deletions soda-trino/local_instance/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
services:
trino:
image: trinodb/trino:latest
ports:
- "8443:8443"
volumes:
- ./trino-config:/etc/trino
13 changes: 13 additions & 0 deletions soda-trino/local_instance/trino-config/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
coordinator=true
node-scheduler.include-coordinator=true
http-server.http.enabled=false
http-server.https.enabled=true
http-server.https.port=8443
http-server.https.keystore.path=/etc/trino/keystore.jks
http-server.https.keystore.key=changeit
http-server.authentication.type=jwt
http-server.authentication.jwt.key-file=/etc/trino/jwt-public.pem
discovery.uri=https://localhost:8443
internal-communication.shared-secret=dev-secret
internal-communication.https.required=true
http-server.process-forwarded=true
4 changes: 4 additions & 0 deletions soda-trino/local_instance/trino-config/jvm.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-server
-Xmx1G
-XX:+UseG1GC
-XX:+ExitOnOutOfMemoryError
2 changes: 2 additions & 0 deletions soda-trino/local_instance/trino-config/node.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
node.environment=local
node.data-dir=/data/trino
26 changes: 26 additions & 0 deletions soda-trino/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[project]
name = "soda-trino"
version = "4.0.7rc0"
description = "Soda Trino V4"
requires-python = ">=3.10"
license = {text = "Proprietary"}
authors = [
{name = "Soda Data N.V.", email = "info@soda.io"}
]
dependencies = [
"soda-core==4.0.7rc0",
"trino>=0.336.0"
]

[project.entry-points."soda.plugins.data_source.trino"]
trinoDataSourceImpl = "soda_trino.common.data_sources.trino_data_source:trinoDataSourceImpl"

[tool.uv.sources]
soda-core = { workspace = true }

[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
package-dir = {"" = "src"}
35 changes: 35 additions & 0 deletions soda-trino/src/soda_trino/common/data_sources/trino_data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
from typing import Optional

from soda_core.common.data_source_connection import DataSourceConnection
from soda_core.common.data_source_impl import DataSourceImpl
from soda_core.common.logging_constants import soda_logger
from soda_core.common.sql_dialect import SqlDialect
from soda_trino.common.data_sources.trino_data_source_connection import (
TrinoDataSource as TrinoDataSourceModel,
)
from soda_trino.common.data_sources.trino_data_source_connection import (
TrinoDataSourceConnection,
)

logger: logging.Logger = soda_logger


# placeholder file


class TrinoDataSourceImpl(DataSourceImpl, model_class=TrinoDataSourceModel):
def __init__(self, data_source_model: TrinoDataSourceModel, connection: Optional[DataSourceConnection] = None):
super().__init__(data_source_model=data_source_model, connection=connection)

def _create_sql_dialect(self) -> SqlDialect:
return TrinoSqlDialect(data_source_impl=self)

def _create_data_source_connection(self) -> DataSourceConnection:
return TrinoDataSourceConnection(
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
)


class TrinoSqlDialect(SqlDialect):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

import logging
from abc import ABC
from typing import Literal, Optional, Union

import trino
from pydantic import BaseModel, Field
from soda_core.common.logging_constants import soda_logger

logger: logging.Logger = soda_logger


from soda_core.common.data_source_connection import DataSourceConnection
from soda_core.model.data_source.data_source import DataSourceBase
from soda_core.model.data_source.data_source_connection_properties import (
DataSourceConnectionProperties,
)


class TrinoConnectionProperties(DataSourceConnectionProperties):
host: str = Field(..., description="Database host")
catalog: str = Field(..., description="Database catalog")
port: str = Field("443", description="Database port")
http_scheme: Literal["https", "http"] = Field("https", description="HTTP scheme")
http_headers: Optional[dict[str, str]] = Field(None, description="HTTP headers")
source: str = Field("soda-core", description="Source")
client_tags: Optional[list[str]] = Field(None, description="Client tags")
verify: Optional[bool] = Field(True, description="Verify SSL certificate")


class TrinoUserPasswordConnectionProperties(TrinoConnectionProperties):
# Default if authType not specified
auth_type: Optional[Literal["BasicAuthentication"]] = Field(
"BasicAuthentication", description="Authentication type"
)
user: str = Field(..., description="Database username")
password: str = Field(..., description="Database password")


class TrinoJWTConnectionProperties(TrinoConnectionProperties):
auth_type: Literal["JWTAuthentication"] = Field(description="Authentication type")
access_token: str = Field(..., description="JWT access token")
user: Optional[str] = Field(None, description="Database username")


class TrinoOauthPayload(BaseModel):
token_url: str = Field(..., description="Token URL")
client_id: str = Field(..., description="Client ID")
client_secret: str = Field(..., description="Client secret")
scope: Optional[str] = Field(None, description="Scope")
grant_type: Optional[str] = Field("client_credentials", description="Grant type")


class TrinoOauthConnectionProperties(TrinoConnectionProperties):
auth_type: Literal["OAuth2ClientCredentialsAuthentication"] = Field(description="Authentication type")
oauth: TrinoOauthPayload = Field(..., description="OAuth configuration")
user: Optional[str] = Field(None, description="Database username")


class TrinoNoAuthenticationConnectionProperties(TrinoConnectionProperties):
auth_type: Literal["NoAuthentication"] = Field(description="Authentication type")


class TrinoDataSource(DataSourceBase, ABC):
type: Literal["trino"] = Field("trino")

connection_properties: Union[
TrinoUserPasswordConnectionProperties,
TrinoJWTConnectionProperties,
TrinoOauthConnectionProperties,
TrinoNoAuthenticationConnectionProperties,
] = Field(..., alias="connection", description="Trino connection configuration")


class TrinoDataSourceConnection(DataSourceConnection):
def __init__(self, name: str, connection_properties: DataSourceConnectionProperties):
super().__init__(name, connection_properties)

def _create_connection(
self,
config: TrinoConnectionProperties,
):
if isinstance(config, TrinoUserPasswordConnectionProperties):
self.auth = trino.auth.BasicAuthentication(config.user, config.password)
elif isinstance(config, TrinoJWTConnectionProperties):
self.auth = trino.auth.JWTAuthentication(token=config.access_token)
elif isinstance(config, TrinoOauthConnectionProperties):
# Use OAuth to get a JWT access token
# Note, this is a JWTAuthentication flow, not to be confused with OAuth2Authentication which launches a web browser
token = self._exchange_oauth_for_access_token(config.oauth)
self.auth = trino.auth.JWTAuthentication(token=token)
elif isinstance(config, TrinoNoAuthenticationConnectionProperties):
self.auth = None
else:
raise ValueError(f"Unrecognized Trino authentication type: {config.authType}")

connect_kwargs = {
"host": config.host,
"port": config.port,
"catalog": config.catalog,
"http_scheme": config.http_scheme,
"auth": self.auth,
"http_headers": config.http_headers,
"source": config.source,
"client_tags": config.client_tags,
"verify": config.verify,
}

if getattr(config, "user"):
connect_kwargs["user"] = config.user
return trino.dbapi.connect(**connect_kwargs)

def _exchange_oauth_for_access_token(self, oauth: TrinoOauthPayload) -> str:
if not oauth:
raise ValueError("OAuth configuration is required for OAuth2ClientCredentialsAuthentication")

token_url = oauth.token_url
client_id = oauth.client_id
client_secret = oauth.client_secret
scope = oauth.scope
grant_type = oauth.grant_type

import requests

# OAuth credentials
payload = {"client_id": client_id, "client_secret": client_secret, "grant_type": grant_type}
if scope:
payload["scope"] = scope
response = requests.post(token_url, data=payload)
if response.status_code == 200:
response_json = response.json()
expires_in = response_json.get("expires_in", 0)
scope = response_json.get("scope", "")
access_token = response_json["access_token"]
if access_token:
logger.info(
f"Obtained OAuth access token, expires in '{expires_in}' seconds, granted scopes: '{scope}'"
)
return access_token
else:
raise ValueError(
f"OAuth request did not return an access token: {response.status_code} {response.text}"
)
else:
raise ValueError(f"OAuth request failed: {response.status_code} {response.text}")
Loading
Loading