diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 2afd96806c473..122b08ced23b5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -40,6 +40,7 @@ from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars +from airflow._shared.observability.metrics import stats from airflow._shared.observability.traces import override_ids from airflow._shared.state import TaskScope from airflow._shared.timezones import timezone @@ -467,6 +468,14 @@ def ti_update_state( extra=json.dumps({"host_name": hostname}) if hostname else None, ) ) + # Commit the TI state update now to release the task_instance row lock before + # running asset-event queries. The direct-INSERT fix in AssetManager removes + # the O(n) lazy-load on the alias-event table, but register_asset_changes_in_db + # also queries scheduled dags and inserts AssetDagRunQueue rows - all of which + # would otherwise hold the row lock and cause idle-in-transaction pile-up that + # exhausts API server memory and triggers OOMKill under high concurrency. + # The task outcome is durable from this point on. + session.commit() except SQLAlchemyError as e: log.error("Error updating Task Instance state", error=str(e)) raise HTTPException( @@ -490,7 +499,9 @@ def ti_update_state( task_id=task_id, map_index=map_index, ) + session.commit() except Exception: + session.rollback() log.warning( "Failed to clear task state on success", dag_id=dag_id, @@ -498,6 +509,30 @@ def ti_update_state( task_id=task_id, ) + # Asset registration runs outside the TI row lock. Failures are logged and counted; + # raising HTTP 500 here would be misleading because the task already succeeded and + # would make the worker retry a state update that has already completed. Durable + # retry/reconciliation for dropped asset events is out of scope for this hot-path fix. + if isinstance(ti_patch_payload, TISuccessStatePayload) and ti_patch_payload.task_outlets: + try: + ti_for_assets = session.get(TI, task_instance_id) + if ti_for_assets is not None: + TI.register_asset_changes_in_db( + ti_for_assets, + ti_patch_payload.task_outlets, + ti_patch_payload.outlet_events, + session=session, + ) + session.commit() + except Exception: + session.rollback() + stats.incr("asset.registration_failures") + log.exception( + "Failed to register asset changes; task state is already committed", + task_instance_id=str(task_instance_id), + new_state=updated_state, + ) + def _emit_task_span(ti, state): # just to be safe @@ -586,13 +621,7 @@ def _create_ti_state_update_query_and_update_state( retry_reason=(ti_patch_payload.retry_reason[:500] if ti_patch_payload.retry_reason else None), ) elif isinstance(ti_patch_payload, TISuccessStatePayload): - if ti is not None: - TI.register_asset_changes_in_db( - ti, - ti_patch_payload.task_outlets, - ti_patch_payload.outlet_events, - session=session, - ) + pass # Asset registration happens after the TI state is committed; see ti_update_state. try: _emit_task_span(ti, state=updated_state) except Exception: diff --git a/airflow-core/src/airflow/assets/manager.py b/airflow-core/src/airflow/assets/manager.py index 7c12c31f979d6..03c4067496f53 100644 --- a/airflow-core/src/airflow/assets/manager.py +++ b/airflow-core/src/airflow/assets/manager.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING import structlog -from sqlalchemy import exc, or_, select +from sqlalchemy import exc, insert, or_, select from sqlalchemy.orm import joinedload from airflow._shared.observability.metrics import stats @@ -40,6 +40,7 @@ DagScheduleAssetReference, DagScheduleAssetUriReference, PartitionedAssetKeyLog, + asset_alias_asset_event_association_table, ) from airflow.models.log import Log from airflow.utils.helpers import is_container @@ -327,8 +328,17 @@ def register_asset_change( ).unique() for asset_alias_model in asset_alias_models: - asset_alias_model.asset_events.append(asset_event) - session.add(asset_alias_model) + # Use a direct INSERT rather than ORM .append() to avoid lazy-loading the + # entire asset_events collection. On long-running deployments that collection + # can contain thousands of rows; loading it on the task-success hot path can + # leave DB connections idle-in-transaction for minutes, blocking other workers. + # This intentionally leaves asset_alias_model.asset_events unsynced in-session. + session.execute( + insert(asset_alias_asset_event_association_table).values( + alias_id=asset_alias_model.id, + event_id=asset_event.id, + ) + ) dags_to_queue_from_asset_alias |= { alias_ref.dag @@ -465,9 +475,9 @@ def _queue_dagruns( # constraint violation. # # If we support it, use ON CONFLICT to do nothing, otherwise - # "fallback" to running this in a nested transaction. This is needed - # so that the adding of these rows happens in the same transaction - # where `ti.state` is changed. + # "fallback" to running this in a nested transaction. Some callers + # run this as part of a TI state transaction; the Execution API commits + # the TI state first, then runs asset registration in a separate transaction. if get_dialect_name(session) == "postgresql": return cls._queue_dagruns_nonpartitioned_postgres(asset_id, non_partitioned_dags, session) return cls._queue_dagruns_nonpartitioned_slow_path(asset_id, non_partitioned_dags, session) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3022bbfea06e3..99563992ec100 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1368,6 +1368,190 @@ def test_ti_update_state_running_errors(self, client, session, create_task_insta assert response.status_code == 422 + def test_ti_update_state_to_success_asset_registration_failure_returns_204( + self, client, session, create_task_instance + ): + """Regression: asset registration failure after TI state commit must return 204, not 500. + + The TI state is committed (and the row lock released) before asset registration runs. + If registration fails at that point, the task outcome is already durable as SUCCESS, + so surfacing HTTP 500 would be misleading and cause unnecessary worker retries. + """ + asset = AssetModel( + id=42, + name="fail-asset", + uri="s3://bucket/fail-asset", + group="asset", + extra={}, + ) + asset_active = AssetActive.for_asset(asset) + session.add_all([asset, asset_active]) + + ti = create_task_instance( + task_id="test_asset_reg_failure", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + with ( + mock.patch( + "airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db", + side_effect=Exception("simulated DB explosion during asset registration"), + ), + mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr, + ): + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": "success", + "end_date": DEFAULT_END_DATE.isoformat(), + "task_outlets": [ + {"name": "fail-asset", "uri": "s3://bucket/fail-asset", "type": "Asset"} + ], + "outlet_events": [], + }, + ) + + assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}" + session.expire_all() + ti_db = session.get(TaskInstance, ti.id) + assert ti_db is not None + assert ti_db.state == TaskInstanceState.SUCCESS + mock_incr.assert_any_call("asset.registration_failures") + + def test_ti_update_state_rolls_back_partial_asset_registration_on_failure( + self, client, session, create_task_instance + ): + asset = AssetModel( + id=43, + name="partial-asset", + uri="s3://bucket/partial-asset", + group="asset", + extra={}, + ) + session.add_all([asset, AssetActive.for_asset(asset)]) + + ti = create_task_instance( + task_id="test_partial_asset_registration_failure", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + def add_event_then_fail(ti, task_outlets, outlet_events, session): + session.add( + AssetEvent( + asset_id=asset.id, + extra={"partial": True}, + source_task_id=ti.task_id, + source_dag_id=ti.dag_id, + source_run_id=ti.run_id, + source_map_index=ti.map_index, + ) + ) + session.flush() + raise RuntimeError("simulated failure after partial asset registration") + + with ( + mock.patch( + "airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db", + side_effect=add_event_then_fail, + ), + mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr, + ): + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": "success", + "end_date": DEFAULT_END_DATE.isoformat(), + "task_outlets": [ + {"name": "partial-asset", "uri": "s3://bucket/partial-asset", "type": "Asset"} + ], + "outlet_events": [], + }, + ) + + assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}" + session.expire_all() + ti_db = session.get(TaskInstance, ti.id) + assert ti_db is not None + assert ti_db.state == TaskInstanceState.SUCCESS + assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == [] + mock_incr.assert_any_call("asset.registration_failures") + + def test_ti_update_state_swallow_asset_registration_commit_failure( + self, client, session, create_task_instance + ): + asset = AssetModel( + id=44, + name="commit-fail-asset", + uri="s3://bucket/commit-fail-asset", + group="asset", + extra={}, + ) + session.add_all([asset, AssetActive.for_asset(asset)]) + + ti = create_task_instance( + task_id="test_asset_registration_commit_failure", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + real_register_asset_changes_in_db = TaskInstance.register_asset_changes_in_db + real_commit = Session.commit + asset_registration_started = False + failed_asset_commit = False + + def register_asset_changes_then_mark_started(ti, task_outlets, outlet_events, *, session): + nonlocal asset_registration_started + real_register_asset_changes_in_db(ti, task_outlets, outlet_events, session=session) + asset_registration_started = True + + def fail_asset_registration_commit(session): + nonlocal failed_asset_commit + if asset_registration_started and not failed_asset_commit: + failed_asset_commit = True + raise RuntimeError("simulated asset registration commit failure") + return real_commit(session) + + with ( + mock.patch( + "airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db", + side_effect=register_asset_changes_then_mark_started, + ), + mock.patch( + "airflow.api_fastapi.common.db.common.Session.commit", + fail_asset_registration_commit, + ), + mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr, + ): + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": "success", + "end_date": DEFAULT_END_DATE.isoformat(), + "task_outlets": [ + { + "name": "commit-fail-asset", + "uri": "s3://bucket/commit-fail-asset", + "type": "Asset", + } + ], + "outlet_events": [], + }, + ) + + assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}" + assert failed_asset_commit + session.expire_all() + ti_db = session.get(TaskInstance, ti.id) + assert ti_db is not None + assert ti_db.state == TaskInstanceState.SUCCESS + assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == [] + mock_incr.assert_any_call("asset.registration_failures") + def test_ti_update_state_database_error(self, client, session, create_task_instance): """ Test that a database error is handled correctly when updating the Task Instance state. @@ -1977,6 +2161,74 @@ def test_ti_update_state_to_success_clears_task_state(self, client, session, cre session.expire_all() assert not session.scalars(select(TaskStoreModel).where(TaskStoreModel.task_id == ti.task_id)).all() + @pytest.mark.db_test + @conf_vars({("state_store", "clear_on_success"): "True"}) + def test_asset_registration_failure_does_not_rollback_successful_task_state_clear( + self, client, session, create_task_instance + ): + asset = AssetModel( + id=44, + name="partial-asset-with-state-clear", + uri="s3://bucket/partial-asset-with-state-clear", + group="asset", + extra={}, + ) + session.add_all([asset, AssetActive.for_asset(asset)]) + + ti = create_task_instance( + task_id="test_asset_failure_after_state_clear", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + backend = MetastoreStoreBackend() + scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, task_id=ti.task_id, map_index=ti.map_index) + backend.set(scope, "job_id", "app_1234", session=session) + session.commit() + + def add_event_then_fail(ti, task_outlets, outlet_events, session): + session.add( + AssetEvent( + asset_id=asset.id, + extra={"partial": True}, + source_task_id=ti.task_id, + source_dag_id=ti.dag_id, + source_run_id=ti.run_id, + source_map_index=ti.map_index, + ) + ) + session.flush() + raise RuntimeError("simulated failure after state clear") + + with mock.patch( + "airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db", + side_effect=add_event_then_fail, + ): + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": "success", + "end_date": DEFAULT_END_DATE.isoformat(), + "task_outlets": [ + { + "name": "partial-asset-with-state-clear", + "uri": "s3://bucket/partial-asset-with-state-clear", + "type": "Asset", + } + ], + "outlet_events": [], + }, + ) + + assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}" + session.expire_all() + ti_db = session.get(TaskInstance, ti.id) + assert ti_db is not None + assert ti_db.state == TaskInstanceState.SUCCESS + assert not session.scalars(select(TaskStoreModel).where(TaskStoreModel.task_id == ti.task_id)).all() + assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == [] + @pytest.mark.db_test @conf_vars({("state_store", "clear_on_success"): "True"}) def test_ti_update_state_to_failed_does_not_clear_task_state(self, client, session, create_task_instance): diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 8f9d290a2b0b2..c22d4faed36d7 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -24,7 +24,7 @@ from unittest import mock import pytest -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, insert, select from sqlalchemy.orm import Session from airflow import settings @@ -37,6 +37,7 @@ AssetPartitionDagRun, DagScheduleAssetAliasReference, DagScheduleAssetReference, + asset_alias_asset_event_association_table, ) from airflow.models.dag import DagModel from airflow.sdk.definitions.asset import Asset @@ -154,6 +155,88 @@ def test_register_asset_change_with_alias( ) assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 2 + def test_register_asset_change_with_alias_no_lazy_load( + self, session, mock_task_instance, testing_dag_bundle + ): + """Regression: alias-event association must use a direct INSERT, not ORM .append(). + + ORM .append() lazy-loads the entire asset_events collection before writing. + On long-running deployments with thousands of past events, this query runs + while the task_instance row lock is held in ti_update_state, causing idle-in-transaction + pile-up that exhausts API server memory and triggers OOMKill. + """ + asm = AssetModel(uri="test://asset-nolazy/", name="test_nolazy_asset", group="asset") + session.add(asm) + asam = AssetAliasModel(name="test_nolazy_alias", group="test") + session.add(asam) + session.flush() + + # Pre-populate existing alias-event rows to simulate a long-running deployment. + # If .append() is used, SQLAlchemy will lazy-load ALL of these before inserting the new one. + existing_events = [AssetEvent(asset_id=asm.id, extra={}) for _ in range(5)] + session.add_all(existing_events) + session.flush() + for ev in existing_events: + session.execute( + insert(asset_alias_asset_event_association_table).values(alias_id=asam.id, event_id=ev.id) + ) + session.flush() + + # Expire the alias so a lazy-load would have to hit the DB (no in-memory cache). + session.expire(asam) + + asset = Asset(uri="test://asset-nolazy", name="test_nolazy_asset") + asset_manager = AssetManager() + + lazy_load_selects: list[str] = [] + real_execute = session.execute + + def tracking_execute(stmt, *args, **kwargs): + try: + compiled = str(stmt.compile(compile_kwargs={"literal_binds": True})) + except Exception: + compiled = str(stmt) + # Detect a lazy-load SELECT joining asset_alias_asset_event with asset_event + if ( + "asset_alias_asset_event" in compiled.lower() + and "asset_event" in compiled.lower() + and compiled.strip().upper().startswith("SELECT") + ): + lazy_load_selects.append(compiled[:120]) + return real_execute(stmt, *args, **kwargs) + + with mock.patch.object(session, "execute", side_effect=tracking_execute): + asset_manager.register_asset_change( + task_instance=mock_task_instance, + asset=asset, + source_alias_names=["test_nolazy_alias"], + session=session, + ) + session.flush() + + # The new association row must exist + new_events = session.scalars( + select(AssetEvent).where( + AssetEvent.asset_id == asm.id, + AssetEvent.id.notin_([ev.id for ev in existing_events]), + ) + ).all() + assert len(new_events) == 1, "Expected exactly one new AssetEvent" + + row_count = session.scalar( + select(func.count()) + .select_from(asset_alias_asset_event_association_table) + .where( + asset_alias_asset_event_association_table.c.alias_id == asam.id, + asset_alias_asset_event_association_table.c.event_id == new_events[0].id, + ) + ) + assert row_count == 1, "Expected the alias-event association row to be written" + + assert lazy_load_selects == [], ( + f"Unexpected lazy-load SELECT on asset_alias_asset_event: {lazy_load_selects}" + ) + def test_register_asset_change_no_downstreams(self, session, mock_task_instance): asset_manager = AssetManager() diff --git a/scripts/ci/prek/check_connection_doc_labels.py b/scripts/ci/prek/check_connection_doc_labels.py index 8dc16b346eaae..e3f0e714a2a35 100755 --- a/scripts/ci/prek/check_connection_doc_labels.py +++ b/scripts/ci/prek/check_connection_doc_labels.py @@ -39,6 +39,7 @@ import re import sys +from os import walk from pathlib import Path from rich.console import Console @@ -62,6 +63,7 @@ TOP_LEVEL_ANCHOR_RE = re.compile(r"^\.\.\s+_howto/connection:([a-zA-Z0-9_-]+):\s*$", re.MULTILINE) ANY_ANCHOR_RE = re.compile(r"^\.\.\s+_(howto/connection:[^\s]+?):\s*$", re.MULTILINE) REF_RE = re.compile(r":ref:`(?:[^`]*<(howto/connection:[^>]+)>|(howto/connection:[^`]+))`") +SKIP_SCAN_DIRS = frozenset({"node_modules", ".pnpm-store"}) def collect_connection_types() -> set[str]: @@ -72,18 +74,26 @@ def collect_connection_types() -> set[str]: return conn_types +def collect_files(root: Path, suffix: str) -> list[Path]: + files: list[Path] = [] + for current_root, dirnames, filenames in walk(root): + dirnames[:] = [dirname for dirname in dirnames if dirname not in SKIP_SCAN_DIRS] + files.extend(Path(current_root, filename) for filename in filenames if filename.endswith(suffix)) + return sorted(files) + + def collect_rst_files() -> list[Path]: - rst_files: list[Path] = list(AIRFLOW_PROVIDERS_ROOT_PATH.rglob("*.rst")) + rst_files: list[Path] = collect_files(AIRFLOW_PROVIDERS_ROOT_PATH, ".rst") core_docs = AIRFLOW_ROOT_PATH / "airflow-core" / "docs" if core_docs.is_dir(): - rst_files.extend(core_docs.rglob("*.rst")) + rst_files.extend(collect_files(core_docs, ".rst")) return rst_files def collect_python_files() -> list[Path]: - py_files: list[Path] = list(AIRFLOW_PROVIDERS_ROOT_PATH.rglob("*.py")) + py_files: list[Path] = collect_files(AIRFLOW_PROVIDERS_ROOT_PATH, ".py") if AIRFLOW_CORE_SOURCES_PATH.is_dir(): - py_files.extend(AIRFLOW_CORE_SOURCES_PATH.rglob("*.py")) + py_files.extend(collect_files(AIRFLOW_CORE_SOURCES_PATH, ".py")) return py_files diff --git a/scripts/tests/ci/prek/test_check_connection_doc_labels.py b/scripts/tests/ci/prek/test_check_connection_doc_labels.py new file mode 100644 index 0000000000000..8775e732451b8 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_connection_doc_labels.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from check_connection_doc_labels import collect_files + + +def test_collect_files_skips_volatile_dependency_directories(tmp_path): + source_file = tmp_path / "provider" / "docs" / "connection.rst" + source_file.parent.mkdir(parents=True) + source_file.touch() + + node_modules_file = tmp_path / "ui" / "node_modules" / "package" / "docs" / "connection.rst" + node_modules_file.parent.mkdir(parents=True) + node_modules_file.touch() + + pnpm_store_file = tmp_path / "ui" / ".pnpm-store" / "package" / "docs" / "connection.rst" + pnpm_store_file.parent.mkdir(parents=True) + pnpm_store_file.touch() + + assert collect_files(tmp_path, ".rst") == [source_file] + + +def test_collect_files_matches_suffix(tmp_path): + python_file = tmp_path / "src" / "module.py" + python_file.parent.mkdir(parents=True) + python_file.touch() + (python_file.parent / "module.pyi").touch() + + assert collect_files(tmp_path, ".py") == [python_file] diff --git a/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml b/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml index 876cc3cc8953b..625867564035a 100644 --- a/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml +++ b/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml @@ -255,6 +255,12 @@ metrics: legacy_name: "-" name_variables: [] + - name: "asset.registration_failures" + description: "Number of task success asset registration failures after the task state was updated" + type: "counter" + legacy_name: "-" + name_variables: [] + - name: "asset.orphaned" description: "Number of assets marked as orphans because they are no longer referenced in Dag schedule parameters or task outlets"