diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 04c0eb0aebfb..a2125b469e3e 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -130,6 +130,7 @@ def get_default_instructions( - get_dashboard_layout: Get parsed tabs and chart positions for a dashboard (companion to get_dashboard_info when its omitted_fields hint flags position_json) - generate_dashboard: Create a dashboard from chart IDs (requires write access) - add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access) +- manage_native_filters: Add, update, remove, or reorder native filters on a dashboard (requires write access; supports filter_select and filter_time) Annotation Layers: - list_annotation_layers: List annotation layers with advanced filters (1-based pagination) @@ -414,7 +415,8 @@ def get_default_instructions( {_instance_info_role_bullet}- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets, charts, or dashboards). SQL execution is a separate permission — see execute_sql below. - Write tools (generate_chart, generate_dashboard, update_chart, create_virtual_dataset, - save_sql_query, add_chart_to_existing_dashboard, update_chart_preview) require write + save_sql_query, add_chart_to_existing_dashboard, manage_native_filters, + update_chart_preview) require write permissions. These tools are only listed for users who have the necessary access. If a write tool does not appear in the tool list, the current user lacks write access. - execute_sql requires SQL Lab access (execute_sql_query permission), which is separate @@ -683,6 +685,7 @@ def create_mcp_app( get_dashboard_info, get_dashboard_layout, list_dashboards, + manage_native_filters, ) from superset.mcp_service.database.tool import ( # noqa: F401, E402 get_database_info, diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index 73a57a6e292e..0495e64a77f6 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -1298,3 +1298,206 @@ def dashboard_layout_serializer(dashboard: "Dashboard") -> DashboardLayout: has_layout=bool(position_json_str), ) ) + + +# --------------------------------------------------------------------------- +# manage_native_filters schemas +# --------------------------------------------------------------------------- + + +class BaseNewFilterSpec(BaseModel): + """Common fields shared by all new native filter specs.""" + + name: str = Field(..., min_length=1, description="Filter display name") + description: str = Field("", description="Optional filter description") + scope_chart_ids: List[int] | None = Field( + None, + description=( + "Chart IDs this filter should apply to. When omitted the filter " + "applies to all charts on the dashboard. All IDs must belong to " + "charts that are on the dashboard." + ), + ) + + +class FilterSelectSpec(BaseNewFilterSpec): + """Spec for a new dropdown (filter_select) native filter.""" + + filter_type: Literal["filter_select"] = Field( + ..., description="Discriminator - must be 'filter_select'" + ) + dataset_id: int = Field(..., description="ID of the dataset to filter on") + column: str = Field( + ..., min_length=1, description="Name of the dataset column to filter on" + ) + multi_select: bool = Field( + True, description="Allow selecting multiple values (default True)" + ) + default_to_first_item: bool = Field( + False, description="Default the filter to the first item in the list" + ) + enable_empty_filter: bool = Field( + False, description="Require a value before the filter is applied" + ) + sort_ascending: bool | None = Field( + None, + description=( + "Sort filter values ascending (True) or descending (False). " + "When omitted, values are not explicitly sorted." + ), + ) + search_all_options: bool = Field( + False, description="Query the database on search rather than client-side" + ) + + +class FilterTimeSpec(BaseNewFilterSpec): + """Spec for a new time range (filter_time) native filter.""" + + filter_type: Literal["filter_time"] = Field( + ..., description="Discriminator - must be 'filter_time'" + ) + default_time_range: str | None = Field( + None, + description=( + "Default time range value, e.g. 'Last week', 'Last month', " + "'2024-01-01 : 2024-12-31'. When omitted the filter has no default." + ), + ) + + +NewNativeFilterSpec = Annotated[ + FilterSelectSpec | FilterTimeSpec, + Field(discriminator="filter_type"), +] + + +class NativeFilterUpdateSpec(BaseModel): + """Partial update for an existing native filter. + + Only ``id`` is required; any other provided field is merged into the + existing filter configuration. Fields that only apply to one filter + type (e.g. ``multi_select`` for filter_select, ``default_time_range`` + for filter_time) are rejected when used on the wrong filter type. + """ + + id: str = Field(..., min_length=1, description="ID of the filter to update") + name: str | None = Field(None, min_length=1, description="New display name") + description: str | None = Field(None, description="New description") + dataset_id: int | None = Field( + None, description="New target dataset ID (filter_select only)" + ) + column: str | None = Field( + None, min_length=1, description="New target column name (filter_select only)" + ) + multi_select: bool | None = Field( + None, description="Allow multiple values (filter_select only)" + ) + default_to_first_item: bool | None = Field( + None, description="Default to first item (filter_select only)" + ) + enable_empty_filter: bool | None = Field( + None, description="Require a value (filter_select only)" + ) + sort_ascending: bool | None = Field( + None, description="Sort values ascending/descending (filter_select only)" + ) + search_all_options: bool | None = Field( + None, description="Search all options in the database (filter_select only)" + ) + default_time_range: str | None = Field( + None, description="Default time range (filter_time only)" + ) + scope_chart_ids: List[int] | None = Field( + None, + description=( + "Chart IDs this filter should apply to. Replaces the current " + "scope. All IDs must belong to charts on the dashboard." + ), + ) + + +class ManageNativeFiltersRequest(BaseModel): + """Request schema for the manage_native_filters tool.""" + + dashboard_id: int = Field(..., description="ID of the dashboard to modify") + add: List[NewNativeFilterSpec] = Field( + default_factory=list, + description=( + "New filters to create. Supported types: filter_select " + "(dropdown) and filter_time (time range). Other filter types " + "(numerical range, time column, time grain) are not yet " + "supported by this tool." + ), + ) + update: List[NativeFilterUpdateSpec] = Field( + default_factory=list, + description="Partial updates to existing filters, addressed by filter ID", + ) + remove: List[str] = Field( + default_factory=list, + description="IDs of filters to delete from the dashboard", + ) + reorder: List[str] | None = Field( + None, + description=( + "Complete ordered list of filter IDs defining the new filter " + "order. Must include every filter that remains on the dashboard " + "(after removals); newly added filters are appended " + "automatically and may be omitted." + ), + ) + + @model_validator(mode="after") + def _require_at_least_one_operation(self) -> "ManageNativeFiltersRequest": + if not self.add and not self.update and not self.remove and not self.reorder: + raise ValueError( + "At least one operation (add, update, remove, reorder) is required" + ) + return self + + +class ManageNativeFiltersResponse(BaseModel): + """Response schema for the manage_native_filters tool.""" + + dashboard_id: int | None = Field(None, description="ID of the dashboard") + dashboard_url: str | None = Field( + None, description="URL to view the updated dashboard" + ) + added_filter_ids: List[str] = Field( + default_factory=list, + description=( + "Server-generated IDs of the newly created filters, in request order" + ), + ) + updated_filter_ids: List[str] = Field( + default_factory=list, description="IDs of the filters that were updated" + ) + removed_filter_ids: List[str] = Field( + default_factory=list, description="IDs of the filters that were removed" + ) + filters: List[NativeFilterSummary] = Field( + default_factory=list, + description="Final native filter configuration after the operation, in order", + ) + error: str | None = Field(None, description="Error message, if operation failed") + permission_denied: bool = Field( + default=False, + description=( + "True when the operation failed because the current user does " + "not have edit rights on the target dashboard." + ), + ) + + @field_validator("error") + @classmethod + def sanitize_error_for_llm_context(cls, value: str | None) -> str | None: + """Wrap error text before it is exposed to LLM context. + + The error may echo user-supplied filter names or dashboard-controlled + metadata - both must be wrapped so the LLM treats them as data, not + instructions. + """ + if value is None: + return value + return sanitize_for_llm_context(value, field_path=("error",)) diff --git a/superset/mcp_service/dashboard/tool/__init__.py b/superset/mcp_service/dashboard/tool/__init__.py index 389acfb192ab..8eaadf108985 100644 --- a/superset/mcp_service/dashboard/tool/__init__.py +++ b/superset/mcp_service/dashboard/tool/__init__.py @@ -20,6 +20,7 @@ from .get_dashboard_info import get_dashboard_info from .get_dashboard_layout import get_dashboard_layout from .list_dashboards import list_dashboards +from .manage_native_filters import manage_native_filters __all__ = [ "list_dashboards", @@ -27,4 +28,5 @@ "get_dashboard_layout", "generate_dashboard", "add_chart_to_existing_dashboard", + "manage_native_filters", ] diff --git a/superset/mcp_service/dashboard/tool/manage_native_filters.py b/superset/mcp_service/dashboard/tool/manage_native_filters.py new file mode 100644 index 000000000000..331df30dbe50 --- /dev/null +++ b/superset/mcp_service/dashboard/tool/manage_native_filters.py @@ -0,0 +1,450 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: manage_native_filters + +Adds, updates, removes, and reorders native filters on a dashboard by +translating high-level operations into the ``deleted`` / ``modified`` / +``reordered`` payload consumed by ``UpdateDashboardNativeFiltersCommand``. +""" + +import copy +import logging +from typing import Any + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.dashboard.constants import generate_id +from superset.mcp_service.dashboard.schemas import ( + FilterSelectSpec, + FilterTimeSpec, + ManageNativeFiltersRequest, + ManageNativeFiltersResponse, + NativeFilterSummary, + NativeFilterUpdateSpec, +) +from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.utils import json + +logger = logging.getLogger(__name__) + +# Control values that map to filter_select controlValues keys. +_SELECT_CONTROL_FIELDS: dict[str, str] = { + "multi_select": "multiSelect", + "default_to_first_item": "defaultToFirstItem", + "enable_empty_filter": "enableEmptyFilter", + "sort_ascending": "sortAscending", + "search_all_options": "searchAllOptions", +} + + +class _FilterValidationError(Exception): + """Raised internally when a filter operation fails validation.""" + + +def _empty_data_mask() -> dict[str, Any]: + return {"filterState": {"value": None}, "extraFormData": {}} + + +def _time_data_mask(default_time_range: str | None) -> dict[str, Any]: + if not default_time_range: + return _empty_data_mask() + return { + "filterState": {"value": default_time_range}, + "extraFormData": {"time_range": default_time_range}, + } + + +def _validate_dataset_column(dataset_id: int, column: str) -> None: + """Validate that the dataset exists and contains the given column.""" + from superset.daos.dataset import DatasetDAO + + dataset = DatasetDAO.find_by_id(dataset_id) + if not dataset: + raise _FilterValidationError( + f"Dataset with ID {dataset_id} not found." + " Use list_datasets to get valid dataset IDs." + ) + column_names = [c.column_name for c in dataset.columns] + if column not in column_names: + raise _FilterValidationError( + f"Column '{column}' not found in dataset {dataset_id}. " + f"Available columns: {', '.join(sorted(column_names))}." + ) + + +def _build_scope( + scope_chart_ids: list[int] | None, + dashboard_chart_ids: list[int], +) -> dict[str, Any]: + """Translate scope_chart_ids into the frontend scope structure. + + The frontend expresses scope as an exclusion list, so charts NOT in + ``scope_chart_ids`` are excluded. When ``scope_chart_ids`` is None + the filter applies to all charts (empty exclusion list). + """ + if scope_chart_ids is None: + return {"rootPath": ["ROOT_ID"], "excluded": []} + unknown = sorted(set(scope_chart_ids) - set(dashboard_chart_ids)) + if unknown: + raise _FilterValidationError( + f"scope_chart_ids contains chart IDs not on the dashboard: " + f"{unknown}. Charts on this dashboard: {sorted(dashboard_chart_ids)}." + ) + excluded = sorted(set(dashboard_chart_ids) - set(scope_chart_ids)) + return {"rootPath": ["ROOT_ID"], "excluded": excluded} + + +def _build_new_filter_config( + spec: FilterSelectSpec | FilterTimeSpec, + dashboard_chart_ids: list[int], +) -> dict[str, Any]: + """Build a full native filter config dict for a new filter.""" + scope = _build_scope(spec.scope_chart_ids, dashboard_chart_ids) + filter_id = generate_id("NATIVE_FILTER") + + if isinstance(spec, FilterSelectSpec): + _validate_dataset_column(spec.dataset_id, spec.column) + control_values: dict[str, Any] = { + "multiSelect": spec.multi_select, + "defaultToFirstItem": spec.default_to_first_item, + "enableEmptyFilter": spec.enable_empty_filter, + "searchAllOptions": spec.search_all_options, + } + if spec.sort_ascending is not None: + control_values["sortAscending"] = spec.sort_ascending + return { + "id": filter_id, + "type": "NATIVE_FILTER", + "filterType": "filter_select", + "name": spec.name, + "description": spec.description, + "scope": scope, + "targets": [ + {"datasetId": spec.dataset_id, "column": {"name": spec.column}} + ], + "controlValues": control_values, + "defaultDataMask": _empty_data_mask(), + "cascadeParentIds": [], + } + + # filter_time: no dataset target, empty controlValues + return { + "id": filter_id, + "type": "NATIVE_FILTER", + "filterType": "filter_time", + "name": spec.name, + "description": spec.description, + "scope": scope, + "targets": [{}], + "controlValues": {}, + "defaultDataMask": _time_data_mask(spec.default_time_range), + "cascadeParentIds": [], + } + + +def _validate_update_type_compat( + spec: NativeFilterUpdateSpec, filter_type: str | None +) -> None: + """Reject update fields that do not apply to the filter's type.""" + select_fields_set = [ + field + for field in (*_SELECT_CONTROL_FIELDS, "dataset_id", "column") + if getattr(spec, field) is not None + ] + if filter_type != "filter_select" and select_fields_set: + raise _FilterValidationError( + f"Filter '{spec.id}' has type '{filter_type}'; fields " + f"{select_fields_set} only apply to filter_select filters." + ) + if filter_type != "filter_time" and spec.default_time_range is not None: + raise _FilterValidationError( + f"Filter '{spec.id}' has type '{filter_type}'; default_time_range " + "only applies to filter_time filters." + ) + + +def _merge_target(spec: NativeFilterUpdateSpec, merged: dict[str, Any]) -> None: + """Merge dataset_id / column changes into the filter's first target.""" + targets = merged.get("targets") or [{}] + target = dict(targets[0]) if targets else {} + dataset_id = ( + spec.dataset_id if spec.dataset_id is not None else target.get("datasetId") + ) + column = ( + spec.column + if spec.column is not None + else (target.get("column") or {}).get("name") + ) + if dataset_id is None or not column: + raise _FilterValidationError( + f"Filter '{spec.id}' is missing a dataset or column target; " + "provide both dataset_id and column to set the target." + ) + _validate_dataset_column(dataset_id, column) + target["datasetId"] = dataset_id + target["column"] = {"name": column} + merged["targets"] = [target] + + +def _merge_filter_update( + spec: NativeFilterUpdateSpec, + existing: dict[str, Any], + dashboard_chart_ids: list[int], +) -> dict[str, Any]: + """Merge a partial update into an existing filter config. + + Returns a FULL filter config (the backend command substitutes whole + entries, it does not merge deltas). + """ + merged = copy.deepcopy(existing) + _validate_update_type_compat(spec, merged.get("filterType")) + + if spec.name is not None: + merged["name"] = spec.name + if spec.description is not None: + merged["description"] = spec.description + if spec.scope_chart_ids is not None: + merged["scope"] = _build_scope(spec.scope_chart_ids, dashboard_chart_ids) + if spec.dataset_id is not None or spec.column is not None: + _merge_target(spec, merged) + + control_values = dict(merged.get("controlValues") or {}) + for field, control_key in _SELECT_CONTROL_FIELDS.items(): + value = getattr(spec, field) + if value is not None: + control_values[control_key] = value + merged["controlValues"] = control_values + + if spec.default_time_range is not None: + merged["defaultDataMask"] = _time_data_mask(spec.default_time_range) + + return merged + + +def _filter_summary(conf: dict[str, Any]) -> NativeFilterSummary: + return NativeFilterSummary( + id=conf.get("id"), + name=conf.get("name"), + filter_type=conf.get("filterType"), + targets=[t for t in (conf.get("targets") or []) if t], + ) + + +def _build_native_filters_payload( # noqa: C901 + request: ManageNativeFiltersRequest, + current_config: list[dict[str, Any]], + dashboard_chart_ids: list[int], +) -> tuple[dict[str, Any], list[str], list[str]]: + """Translate tool operations into the command payload. + + Returns ``(payload, added_filter_ids, updated_filter_ids)`` where the + payload has the ``deleted`` / ``modified`` / ``reordered`` shape expected + by ``UpdateDashboardNativeFiltersCommand``. + """ + current_by_id = {conf["id"]: conf for conf in current_config if conf.get("id")} + + unknown_removals = [fid for fid in request.remove if fid not in current_by_id] + if unknown_removals: + raise _FilterValidationError( + f"Cannot remove filters that do not exist on the dashboard: " + f"{unknown_removals}. Existing filter IDs: " + f"{sorted(current_by_id)}." + ) + + removed_ids = set(request.remove) + modified: list[dict[str, Any]] = [] + updated_filter_ids: list[str] = [] + + for update_spec in request.update: + if update_spec.id in removed_ids: + raise _FilterValidationError( + f"Filter '{update_spec.id}' cannot be both updated and removed." + ) + existing = current_by_id.get(update_spec.id) + if existing is None: + raise _FilterValidationError( + f"Cannot update filter '{update_spec.id}': not found on the " + f"dashboard. Existing filter IDs: {sorted(current_by_id)}." + ) + modified.append( + _merge_filter_update(update_spec, existing, dashboard_chart_ids) + ) + updated_filter_ids.append(update_spec.id) + + added_filter_ids: list[str] = [] + for new_spec in request.add: + config = _build_new_filter_config(new_spec, dashboard_chart_ids) + modified.append(config) + added_filter_ids.append(config["id"]) + + payload: dict[str, Any] = {} + if request.remove: + payload["deleted"] = list(request.remove) + if modified: + payload["modified"] = modified + + if request.reorder is not None: + # The DAO drops any surviving filter that is absent from the + # reordered list, so require a complete ordering of surviving + # pre-existing filters. Newly added filters are appended + # automatically by the DAO and may be omitted. + surviving_ids = set(current_by_id) - removed_ids + reorder_ids = [fid for fid in request.reorder if fid not in added_filter_ids] + if len(set(request.reorder)) != len(request.reorder): + raise _FilterValidationError("reorder contains duplicate filter IDs.") + missing = sorted(surviving_ids - set(reorder_ids)) + unknown = sorted(set(reorder_ids) - surviving_ids) + if missing or unknown: + raise _FilterValidationError( + "reorder must list every remaining filter exactly once. " + f"Missing: {missing}. Unknown: {unknown}. " + f"Remaining filter IDs: {sorted(surviving_ids)}." + ) + payload["reordered"] = list(request.reorder) + + return payload, added_filter_ids, updated_filter_ids + + +@tool( + tags=["mutate"], + class_permission_name="Dashboard", + method_permission_name="write", + annotations=ToolAnnotations( + title="Manage dashboard native filters", + readOnlyHint=False, + destructiveHint=True, + ), +) +def manage_native_filters( + request: ManageNativeFiltersRequest, ctx: Context +) -> ManageNativeFiltersResponse: + """ + Add, update, remove, and reorder native filters on a dashboard. + + Supported filter types for new filters: filter_select (dropdown backed + by a dataset column) and filter_time (time range). Other filter types + (numerical range, time column, time grain) are not yet supported. + Filter IDs are generated by the server and returned in the response. + """ + from superset.commands.dashboard.exceptions import ( + DashboardForbiddenError, + DashboardInvalidError, + DashboardNativeFiltersUpdateFailedError, + DashboardNotFoundError, + ) + from superset.commands.dashboard.update import ( + UpdateDashboardNativeFiltersCommand, + ) + from superset.commands.exceptions import TagForbiddenError + from superset.daos.dashboard import DashboardDAO + + try: + with event_logger.log_context(action="mcp.manage_native_filters.validation"): + dashboard = DashboardDAO.find_by_id(request.dashboard_id) + if not dashboard: + return ManageNativeFiltersResponse( + error=( + f"Dashboard with ID {request.dashboard_id} not found." + " Use list_dashboards to get valid dashboard IDs." + ), + ) + + try: + metadata = json.loads(dashboard.json_metadata or "{}") + except (json.JSONDecodeError, TypeError): + metadata = {} + current_config = metadata.get("native_filter_configuration") or [] + dashboard_chart_ids = [slc.id for slc in dashboard.slices] + + try: + payload, added_ids, updated_ids = _build_native_filters_payload( + request, current_config, dashboard_chart_ids + ) + except _FilterValidationError as exc: + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + error=str(exc), + ) + + with event_logger.log_context(action="mcp.manage_native_filters.db_write"): + configuration = UpdateDashboardNativeFiltersCommand( + request.dashboard_id, payload + ).run() + + dashboard_url = ( + f"{get_superset_base_url()}/superset/dashboard/{request.dashboard_id}/" + ) + logger.info( + "Managed native filters on dashboard %s (added=%d updated=%d removed=%d)", + request.dashboard_id, + len(added_ids), + len(updated_ids), + len(request.remove), + ) + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + dashboard_url=dashboard_url, + added_filter_ids=added_ids, + updated_filter_ids=updated_ids, + removed_filter_ids=list(request.remove), + filters=[_filter_summary(conf) for conf in configuration], + ) + + except DashboardNotFoundError: + return ManageNativeFiltersResponse( + error=( + f"Dashboard with ID {request.dashboard_id} not found." + " Use list_dashboards to get valid dashboard IDs." + ), + ) + except DashboardForbiddenError: + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + permission_denied=True, + error=( + f"You don't have permission to edit dashboard " + f"{request.dashboard_id}. Changing native filters requires " + "ownership of the dashboard." + ), + ) + except TagForbiddenError as exc: + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + permission_denied=True, + error=str(exc), + ) + except DashboardInvalidError as exc: + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + error=f"Invalid dashboard update: {exc.normalized_messages()}", + ) + except DashboardNativeFiltersUpdateFailedError as exc: + return ManageNativeFiltersResponse( + dashboard_id=request.dashboard_id, + error=f"Failed to update native filters: {exc}", + ) + except Exception as exc: + logger.exception( + "Unexpected error managing native filters on dashboard %s: %s", + request.dashboard_id, + exc, + ) + raise diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_manage_native_filters.py b/tests/unit_tests/mcp_service/dashboard/tool/test_manage_native_filters.py new file mode 100644 index 000000000000..a746beb4cc8e --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_manage_native_filters.py @@ -0,0 +1,588 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for the manage_native_filters MCP tool. + +Follows the pattern from test_add_chart_to_existing_dashboard.py: +- Tests run through the async MCP Client (not direct function calls) +- Patches applied at source locations (superset.daos.dashboard.*, etc.) +- auth is mocked via the autouse mock_auth fixture + +Covers: +- Adding a filter_select filter (full config shape, scope translation) +- Adding a filter_time filter (with default time range) +- Updating a filter (merge produces a FULL config, not a delta) +- Removing a filter +- Reordering filters (including incomplete-reorder validation) +- Invalid dataset / column errors +- Dashboard not found +- Permission denied (DashboardForbiddenError) +""" + +import logging +from typing import Any +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.commands.dashboard.exceptions import DashboardForbiddenError +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +DAO_FIND_BY_ID = "superset.daos.dashboard.DashboardDAO.find_by_id" +DATASET_FIND_BY_ID = "superset.daos.dataset.DatasetDAO.find_by_id" +COMMAND_PATH = "superset.commands.dashboard.update.UpdateDashboardNativeFiltersCommand" + + +@pytest.fixture +def mcp_server() -> object: + """Return the FastMCP app instance for use in MCP client tests.""" + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +EXISTING_SELECT_FILTER = { + "id": "NATIVE_FILTER-existing1", + "type": "NATIVE_FILTER", + "filterType": "filter_select", + "name": "Region", + "description": "", + "scope": {"rootPath": ["ROOT_ID"], "excluded": []}, + "targets": [{"datasetId": 5, "column": {"name": "region"}}], + "controlValues": { + "multiSelect": True, + "defaultToFirstItem": False, + "enableEmptyFilter": False, + "searchAllOptions": False, + }, + "defaultDataMask": {"filterState": {"value": None}, "extraFormData": {}}, + "cascadeParentIds": [], +} + +EXISTING_TIME_FILTER = { + "id": "NATIVE_FILTER-existing2", + "type": "NATIVE_FILTER", + "filterType": "filter_time", + "name": "Time Range", + "description": "", + "scope": {"rootPath": ["ROOT_ID"], "excluded": []}, + "targets": [{}], + "controlValues": {}, + "defaultDataMask": {"filterState": {"value": None}, "extraFormData": {}}, + "cascadeParentIds": [], +} + + +def _mock_dashboard( + id: int = 1, + filters: list[dict[str, Any]] | None = None, + chart_ids: list[int] | None = None, +) -> Mock: + dashboard = Mock() + dashboard.id = id + dashboard.dashboard_title = "Test Dashboard" + dashboard.json_metadata = json.dumps({"native_filter_configuration": filters or []}) + slices = [] + for chart_id in chart_ids or [10, 11]: + slc = Mock() + slc.id = chart_id + slices.append(slc) + dashboard.slices = slices + return dashboard + + +def _mock_dataset(columns: list[str] | None = None) -> Mock: + dataset = Mock() + dataset.id = 5 + cols = [] + for name in columns or ["region", "country", "ds"]: + col = Mock() + col.column_name = name + cols.append(col) + dataset.columns = cols + return dataset + + +def _mock_command(captured: dict[str, Any]): + """Build a mock UpdateDashboardNativeFiltersCommand class. + + Captures constructor args and returns the modified configuration + the way the real DAO would (existing filters with substitutions, + new filters appended, deletions removed). + """ + + def command_factory(dashboard_id, payload): + captured["dashboard_id"] = dashboard_id + captured["payload"] = payload + + command = Mock() + + def run(): + current = captured.get("current_config", []) + deleted = payload.get("deleted", []) + modified = payload.get("modified", []) + result = [] + for conf in current: + if conf["id"] in deleted: + continue + replacement = next((m for m in modified if m["id"] == conf["id"]), None) + result.append(replacement if replacement else conf) + for m in modified: + if m["id"] not in [c["id"] for c in result]: + result.append(m) + if reordered := list(payload.get("reordered", [])): + for m in modified: + if m["id"] not in reordered: + reordered.append(m["id"]) + by_id = {c["id"]: c for c in result} + result = [by_id[fid] for fid in reordered if fid in by_id] + captured["result"] = result + return result + + command.run = run + return command + + return command_factory + + +async def _call(mcp_server, request: dict[str, Any]) -> dict[str, Any]: + async with Client(mcp_server) as client: + result = await client.call_tool("manage_native_filters", {"request": request}) + return json.loads(result.content[0].text) + + +# --------------------------------------------------------------------------- +# Add +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_filter_select(mcp_server): + captured: dict = {"current_config": []} + dashboard = _mock_dashboard(filters=[], chart_ids=[10, 11, 12]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(DATASET_FIND_BY_ID, return_value=_mock_dataset()), + patch(COMMAND_PATH, side_effect=_mock_command(captured)), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "add": [ + { + "filter_type": "filter_select", + "name": "Region", + "dataset_id": 5, + "column": "region", + "multi_select": False, + "default_to_first_item": True, + "enable_empty_filter": True, + "sort_ascending": False, + "search_all_options": True, + "scope_chart_ids": [10, 11], + } + ], + }, + ) + + assert data["error"] is None + assert len(data["added_filter_ids"]) == 1 + new_id = data["added_filter_ids"][0] + assert new_id.startswith("NATIVE_FILTER-") + + payload = captured["payload"] + assert "deleted" not in payload + assert "reordered" not in payload + assert len(payload["modified"]) == 1 + config = payload["modified"][0] + assert config == { + "id": new_id, + "type": "NATIVE_FILTER", + "filterType": "filter_select", + "name": "Region", + "description": "", + "scope": {"rootPath": ["ROOT_ID"], "excluded": [12]}, + "targets": [{"datasetId": 5, "column": {"name": "region"}}], + "controlValues": { + "multiSelect": False, + "defaultToFirstItem": True, + "enableEmptyFilter": True, + "searchAllOptions": True, + "sortAscending": False, + }, + "defaultDataMask": {"filterState": {"value": None}, "extraFormData": {}}, + "cascadeParentIds": [], + } + assert data["filters"][0]["id"] == new_id + assert data["filters"][0]["filter_type"] == "filter_select" + + +@pytest.mark.asyncio +async def test_add_filter_time(mcp_server): + captured: dict = {"current_config": []} + dashboard = _mock_dashboard(filters=[]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(COMMAND_PATH, side_effect=_mock_command(captured)), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "add": [ + { + "filter_type": "filter_time", + "name": "Time Range", + "default_time_range": "Last week", + } + ], + }, + ) + + assert data["error"] is None + new_id = data["added_filter_ids"][0] + config = captured["payload"]["modified"][0] + assert config["id"] == new_id + assert config["type"] == "NATIVE_FILTER" + assert config["filterType"] == "filter_time" + assert config["targets"] == [{}] + assert config["controlValues"] == {} + assert config["scope"] == {"rootPath": ["ROOT_ID"], "excluded": []} + assert config["defaultDataMask"] == { + "filterState": {"value": "Last week"}, + "extraFormData": {"time_range": "Last week"}, + } + + +# --------------------------------------------------------------------------- +# Update +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_update_merge_produces_full_config(mcp_server): + captured: dict = {"current_config": [EXISTING_SELECT_FILTER]} + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(DATASET_FIND_BY_ID, return_value=_mock_dataset()), + patch(COMMAND_PATH, side_effect=_mock_command(captured)), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "update": [ + { + "id": "NATIVE_FILTER-existing1", + "name": "Region (updated)", + "column": "country", + "multi_select": False, + } + ], + }, + ) + + assert data["error"] is None + assert data["updated_filter_ids"] == ["NATIVE_FILTER-existing1"] + + config = captured["payload"]["modified"][0] + # Full config substituted, not a delta: untouched fields preserved + assert config["id"] == "NATIVE_FILTER-existing1" + assert config["type"] == "NATIVE_FILTER" + assert config["filterType"] == "filter_select" + assert config["name"] == "Region (updated)" + assert config["targets"] == [{"datasetId": 5, "column": {"name": "country"}}] + assert config["controlValues"]["multiSelect"] is False + # Untouched control values preserved from the existing config + assert config["controlValues"]["enableEmptyFilter"] is False + assert config["controlValues"]["searchAllOptions"] is False + assert config["defaultDataMask"] == EXISTING_SELECT_FILTER["defaultDataMask"] + assert config["cascadeParentIds"] == [] + assert config["scope"] == EXISTING_SELECT_FILTER["scope"] + + +@pytest.mark.asyncio +async def test_update_unknown_filter_id(mcp_server): + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER]) + + with patch(DAO_FIND_BY_ID, return_value=dashboard): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "update": [{"id": "NATIVE_FILTER-nope", "name": "X"}], + }, + ) + + assert "not found on the" in data["error"] + assert "NATIVE_FILTER-existing1" in data["error"] + + +@pytest.mark.asyncio +async def test_update_time_field_on_select_filter_rejected(mcp_server): + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER]) + + with patch(DAO_FIND_BY_ID, return_value=dashboard): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "update": [ + { + "id": "NATIVE_FILTER-existing1", + "default_time_range": "Last week", + } + ], + }, + ) + + assert "default_time_range" in data["error"] + assert "filter_time" in data["error"] + + +# --------------------------------------------------------------------------- +# Remove +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_remove_filter(mcp_server): + captured: dict = {"current_config": [EXISTING_SELECT_FILTER, EXISTING_TIME_FILTER]} + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER, EXISTING_TIME_FILTER]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(COMMAND_PATH, side_effect=_mock_command(captured)), + ): + data = await _call( + mcp_server, + {"dashboard_id": 1, "remove": ["NATIVE_FILTER-existing1"]}, + ) + + assert data["error"] is None + assert data["removed_filter_ids"] == ["NATIVE_FILTER-existing1"] + assert captured["payload"] == {"deleted": ["NATIVE_FILTER-existing1"]} + assert [f["id"] for f in data["filters"]] == ["NATIVE_FILTER-existing2"] + + +@pytest.mark.asyncio +async def test_remove_unknown_filter_id(mcp_server): + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER]) + + with patch(DAO_FIND_BY_ID, return_value=dashboard): + data = await _call( + mcp_server, + {"dashboard_id": 1, "remove": ["NATIVE_FILTER-nope"]}, + ) + + assert "do not exist" in data["error"] + + +# --------------------------------------------------------------------------- +# Reorder +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reorder_filters(mcp_server): + captured: dict = {"current_config": [EXISTING_SELECT_FILTER, EXISTING_TIME_FILTER]} + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER, EXISTING_TIME_FILTER]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(COMMAND_PATH, side_effect=_mock_command(captured)), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "reorder": [ + "NATIVE_FILTER-existing2", + "NATIVE_FILTER-existing1", + ], + }, + ) + + assert data["error"] is None + assert captured["payload"] == { + "reordered": ["NATIVE_FILTER-existing2", "NATIVE_FILTER-existing1"] + } + assert [f["id"] for f in data["filters"]] == [ + "NATIVE_FILTER-existing2", + "NATIVE_FILTER-existing1", + ] + + +@pytest.mark.asyncio +async def test_reorder_must_include_all_filters(mcp_server): + """The DAO silently drops filters missing from the reordered list, + so the tool must reject incomplete reorders.""" + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER, EXISTING_TIME_FILTER]) + + with patch(DAO_FIND_BY_ID, return_value=dashboard): + data = await _call( + mcp_server, + {"dashboard_id": 1, "reorder": ["NATIVE_FILTER-existing2"]}, + ) + + assert "every remaining filter" in data["error"] + assert "NATIVE_FILTER-existing1" in data["error"] + + +# --------------------------------------------------------------------------- +# Validation errors +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_with_invalid_dataset(mcp_server): + dashboard = _mock_dashboard(filters=[]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(DATASET_FIND_BY_ID, return_value=None), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "add": [ + { + "filter_type": "filter_select", + "name": "Region", + "dataset_id": 999, + "column": "region", + } + ], + }, + ) + + assert "Dataset with ID 999 not found" in data["error"] + + +@pytest.mark.asyncio +async def test_add_with_invalid_column(mcp_server): + dashboard = _mock_dashboard(filters=[]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(DATASET_FIND_BY_ID, return_value=_mock_dataset(["region", "ds"])), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "add": [ + { + "filter_type": "filter_select", + "name": "Region", + "dataset_id": 5, + "column": "nonexistent", + } + ], + }, + ) + + assert "Column 'nonexistent' not found in dataset 5" in data["error"] + assert "region" in data["error"] + + +@pytest.mark.asyncio +async def test_scope_chart_ids_not_on_dashboard(mcp_server): + dashboard = _mock_dashboard(filters=[], chart_ids=[10, 11]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(DATASET_FIND_BY_ID, return_value=_mock_dataset()), + ): + data = await _call( + mcp_server, + { + "dashboard_id": 1, + "add": [ + { + "filter_type": "filter_select", + "name": "Region", + "dataset_id": 5, + "column": "region", + "scope_chart_ids": [10, 99], + } + ], + }, + ) + + assert "not on the dashboard" in data["error"] + assert "99" in data["error"] + + +# --------------------------------------------------------------------------- +# Not found / forbidden +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dashboard_not_found(mcp_server): + with patch(DAO_FIND_BY_ID, return_value=None): + data = await _call( + mcp_server, + {"dashboard_id": 42, "remove": ["NATIVE_FILTER-x"]}, + ) + + assert "Dashboard with ID 42 not found" in data["error"] + assert data["permission_denied"] is False + + +@pytest.mark.asyncio +async def test_dashboard_forbidden(mcp_server): + dashboard = _mock_dashboard(filters=[EXISTING_SELECT_FILTER]) + + with ( + patch(DAO_FIND_BY_ID, return_value=dashboard), + patch(COMMAND_PATH, side_effect=DashboardForbiddenError), + ): + data = await _call( + mcp_server, + {"dashboard_id": 1, "remove": ["NATIVE_FILTER-existing1"]}, + ) + + assert data["permission_denied"] is True + assert "permission" in data["error"]