diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 69747ba497d60c..63ee0aa35d4d4f 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -10,6 +10,7 @@ from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan @@ -17,6 +18,8 @@ from models.dataset import Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.feature_service import FeatureService +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -139,11 +142,12 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id] + segment_ids = [segment.id for segment in segments] # delete from vector index index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids=segment_ids) - segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) session.commit() @@ -157,6 +161,41 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st indexing_runner.run(list(documents)) end_at = time.perf_counter() logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + + # Rebuild summary indexes for duplicate uploads after the replacement segments are indexed. + session.expire_all() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if not dataset: + logger.warning("Dataset %s not found after duplicate indexing", dataset_id) + return + + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() + ) + + for document in documents: + if ( + document.indexing_status == IndexingStatus.COMPLETED + and document.doc_form != IndexStructureType.QA_INDEX + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info( + "Queued summary index generation task for duplicate document %s in dataset %s", + document.id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for duplicate document %s", + document.id, + ) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index f6dbc4275b4b0a..c8c70ea25292ac 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -1,12 +1,14 @@ -"""Unit tests for queue/wrapper behaviors in duplicate document indexing tasks (non-database logic).""" +"""Unit tests for duplicate document indexing task behavior.""" import uuid -from unittest.mock import Mock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch import pytest from core.rag.pipeline.queue import TenantIsolatedTaskQueue from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, _duplicate_document_indexing_task_with_tenant_queue, duplicate_document_indexing_task, normal_duplicate_document_indexing_task, @@ -40,6 +42,17 @@ def mock_tenant_isolated_queue(): yield mock_queue +class _SessionContext: + def __init__(self, session): + self.session = session + + def __enter__(self): + return self.session + + def __exit__(self, exc_type, exc, tb): + return False + + class TestDuplicateDocumentIndexingTask: """Tests for the deprecated duplicate_document_indexing_task function.""" @@ -52,6 +65,86 @@ def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_fu # Assert mock_core_func.assert_called_once_with(dataset_id, document_ids) + def test_core_task_deletes_old_summaries_and_queues_summary_regeneration( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Duplicate indexing should refresh summary index data for replaced segments.""" + # Arrange + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + document = SimpleNamespace( + id="doc-1", + dataset_id="dataset-1", + doc_form="text", + indexing_status="completed", + need_summary=True, + ) + indexed_document = SimpleNamespace( + id="doc-1", + dataset_id="dataset-1", + doc_form="text", + indexing_status="completed", + need_summary=True, + ) + segment = SimpleNamespace(id="segment-1", index_node_id="node-1") + + session = MagicMock() + session.scalar.return_value = dataset + session.scalars.side_effect = [ + MagicMock(all=MagicMock(return_value=[document])), + MagicMock(all=MagicMock(return_value=[segment])), + MagicMock(all=MagicMock(return_value=[indexed_document])), + ] + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.session_factory.create_session", + MagicMock(return_value=_SessionContext(session)), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.FeatureService.get_features", + MagicMock(return_value=features), + ) + + index_processor = MagicMock() + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.IndexProcessorFactory", + MagicMock(return_value=MagicMock(init_index_processor=MagicMock(return_value=index_processor))), + ) + + indexing_runner = MagicMock() + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.IndexingRunner", + MagicMock(return_value=indexing_runner), + ) + + delete_summaries_mock = MagicMock() + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.SummaryIndexService", + SimpleNamespace(delete_summaries_for_segments=delete_summaries_mock), + raising=False, + ) + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.duplicate_document_indexing_task.generate_summary_index_task", + SimpleNamespace(delay=delay_mock), + raising=False, + ) + + # Act + _duplicate_document_indexing_task("dataset-1", ["doc-1"]) + + # Assert + delete_summaries_mock.assert_called_once_with(dataset, segment_ids=["segment-1"]) + delay_mock.assert_called_once_with("dataset-1", "doc-1", None) + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): """Test duplicate_document_indexing_task with empty document_ids list.""" diff --git a/api/uv.lock b/api/uv.lock index 1dbbd21356421a..8c7d2b21727283 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1300,6 +1300,7 @@ requires-dist = [ { name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" }, { name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" }, { name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" }, + { name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" }, { name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" }, { name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" }, ] diff --git a/dify-agent/src/dify_agent/layers/shell/layer.py b/dify-agent/src/dify_agent/layers/shell/layer.py index d265075631c817..5e31c0fb391ea3 100644 --- a/dify-agent/src/dify_agent/layers/shell/layer.py +++ b/dify-agent/src/dify_agent/layers/shell/layer.py @@ -256,9 +256,7 @@ def validate_workspace_and_offsets(self) -> Self: raise ValueError("workspace_cwd requires a matching session_id.") expected_workspace = _workspace_cwd(self.session_id) if self.workspace_cwd != expected_workspace: - raise ValueError( - f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}." - ) + raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.") unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids) if unknown_offset_job_ids: names = ", ".join(sorted(unknown_offset_job_ids)) @@ -694,12 +692,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str: of silently reusing another session's workspace. """ safe_session_id = _validated_session_id(session_id) - workspace_dir = f'$HOME/workspace/{safe_session_id}' + workspace_dir = f"$HOME/workspace/{safe_session_id}" return ( 'mkdir -p "$HOME/workspace"; ' f'if mkdir "{workspace_dir}"; then exit 0; fi; ' f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; ' - 'exit 1' + "exit 1" ) diff --git a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py index a2ab4e435c7826..3d4ae3221ea9e0 100644 --- a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py +++ b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py @@ -277,6 +277,7 @@ def factory(entrypoint: str) -> FakeShellctlClient: return next(clients) compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))]) + async def scenario() -> None: async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run: shell_layer = run.get_layer("shell", DifyShellLayer) @@ -342,7 +343,10 @@ async def scenario() -> None: assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")] assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"} - assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls) + assert all( + client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) + for call in client.delete_calls + ) assert all(call.force is True for call in client.delete_calls) assert layer.runtime_state.job_ids == [] assert layer.runtime_state.job_offsets == {} diff --git a/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py b/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py index 799ec94292e09f..8808cf7a9654d1 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py @@ -27,7 +27,9 @@ def factory(entrypoint: str) -> FakeFactoryClient: return factory - monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory) + monkeypatch.setattr( + compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory + ) providers = create_default_layer_providers( shellctl_entrypoint="http://shellctl.example", @@ -56,7 +58,9 @@ def factory(_entrypoint: str) -> FakeFactoryClient: return factory - monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory) + monkeypatch.setattr( + compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory + ) providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example") shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_runner.py b/dify-agent/tests/local/dify_agent/runtime/test_runner.py index 4a899f0a790bda..ec2acae15ca37e 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_runner.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_runner.py @@ -684,7 +684,8 @@ def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: ), ) layer_providers = tuple( - provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused") + provider + for provider in create_default_layer_providers(shellctl_entrypoint="http://unused") if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID ) + (shell_provider,)