|
| 1 | +# |
| 2 | +# sambacc: a samba container configuration tool (and more) |
| 3 | +# Copyright (C) 2025 John Mulligan |
| 4 | +# |
| 5 | +# This program is free software: you can redistribute it and/or modify |
| 6 | +# it under the terms of the GNU General Public License as published by |
| 7 | +# the Free Software Foundation, either version 3 of the License, or |
| 8 | +# (at your option) any later version. |
| 9 | +# |
| 10 | +# This program is distributed in the hope that it will be useful, |
| 11 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 12 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 13 | +# GNU General Public License for more details. |
| 14 | +# |
| 15 | +# You should have received a copy of the GNU General Public License |
| 16 | +# along with this program. If not, see <http://www.gnu.org/licenses/> |
| 17 | +# |
| 18 | + |
| 19 | +from typing import Iterator, Protocol, Optional |
| 20 | + |
| 21 | +import concurrent.futures |
| 22 | +import contextlib |
| 23 | +import logging |
| 24 | + |
| 25 | +import grpc |
| 26 | + |
| 27 | +import sambacc.grpc.backend as rbe |
| 28 | +import sambacc.grpc.generated.control_pb2 as pb |
| 29 | +import sambacc.grpc.generated.control_pb2_grpc as control_rpc |
| 30 | + |
| 31 | +_logger = logging.getLogger(__name__) |
| 32 | + |
| 33 | + |
| 34 | +class Backend(Protocol): |
| 35 | + def get_versions(self) -> rbe.Versions: ... |
| 36 | + |
| 37 | + def is_clustered(self) -> bool: ... |
| 38 | + |
| 39 | + def get_status(self) -> rbe.Status: ... |
| 40 | + |
| 41 | + def close_share(self, share_name: str, denied_users: bool) -> None: ... |
| 42 | + |
| 43 | + def kill_client(self, ip_address: str) -> None: ... |
| 44 | + |
| 45 | + |
| 46 | +@contextlib.contextmanager |
| 47 | +def _in_rpc(context: grpc.ServicerContext, allowed: bool) -> Iterator[None]: |
| 48 | + if not allowed: |
| 49 | + _logger.error("Blocking operation") |
| 50 | + context.abort( |
| 51 | + grpc.StatusCode.PERMISSION_DENIED, "Operation not permitted" |
| 52 | + ) |
| 53 | + try: |
| 54 | + yield |
| 55 | + except Exception: |
| 56 | + _logger.exception("exception in rpc call") |
| 57 | + context.abort(grpc.StatusCode.UNKNOWN, "Unexpected server error") |
| 58 | + |
| 59 | + |
| 60 | +def _get_info(backend: Backend) -> pb.GeneralInfo: |
| 61 | + _info = backend.get_versions() |
| 62 | + clustered = backend.is_clustered() |
| 63 | + return pb.GeneralInfo( |
| 64 | + samba_info=pb.SambaInfo( |
| 65 | + version=_info.samba_version, |
| 66 | + clustered=clustered, |
| 67 | + ), |
| 68 | + container_info=pb.SambaContainerInfo( |
| 69 | + sambacc_version=_info.sambacc_version, |
| 70 | + container_version=_info.container_version, |
| 71 | + ), |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +def _convert_session(session: rbe.Session) -> pb.SessionInfo: |
| 76 | + info = pb.SessionInfo( |
| 77 | + session_id=session.session_id, |
| 78 | + username=session.username, |
| 79 | + groupname=session.groupname, |
| 80 | + remote_machine=session.remote_machine, |
| 81 | + hostname=session.hostname, |
| 82 | + session_dialect=session.session_dialect, |
| 83 | + ) |
| 84 | + # python side takes -1 to mean not found uid/gid. in protobufs |
| 85 | + # that would mean the fields are unset |
| 86 | + if session.uid > 0: |
| 87 | + info.uid = session.uid |
| 88 | + if session.gid > 0: |
| 89 | + info.gid = session.gid |
| 90 | + return info |
| 91 | + |
| 92 | + |
| 93 | +def _convert_tcon(tcon: rbe.TreeConnection) -> pb.ConnInfo: |
| 94 | + return pb.ConnInfo( |
| 95 | + tcon_id=tcon.tcon_id, |
| 96 | + session_id=tcon.session_id, |
| 97 | + service_name=tcon.service_name, |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +def _convert_status(status: rbe.Status) -> pb.StatusInfo: |
| 102 | + return pb.StatusInfo( |
| 103 | + server_timestamp=status.timestamp, |
| 104 | + sessions=[_convert_session(s) for s in status.sessions], |
| 105 | + tree_connections=[_convert_tcon(t) for t in status.tcons], |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +class ControlService(control_rpc.SambaControlServicer): |
| 110 | + def __init__(self, backend: Backend, *, read_only: bool = False): |
| 111 | + self._backend = backend |
| 112 | + self._read_only = read_only |
| 113 | + self._ok_to_read = True |
| 114 | + self._ok_to_modify = not read_only |
| 115 | + |
| 116 | + def Info( |
| 117 | + self, request: pb.InfoRequest, context: grpc.ServicerContext |
| 118 | + ) -> pb.GeneralInfo: |
| 119 | + _logger.debug("RPC Called: Info") |
| 120 | + with _in_rpc(context, self._ok_to_read): |
| 121 | + info = _get_info(self._backend) |
| 122 | + return info |
| 123 | + |
| 124 | + def Status( |
| 125 | + self, request: pb.StatusRequest, context: grpc.ServicerContext |
| 126 | + ) -> pb.StatusInfo: |
| 127 | + _logger.debug("RPC Called: Status") |
| 128 | + with _in_rpc(context, self._ok_to_read): |
| 129 | + info = _convert_status(self._backend.get_status()) |
| 130 | + return info |
| 131 | + |
| 132 | + def CloseShare( |
| 133 | + self, request: pb.CloseShareRequest, context: grpc.ServicerContext |
| 134 | + ) -> pb.CloseShareInfo: |
| 135 | + _logger.debug("RPC Called: CloseShare") |
| 136 | + with _in_rpc(context, self._ok_to_modify): |
| 137 | + self._backend.close_share(request.share_name, request.denied_users) |
| 138 | + info = pb.CloseShareInfo() |
| 139 | + return info |
| 140 | + |
| 141 | + def KillClientConnection( |
| 142 | + self, request: pb.KillClientRequest, context: grpc.ServicerContext |
| 143 | + ) -> pb.KillClientInfo: |
| 144 | + _logger.debug("RPC Called: KillClientConnection") |
| 145 | + with _in_rpc(context, self._ok_to_modify): |
| 146 | + self._backend.kill_client(request.ip_address) |
| 147 | + info = pb.KillClientInfo() |
| 148 | + return info |
| 149 | + |
| 150 | + |
| 151 | +class ServerConfig: |
| 152 | + max_workers: int = 8 |
| 153 | + address: str = "localhost:54445" |
| 154 | + read_only: bool = False |
| 155 | + insecure: bool = True |
| 156 | + server_key: Optional[bytes] = None |
| 157 | + server_cert: Optional[bytes] = None |
| 158 | + ca_cert: Optional[bytes] = None |
| 159 | + |
| 160 | + |
| 161 | +def serve(config: ServerConfig, backend: Backend) -> None: |
| 162 | + _logger.info( |
| 163 | + "Starting gRPC server on %s (%s, %s)", |
| 164 | + config.address, |
| 165 | + "insecure" if config.insecure else "tls", |
| 166 | + "read-only" if config.read_only else "read-modify", |
| 167 | + ) |
| 168 | + service = ControlService(backend, read_only=config.read_only) |
| 169 | + executor = concurrent.futures.ThreadPoolExecutor( |
| 170 | + max_workers=config.max_workers |
| 171 | + ) |
| 172 | + server = grpc.server(executor) |
| 173 | + control_rpc.add_SambaControlServicer_to_server(service, server) |
| 174 | + if config.insecure: |
| 175 | + server.add_insecure_port(config.address) |
| 176 | + else: |
| 177 | + if not config.server_key: |
| 178 | + raise ValueError("missing server TLS key") |
| 179 | + if not config.server_cert: |
| 180 | + raise ValueError("missing server TLS cert") |
| 181 | + if config.ca_cert: |
| 182 | + creds = grpc.ssl_server_credentials( |
| 183 | + [(config.server_key, config.server_cert)], |
| 184 | + root_certificates=config.ca_cert, |
| 185 | + require_client_auth=True, |
| 186 | + ) |
| 187 | + else: |
| 188 | + creds = grpc.ssl_server_credentials( |
| 189 | + [(config.server_key, config.server_cert)], |
| 190 | + ) |
| 191 | + server.add_secure_port(config.address, creds) |
| 192 | + server.start() |
| 193 | + # hack for testing |
| 194 | + wait_fn = getattr(config, "wait", None) |
| 195 | + if wait_fn: |
| 196 | + wait_fn(server) |
| 197 | + else: |
| 198 | + server.wait_for_termination() |
0 commit comments