Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7f5e847
Clean-up param protection
sfc-gh-aalam Nov 20, 2024
72d1623
clean-up testing and precommits
sfc-gh-aalam Nov 20, 2024
51a0384
fix typo
sfc-gh-aalam Nov 20, 2024
c25b157
fix test
sfc-gh-aalam Nov 20, 2024
751b3c3
fix test
sfc-gh-aalam Nov 20, 2024
7f62935
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Nov 22, 2024
dfd792c
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Nov 25, 2024
ee15875
merge
sfc-gh-aalam Dec 19, 2024
d060db8
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Jan 2, 2025
dc7cae3
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Jan 7, 2025
a8bb18c
fix pytest assumption issues
sfc-gh-aalam Jan 7, 2025
fa3d182
fix lint
sfc-gh-aalam Jan 7, 2025
98c1ca0
make sql counter thread-safe
sfc-gh-aalam Jan 7, 2025
5f8e146
fix sql counts
sfc-gh-aalam Jan 7, 2025
8b9aa94
use thread start and join to ensure all threads are completed
sfc-gh-aalam Jan 7, 2025
911036c
fix sql counter called
sfc-gh-aalam Jan 8, 2025
d11b3c3
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Jan 8, 2025
9ebc09d
remove lock because it is not required
sfc-gh-aalam Jan 9, 2025
4a255b1
readd lock
sfc-gh-aalam Jan 10, 2025
928d3c0
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Jan 13, 2025
e652ae3
Merge branch 'main' into aalam-SNOW-1720855-clean-up-multithreading-c…
sfc-gh-aalam Jan 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions .github/workflows/daily_precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -546,75 +546,6 @@ jobs:
.tox/.coverage
.tox/coverage.xml

test-snowpark-disable-multithreading-mode:
name: Test Snowpark Multithreading Disabled py-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
needs: build
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest-64-cores]
python-version: ["3.9"]
cloud-provider: [aws]
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Decrypt parameters.py
shell: bash
run: .github/scripts/decrypt_parameters.sh
env:
PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }}
CLOUD_PROVIDER: ${{ matrix.cloud-provider }}
- name: Install protoc
shell: bash
run: .github/scripts/install_protoc.sh
- name: Download wheel(s)
uses: actions/download-artifact@v4
with:
name: wheel
path: dist
- name: Show wheels downloaded
run: ls -lh dist
shell: bash
- name: Upgrade setuptools, pip and wheel
run: python -m pip install -U setuptools pip wheel
- name: Install tox
run: python -m pip install tox
- name: Run tests (excluding doctests)
run: python -m tox -e "py${PYTHON_VERSION/\./}-notmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Run local tests
run: python -m tox -e "py${PYTHON_VERSION/\./}-localnotmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Combine coverages
run: python -m tox -e coverage --skip-missing-interpreters false
shell: bash
env:
SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1
- uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage_${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}-snowpark-multithreading
path: |
.tox/.coverage
.tox/coverage.xml

combine-coverage:
if: ${{ success() || failure() }}
name: Combine coverage
Expand Down
26 changes: 5 additions & 21 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
generate_random_alphanumeric,
get_copy_into_table_options,
is_sql_select_statement,
random_name_for_temp_object,
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import StructType
Expand Down Expand Up @@ -664,12 +663,8 @@ def large_local_relation_plan(
source_plan: Optional[LogicalPlan],
schema_query: Optional[str],
) -> SnowflakePlan:
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
temp_table_name = (
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
temp_table_name = f"temp_name_placeholder_{generate_random_alphanumeric()}"

attributes = [
Attribute(attr.name, attr.datatype, attr.nullable) for attr in output
]
Expand All @@ -696,9 +691,7 @@ def large_local_relation_plan(
Query(
create_table_stmt,
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
BatchInsertQuery(insert_stmt, data),
Query(select_stmt),
Expand Down Expand Up @@ -1215,7 +1208,6 @@ def read_file(
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
):
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
format_type_options, copy_options = get_copy_into_table_options(options)
format_type_options = self._merge_file_format_options(
format_type_options, options
Expand Down Expand Up @@ -1247,8 +1239,6 @@ def read_file(
post_queries: List[Query] = []
format_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
queries.append(
Query(
Expand All @@ -1262,9 +1252,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT),
)
)
post_queries.append(
Expand Down Expand Up @@ -1323,8 +1311,6 @@ def read_file(

temp_table_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
queries = [
Query(
Expand All @@ -1337,9 +1323,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
Query(
copy_into_table(
Expand Down
60 changes: 28 additions & 32 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,35 +204,31 @@ def replace_temp_obj_placeholders(
To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object
here.
"""
session = self._plan.session
if session._conn._thread_safe_session_enabled:
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries

return queries
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries
10 changes: 2 additions & 8 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
escape_quotes,
get_application_name,
get_version,
Expand Down Expand Up @@ -173,12 +171,8 @@ def __init__(
except TypeError:
pass

# thread safe param protection
self._thread_safe_session_enabled = self._get_client_side_session_parameter(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._thread_store = create_thread_local(self._thread_safe_session_enabled)
self._lock = threading.RLock()
self._thread_store = threading.local()

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
from threading import RLock
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
from snowflake.snowpark._internal.utils import create_rlock, is_in_stored_procedure
from snowflake.snowpark._internal.utils import is_in_stored_procedure

_logger = logging.getLogger(__name__)

Expand All @@ -34,7 +35,7 @@ def __init__(self, session: "Session") -> None:
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = create_rlock(session._conn._thread_safe_session_enabled)
self.lock = RLock()

def add(self, table: SnowflakeTable) -> None:
with self.lock:
Expand Down
48 changes: 1 addition & 47 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,7 @@ def normalize_path(path: str, is_local: bool) -> str:
return f"'{path}'"


def warn_session_config_update_in_multithreaded_mode(
config: str, thread_safe_mode_enabled: bool
) -> None:
if not thread_safe_mode_enabled:
return

def warn_session_config_update_in_multithreaded_mode(config: str) -> None:
if threading.active_count() > 1:
_logger.warning(
"You might have more than one threads sharing the Session object trying to update "
Expand Down Expand Up @@ -763,47 +758,6 @@ def warning(self, text: str) -> None:
self.count += 1


# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout
class DummyRLock:
"""This is a dummy lock that is used in place of threading.Rlock when multithreading is
disabled."""

def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def acquire(self, *args, **kwargs):
pass # pragma: no cover

def release(self, *args, **kwargs):
pass # pragma: no cover


class DummyThreadLocal:
"""This is a dummy thread local class that is used in place of threading.local when
multithreading is disabled."""

pass


def create_thread_local(
thread_safe_session_enabled: bool,
) -> Union[threading.local, DummyThreadLocal]:
if thread_safe_session_enabled:
return threading.local()
return DummyThreadLocal()


def create_rlock(
thread_safe_session_enabled: bool,
) -> Union[threading.RLock, DummyRLock]:
if thread_safe_session_enabled:
return threading.RLock()
return DummyRLock()


warning_dict: Dict[str, WarningHelper] = {}


Expand Down
8 changes: 2 additions & 6 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import json
import logging
import threading
import uuid
from copy import copy
from decimal import Decimal
Expand All @@ -30,7 +31,6 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE
from snowflake.snowpark._internal.utils import (
create_rlock,
is_in_stored_procedure,
result_set_to_rows,
)
Expand Down Expand Up @@ -297,11 +297,7 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
self._cursor = Mock()
self._options = options or {}
session_params = self._options.get("session_parameters", {})
# thread safe param protection
self._thread_safe_session_enabled = session_params.get(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._lock = threading.RLock()
self._lower_case_parameters = {}
self._query_listeners = set()
self._telemetry_client = Mock()
Expand Down
Loading
Loading