diff --git a/singlestoredb/fusion/handlers/workspace.py b/singlestoredb/fusion/handlers/workspace.py index 9644f95cb..083d265cd 100644 --- a/singlestoredb/fusion/handlers/workspace.py +++ b/singlestoredb/fusion/handlers/workspace.py @@ -15,7 +15,7 @@ class UseWorkspaceHandler(SQLHandler): """ - USE WORKSPACE workspace [ with_database ]; + USE WORKSPACE workspace [ in_group ] [ with_database ]; # Workspace workspace = { workspace_id | workspace_name | current_workspace } @@ -29,6 +29,15 @@ class UseWorkspaceHandler(SQLHandler): # Current workspace current_workspace = @@CURRENT + # Workspace group specification + in_group = IN GROUP { group_id | group_name } + + # ID of workspace group + group_id = ID '' + + # Name of workspace group + group_name = '' + # Name of database with_database = WITH DATABASE 'database-name' @@ -38,13 +47,18 @@ class UseWorkspaceHandler(SQLHandler): Arguments --------- - * ````: The ID of the workspace to delete. - * ````: The name of the workspace to delete. + * ````: The ID of the workspace to use. + * ````: The name of the workspace to use. + * ````: The ID of the workspace group to search in. + * ````: The name of the workspace group to search in. Remarks ------- * If you want to specify a database in the current workspace, the workspace name can be specified as ``@@CURRENT``. + * Use the ``IN GROUP`` clause to specify the ID or name of the workspace + group where the workspace should be found. If not specified, the current + workspace group will be used. * Specify the ``WITH DATABASE`` clause to select a default database for the session. * This command only works in a notebook session in the @@ -57,23 +71,69 @@ class UseWorkspaceHandler(SQLHandler): USE WORKSPACE 'examplews' WITH DATABASE 'dbname'; + The following command sets the workspace to ``examplews`` from a specific + workspace group:: + + USE WORKSPACE 'examplews' IN GROUP 'my-workspace-group'; + """ def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]: from singlestoredb.notebook import portal + + # Handle current workspace case if params['workspace'].get('current_workspace'): if params.get('with_database'): portal.default_database = params['with_database'] - elif params.get('with_database'): - if params['workspace'].get('workspace_name'): - portal.connection = params['workspace']['workspace_name'], \ - params['with_database'] + return None + + # Get workspace name or ID + workspace_name = params['workspace'].get('workspace_name') + workspace_id = params['workspace'].get('workspace_id') + + # If IN GROUP is specified, look up workspace in that group + if params.get('in_group'): + workspace_group = get_workspace_group(params) + + if workspace_name: + workspace = workspace_group.workspaces[workspace_name] + elif workspace_id: + # Find workspace by ID in the specified group + workspace = next( + (w for w in workspace_group.workspaces if w.id == workspace_id), + None, + ) + if workspace is None: + raise KeyError(f'no workspace found with ID: {workspace_id}') + + workspace_id = workspace.id + + # Set workspace and database + if params.get('with_database'): + if params.get('in_group'): + # Use 3-element tuple: (workspace_group_id, workspace_name_or_id, + # database) + portal.connection = ( # type: ignore[assignment] + workspace_group.id, + workspace_name or workspace_id, + params['with_database'], + ) else: - portal.connection = params['workspace']['workspace_id'], \ - params['with_database'] - elif params['workspace'].get('workspace_name'): - portal.workspace = params['workspace']['workspace_name'] + # Use 2-element tuple: (workspace_name_or_id, database) + portal.connection = ( + workspace_name or workspace_id, + params['with_database'], + ) else: - portal.workspace = params['workspace']['workspace_id'] + if params.get('in_group'): + # Use 2-element tuple: (workspace_group_id, workspace_name_or_id) + portal.workspace = ( # type: ignore[assignment] + workspace_group.id, + workspace_name or workspace_id, + ) + else: + # Use string: workspace_name_or_id + portal.workspace = workspace_name or workspace_id + return None diff --git a/singlestoredb/notebook/_portal.py b/singlestoredb/notebook/_portal.py index d960d4aec..664348c1a 100644 --- a/singlestoredb/notebook/_portal.py +++ b/singlestoredb/notebook/_portal.py @@ -10,6 +10,7 @@ from typing import List from typing import Optional from typing import Tuple +from typing import Union from . import _objects as obj from ..management import workspace as mgr @@ -167,15 +168,32 @@ def workspace(self) -> obj.Workspace: return obj.workspace @workspace.setter - def workspace(self, name_or_id: str) -> None: + def workspace(self, workspace_spec: Union[str, Tuple[str, str]]) -> None: """Set workspace.""" - if re.match( - r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}', - name_or_id, flags=re.I, - ): - w = mgr.get_workspace(name_or_id) + if isinstance(workspace_spec, tuple): + # 2-element tuple: (workspace_group_id, workspace_name_or_id) + workspace_group_id, name_or_id = workspace_spec + uuid_pattern = ( + r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}' + ) + if re.match(uuid_pattern, name_or_id, flags=re.I): + w = mgr.get_workspace(name_or_id) + else: + w = mgr.get_workspace_group(workspace_group_id).workspaces[ + name_or_id + ] else: - w = mgr.get_workspace_group(self.workspace_group_id).workspaces[name_or_id] + # String: workspace_name_or_id (existing behavior) + name_or_id = workspace_spec + uuid_pattern = ( + r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}' + ) + if re.match(uuid_pattern, name_or_id, flags=re.I): + w = mgr.get_workspace(name_or_id) + else: + w = mgr.get_workspace_group( + self.workspace_group_id, + ).workspaces[name_or_id] if w.state and w.state.lower() not in ['active', 'resumed']: raise RuntimeError('workspace is not active') @@ -196,16 +214,37 @@ def connection(self) -> Tuple[obj.Workspace, Optional[str]]: return self.workspace, self.default_database @connection.setter - def connection(self, workspace_and_default_database: Tuple[str, str]) -> None: + def connection( + self, + connection_spec: Union[Tuple[str, str], Tuple[str, str, str]], + ) -> None: """Set workspace and default database name.""" - name_or_id, default_database = workspace_and_default_database - if re.match( - r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}', - name_or_id, flags=re.I, - ): - w = mgr.get_workspace(name_or_id) + if len(connection_spec) == 3: + # 3-element tuple: (workspace_group_id, workspace_name_or_id, + # default_database) + workspace_group_id, name_or_id, default_database = connection_spec + uuid_pattern = ( + r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}' + ) + if re.match(uuid_pattern, name_or_id, flags=re.I): + w = mgr.get_workspace(name_or_id) + else: + w = mgr.get_workspace_group(workspace_group_id).workspaces[ + name_or_id + ] else: - w = mgr.get_workspace_group(self.workspace_group_id).workspaces[name_or_id] + # 2-element tuple: (workspace_name_or_id, default_database) + # existing behavior + name_or_id, default_database = connection_spec + uuid_pattern = ( + r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}' + ) + if re.match(uuid_pattern, name_or_id, flags=re.I): + w = mgr.get_workspace(name_or_id) + else: + w = mgr.get_workspace_group( + self.workspace_group_id, + ).workspaces[name_or_id] if w.state and w.state.lower() not in ['active', 'resumed']: raise RuntimeError('workspace is not active')