Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars

from airflow._shared.observability.metrics import stats
from airflow._shared.observability.traces import override_ids
from airflow._shared.state import TaskScope
from airflow._shared.timezones import timezone
Expand Down Expand Up @@ -467,6 +468,14 @@ def ti_update_state(
extra=json.dumps({"host_name": hostname}) if hostname else None,
)
)
# Commit the TI state update now to release the task_instance row lock before
# running asset-event queries. The direct-INSERT fix in AssetManager removes
# the O(n) lazy-load on the alias-event table, but register_asset_changes_in_db
# also queries scheduled dags and inserts AssetDagRunQueue rows - all of which
# would otherwise hold the row lock and cause idle-in-transaction pile-up that
# exhausts API server memory and triggers OOMKill under high concurrency.
# The task outcome is durable from this point on.
session.commit()
Comment thread
hkc-8010 marked this conversation as resolved.
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
Expand All @@ -490,14 +499,40 @@ def ti_update_state(
task_id=task_id,
map_index=map_index,
)
session.commit()
except Exception:
session.rollback()
log.warning(
"Failed to clear task state on success",
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
)

# Asset registration runs outside the TI row lock. Failures are logged and counted;
# raising HTTP 500 here would be misleading because the task already succeeded and
# would make the worker retry a state update that has already completed. Durable
# retry/reconciliation for dropped asset events is out of scope for this hot-path fix.
if isinstance(ti_patch_payload, TISuccessStatePayload) and ti_patch_payload.task_outlets:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gates on the payload type, but updated_state can have been flipped to FAILED by the time we reach here. If _create_ti_state_update_query_and_update_state raises, the except block commits the TI as failed (query.values(state=(updated_state := TaskInstanceState.FAILED)), line 445). The payload is still a TISuccessStatePayload, so this block runs anyway and emits asset events + queues downstream AssetDagRunQueue rows for a task that actually ended up failed.

On main the registration ran inside _create_ti_state_update_query_and_update_state under that same try, so a raised exception skipped it. Moving it out here and switching the guard from the committed state to the payload type drops that coupling.

The state_store clear block right above gates on if updated_state == TaskInstanceState.SUCCESS: (line 485). Suggest the same guard here so a forced failure can't still register assets.

try:
ti_for_assets = session.get(TI, task_instance_id)
if ti_for_assets is not None:
TI.register_asset_changes_in_db(
ti_for_assets,
ti_patch_payload.task_outlets,
ti_patch_payload.outlet_events,
session=session,
)
session.commit()
except Exception:
session.rollback()
stats.incr("asset.registration_failures")
log.exception(
"Failed to register asset changes; task state is already committed",
task_instance_id=str(task_instance_id),
new_state=updated_state,
)


def _emit_task_span(ti, state):
# just to be safe
Expand Down Expand Up @@ -586,13 +621,7 @@ def _create_ti_state_update_query_and_update_state(
retry_reason=(ti_patch_payload.retry_reason[:500] if ti_patch_payload.retry_reason else None),
)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
if ti is not None:
TI.register_asset_changes_in_db(
ti,
ti_patch_payload.task_outlets,
ti_patch_payload.outlet_events,
session=session,
)
pass # Asset registration happens after the TI state is committed; see ti_update_state.
try:
_emit_task_span(ti, state=updated_state)
except Exception:
Expand Down
22 changes: 16 additions & 6 deletions airflow-core/src/airflow/assets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import exc, or_, select
from sqlalchemy import exc, insert, or_, select
from sqlalchemy.orm import joinedload

from airflow._shared.observability.metrics import stats
Expand All @@ -40,6 +40,7 @@
DagScheduleAssetReference,
DagScheduleAssetUriReference,
PartitionedAssetKeyLog,
asset_alias_asset_event_association_table,
)
from airflow.models.log import Log
from airflow.utils.helpers import is_container
Expand Down Expand Up @@ -327,8 +328,17 @@ def register_asset_change(
).unique()

for asset_alias_model in asset_alias_models:
asset_alias_model.asset_events.append(asset_event)
session.add(asset_alias_model)
# Use a direct INSERT rather than ORM .append() to avoid lazy-loading the
# entire asset_events collection. On long-running deployments that collection
# can contain thousands of rows; loading it on the task-success hot path can
# leave DB connections idle-in-transaction for minutes, blocking other workers.
# This intentionally leaves asset_alias_model.asset_events unsynced in-session.
session.execute(
Comment thread
hkc-8010 marked this conversation as resolved.
insert(asset_alias_asset_event_association_table).values(
alias_id=asset_alias_model.id,
event_id=asset_event.id,
)
)

dags_to_queue_from_asset_alias |= {
alias_ref.dag
Expand Down Expand Up @@ -465,9 +475,9 @@ def _queue_dagruns(
# constraint violation.
#
# If we support it, use ON CONFLICT to do nothing, otherwise
# "fallback" to running this in a nested transaction. This is needed
# so that the adding of these rows happens in the same transaction
# where `ti.state` is changed.
# "fallback" to running this in a nested transaction. Some callers
# run this as part of a TI state transaction; the Execution API commits
# the TI state first, then runs asset registration in a separate transaction.
if get_dialect_name(session) == "postgresql":
return cls._queue_dagruns_nonpartitioned_postgres(asset_id, non_partitioned_dags, session)
return cls._queue_dagruns_nonpartitioned_slow_path(asset_id, non_partitioned_dags, session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,190 @@ def test_ti_update_state_running_errors(self, client, session, create_task_insta

assert response.status_code == 422

def test_ti_update_state_to_success_asset_registration_failure_returns_204(
self, client, session, create_task_instance
):
"""Regression: asset registration failure after TI state commit must return 204, not 500.

The TI state is committed (and the row lock released) before asset registration runs.
If registration fails at that point, the task outcome is already durable as SUCCESS,
so surfacing HTTP 500 would be misleading and cause unnecessary worker retries.
"""
asset = AssetModel(
id=42,
name="fail-asset",
uri="s3://bucket/fail-asset",
group="asset",
extra={},
)
asset_active = AssetActive.for_asset(asset)
session.add_all([asset, asset_active])

ti = create_task_instance(
task_id="test_asset_reg_failure",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

with (
mock.patch(
"airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db",
side_effect=Exception("simulated DB explosion during asset registration"),
),
mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr,
):
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [
{"name": "fail-asset", "uri": "s3://bucket/fail-asset", "type": "Asset"}
],
"outlet_events": [],
},
)

assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}"
session.expire_all()
ti_db = session.get(TaskInstance, ti.id)
assert ti_db is not None
assert ti_db.state == TaskInstanceState.SUCCESS
mock_incr.assert_any_call("asset.registration_failures")

def test_ti_update_state_rolls_back_partial_asset_registration_on_failure(
self, client, session, create_task_instance
):
asset = AssetModel(
id=43,
name="partial-asset",
uri="s3://bucket/partial-asset",
group="asset",
extra={},
)
session.add_all([asset, AssetActive.for_asset(asset)])

ti = create_task_instance(
task_id="test_partial_asset_registration_failure",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

def add_event_then_fail(ti, task_outlets, outlet_events, session):
session.add(
AssetEvent(
asset_id=asset.id,
extra={"partial": True},
source_task_id=ti.task_id,
source_dag_id=ti.dag_id,
source_run_id=ti.run_id,
source_map_index=ti.map_index,
)
)
session.flush()
raise RuntimeError("simulated failure after partial asset registration")

with (
mock.patch(
"airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db",
side_effect=add_event_then_fail,
),
mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr,
):
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [
{"name": "partial-asset", "uri": "s3://bucket/partial-asset", "type": "Asset"}
],
"outlet_events": [],
},
)

assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}"
session.expire_all()
ti_db = session.get(TaskInstance, ti.id)
assert ti_db is not None
assert ti_db.state == TaskInstanceState.SUCCESS
assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == []
mock_incr.assert_any_call("asset.registration_failures")

def test_ti_update_state_swallow_asset_registration_commit_failure(
self, client, session, create_task_instance
):
asset = AssetModel(
id=44,
name="commit-fail-asset",
uri="s3://bucket/commit-fail-asset",
group="asset",
extra={},
)
session.add_all([asset, AssetActive.for_asset(asset)])

ti = create_task_instance(
task_id="test_asset_registration_commit_failure",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

real_register_asset_changes_in_db = TaskInstance.register_asset_changes_in_db
real_commit = Session.commit
asset_registration_started = False
failed_asset_commit = False

def register_asset_changes_then_mark_started(ti, task_outlets, outlet_events, *, session):
nonlocal asset_registration_started
real_register_asset_changes_in_db(ti, task_outlets, outlet_events, session=session)
asset_registration_started = True

def fail_asset_registration_commit(session):
nonlocal failed_asset_commit
if asset_registration_started and not failed_asset_commit:
failed_asset_commit = True
raise RuntimeError("simulated asset registration commit failure")
return real_commit(session)

with (
mock.patch(
"airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db",
side_effect=register_asset_changes_then_mark_started,
),
mock.patch(
"airflow.api_fastapi.common.db.common.Session.commit",
fail_asset_registration_commit,
),
mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.stats.incr") as mock_incr,
):
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [
{
"name": "commit-fail-asset",
"uri": "s3://bucket/commit-fail-asset",
"type": "Asset",
}
],
"outlet_events": [],
},
)

assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}"
assert failed_asset_commit
session.expire_all()
ti_db = session.get(TaskInstance, ti.id)
assert ti_db is not None
assert ti_db.state == TaskInstanceState.SUCCESS
assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == []
mock_incr.assert_any_call("asset.registration_failures")

def test_ti_update_state_database_error(self, client, session, create_task_instance):
"""
Test that a database error is handled correctly when updating the Task Instance state.
Expand Down Expand Up @@ -1977,6 +2161,74 @@ def test_ti_update_state_to_success_clears_task_state(self, client, session, cre
session.expire_all()
assert not session.scalars(select(TaskStoreModel).where(TaskStoreModel.task_id == ti.task_id)).all()

@pytest.mark.db_test
@conf_vars({("state_store", "clear_on_success"): "True"})
def test_asset_registration_failure_does_not_rollback_successful_task_state_clear(
self, client, session, create_task_instance
):
asset = AssetModel(
id=44,
name="partial-asset-with-state-clear",
uri="s3://bucket/partial-asset-with-state-clear",
group="asset",
extra={},
)
session.add_all([asset, AssetActive.for_asset(asset)])

ti = create_task_instance(
task_id="test_asset_failure_after_state_clear",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

backend = MetastoreStoreBackend()
scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, task_id=ti.task_id, map_index=ti.map_index)
backend.set(scope, "job_id", "app_1234", session=session)
session.commit()

def add_event_then_fail(ti, task_outlets, outlet_events, session):
session.add(
AssetEvent(
asset_id=asset.id,
extra={"partial": True},
source_task_id=ti.task_id,
source_dag_id=ti.dag_id,
source_run_id=ti.run_id,
source_map_index=ti.map_index,
)
)
session.flush()
raise RuntimeError("simulated failure after state clear")

with mock.patch(
"airflow.models.taskinstance.TaskInstance.register_asset_changes_in_db",
side_effect=add_event_then_fail,
):
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [
{
"name": "partial-asset-with-state-clear",
"uri": "s3://bucket/partial-asset-with-state-clear",
"type": "Asset",
}
],
"outlet_events": [],
},
)

assert response.status_code == 204, f"Expected 204, got {response.status_code}: {response.text}"
session.expire_all()
ti_db = session.get(TaskInstance, ti.id)
assert ti_db is not None
assert ti_db.state == TaskInstanceState.SUCCESS
assert not session.scalars(select(TaskStoreModel).where(TaskStoreModel.task_id == ti.task_id)).all()
assert session.scalars(select(AssetEvent).where(AssetEvent.asset_id == asset.id)).all() == []

@pytest.mark.db_test
@conf_vars({("state_store", "clear_on_success"): "True"})
def test_ti_update_state_to_failed_does_not_clear_task_state(self, client, session, create_task_instance):
Expand Down
Loading
Loading