diff --git a/python/destinations/MQTT/README.md b/python/destinations/MQTT/README.md index 2f78aff3..482941f9 100644 --- a/python/destinations/MQTT/README.md +++ b/python/destinations/MQTT/README.md @@ -24,11 +24,18 @@ You'll need to have a MQTT either locally or in the cloud The connector uses the following environment variables: - **input**: Name of the input topic to listen to. -- **mqtt_topic_root**: The root for messages in MQTT, this can be anything. -- **mqtt_server**: The address of your MQTT server. -- **mqtt_port**: The port of your MQTT server. -- **mqtt_username**: Username of your MQTT user. -- **mqtt_password**: Password for the MQTT user. +- **MQTT_CLIENT_ID**: A client ID for the sink. + **Default**: `mqtt-sink` +- **MQTT_TOPIC_ROOT**: The root for messages in MQTT, this can be anything. +- **MQTT_SERVER**: The address of your MQTT server. +- **MQTT_PORT**: The port of your MQTT server. + **Default**: `8883` +- **MQTT_USERNAME**: Username of your MQTT user. +- **MQTT_PASSWORD**: Password for the MQTT user. +- **MQTT_VERSION**: MQTT protocol version; choose 3.1, 3.1.1, or 5. + **Default**: `3.1.1` +- **MQTT_USE_TLS**: Set to true if the server uses TLS. + **Default**: `True` ## Contribute diff --git a/python/destinations/MQTT/library.json b/python/destinations/MQTT/library.json index 9f3a4514..ab70cb26 100644 --- a/python/destinations/MQTT/library.json +++ b/python/destinations/MQTT/library.json @@ -18,25 +18,32 @@ "Type": "EnvironmentVariable", "InputType": "InputTopic", "Description": "Name of the input topic to listen to.", - "DefaultValue": "", "Required": true }, { - "Name": "mqtt_topic_root", + "Name": "MQTT_CLIENT_ID", + "Type": "EnvironmentVariable", + "InputType": "FreeText", + "Description": "A client ID for the sink", + "DefaultValue": "mqtt-sink", + "Required": true + }, + { + "Name": "MQTT_TOPIC_ROOT", "Type": "EnvironmentVariable", "InputType": "FreeText", "Description": "The root for messages in MQTT, this can be anything", "Required": true }, { - "Name": "mqtt_server", + "Name": "MQTT_SERVER", "Type": "EnvironmentVariable", "InputType": "FreeText", "Description": "The address of your MQTT server", "Required": true }, { - "Name": "mqtt_port", + "Name": "MQTT_PORT", "Type": "EnvironmentVariable", "InputType": "FreeText", "Description": "The port of your MQTT server", @@ -44,25 +51,32 @@ "Required": true }, { - "Name": "mqtt_username", + "Name": "MQTT_USERNAME", "Type": "EnvironmentVariable", "InputType": "FreeText", "Description": "Username of your MQTT user", - "Required": false + "Required": true }, { - "Name": "mqtt_password", + "Name": "MQTT_PASSWORD", "Type": "EnvironmentVariable", "InputType": "Secret", "Description": "Password for the MQTT user", - "Required": false + "Required": true + }, + { + "Name": "MQTT_VERSION", + "Type": "EnvironmentVariable", + "InputType": "FreeText", + "Description": "MQTT protocol version; choose 3.1, 3.1.1, or 5", + "Required": true }, { - "Name": "mqtt_version", + "Name": "MQTT_USE_TLS", "Type": "EnvironmentVariable", "InputType": "FreeText", - "Description": "MQTT protocol version: 3.1, 3.1.1, 5", - "DefaultValue": "3.1.1", + "Description": "Set to true if the server uses TLS", + "DefaultValue": "true", "Required": true } ], diff --git a/python/destinations/MQTT/main.py b/python/destinations/MQTT/main.py index 8aa2b90a..ec55882d 100644 --- a/python/destinations/MQTT/main.py +++ b/python/destinations/MQTT/main.py @@ -1,95 +1,28 @@ -from quixstreams import Application, context -import paho.mqtt.client as paho -from paho import mqtt -import json +from mqtt import MQTTSink +from quixstreams import Application import os # Load environment variables (useful when working locally) -from dotenv import load_dotenv -load_dotenv() +# from dotenv import load_dotenv +# load_dotenv() -def mqtt_protocol_version(): - if os.environ["mqtt_version"] == "3.1": - print("Using MQTT version 3.1") - return paho.MQTTv31 - if os.environ["mqtt_version"] == "3.1.1": - print("Using MQTT version 3.1.1") - return paho.MQTTv311 - if os.environ["mqtt_version"] == "5": - print("Using MQTT version 5") - return paho.MQTTv5 - print("Defaulting to MQTT version 3.1.1") - return paho.MQTTv311 +app = Application(consumer_group="mqtt_consumer_group", auto_offset_reset="earliest") +input_topic = app.topic(os.environ["input"], value_deserializer="double") -def configure_authentication(mqtt_client): - mqtt_username = os.getenv("mqtt_username", "") - if mqtt_username != "": - mqtt_password = os.getenv("mqtt_password", "") - if mqtt_password == "": - raise ValueError('mqtt_password must set when mqtt_username is set') - print("Using username & password authentication") - mqtt_client.username_pw_set(os.environ["mqtt_username"], os.environ["mqtt_password"]) - return - print("Using anonymous authentication") +sink = MQTTSink( + client_id=os.environ["MQTT_CLIENT_ID"], + server=os.environ["MQTT_SERVER"], + port=int(os.environ["MQTT_PORT"]), + topic_root=os.environ["MQTT_TOPIC_ROOT"], + username=os.environ["MQTT_USERNAME"], + password=os.environ["MQTT_PASSWORD"], + version=os.environ["MQTT_VERSION"], + tls_enabled=os.environ["MQTT_USE_TLS"].lower() == "true" +) -mqtt_port = os.environ["mqtt_port"] -# Validate the config -if not mqtt_port.isnumeric(): - raise ValueError('mqtt_port must be a numeric value') +sdf = app.dataframe(topic=input_topic) +sdf.sink(sink) -client_id = os.getenv("Quix__Deployment__Id", "default") -mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2, - client_id = client_id, userdata = None, protocol = mqtt_protocol_version()) -mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls -mqtt_client.reconnect_delay_set(5, 60) -configure_authentication(mqtt_client) -# Create a Quix platform-specific application instead -app = Application(consumer_group="mqtt_consumer_group", auto_offset_reset='earliest') -# initialize the topic, this will combine the topic name with the environment details to produce a valid topic identifier -input_topic = app.topic(os.environ["input"]) - -# setting callbacks for different events to see if it works, print the message etc. -def on_connect_cb(client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): - if reason_code == 0: - print("CONNECTED!") # required for Quix to know this has connected - else: - print(f"ERROR! - ({reason_code.value}). {reason_code.getName()}") - -def on_disconnect_cb(client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): - print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!") - -mqtt_client.on_connect = on_connect_cb -mqtt_client.on_disconnect = on_disconnect_cb - -mqtt_topic_root = os.environ["mqtt_topic_root"] - -# connect to MQTT Cloud on port 8883 (default for MQTT) -mqtt_client.connect(os.environ["mqtt_server"], int(mqtt_port)) - -# Hook up to termination signal (for docker image) and CTRL-C -print("Listening to streams. Press CTRL-C to exit.") - -sdf = app.dataframe(input_topic) - -def publish_to_mqtt(data, key, timestamp, headers): - json_data = json.dumps(data) - message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding - # publish to MQTT - mqtt_client.publish(mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = 1) - -sdf = sdf.apply(publish_to_mqtt, metadata=True) - - -# start the background process to handle MQTT messages -mqtt_client.loop_start() - -print("Starting application") -# run the data processing pipeline -app.run(sdf) - -# stop handling MQTT messages -mqtt_client.loop_stop() -print("Exiting") \ No newline at end of file +if __name__ == '__main__': + app.run() diff --git a/python/destinations/MQTT/mqtt.py b/python/destinations/MQTT/mqtt.py new file mode 100644 index 00000000..7b90a755 --- /dev/null +++ b/python/destinations/MQTT/mqtt.py @@ -0,0 +1,254 @@ +import json +import logging +import time +from datetime import datetime +from typing import Any, Callable, Literal, Optional, Union, get_args + +from quixstreams.models.types import HeadersTuples +from quixstreams.sinks import ( + BaseSink, + ClientConnectFailureCallback, + ClientConnectSuccessCallback, +) + +try: + import paho.mqtt.client as paho +except ImportError as exc: + raise ImportError( + 'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" + ) from exc + + +logger = logging.getLogger(__name__) + +VERSION_MAP = { + "3.1": paho.MQTTv31, + "3.1.1": paho.MQTTv311, + "5": paho.MQTTv5, +} +MQTT_SUCCESS = paho.MQTT_ERR_SUCCESS +ProtocolVersion = Literal["3.1", "3.1.1", "5"] +MqttPropertiesHandler = Union[paho.Properties, Callable[[Any], paho.Properties]] +RetainHandler = Union[bool, Callable[[Any], bool]] + + +class MQTTSink(BaseSink): + """ + A sink that publishes messages to an MQTT broker. + """ + + def __init__( + self, + client_id: str, + server: str, + port: int, + topic_root: str, + username: str = None, + password: str = None, + version: ProtocolVersion = "3.1.1", + tls_enabled: bool = True, + key_serializer: Callable[[Any], str] = bytes.decode, + value_serializer: Callable[[Any], str] = json.dumps, + qos: Literal[0, 1] = 1, + mqtt_flush_timeout_seconds: int = 10, + retain: Union[bool, Callable[[Any], bool]] = False, + properties: Optional[MqttPropertiesHandler] = None, + on_client_connect_success: Optional[ClientConnectSuccessCallback] = None, + on_client_connect_failure: Optional[ClientConnectFailureCallback] = None, + ): + """ + Initialize the MQTTSink. + + :param client_id: MQTT client identifier. + :param server: MQTT broker server address. + :param port: MQTT broker server port. + :param topic_root: Root topic to publish messages to. + :param username: Username for MQTT broker authentication. Default = None + :param password: Password for MQTT broker authentication. Default = None + :param version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 + :param tls_enabled: Whether to use TLS encryption. Default = True + :param key_serializer: How to serialize the MQTT message key for producing. + :param value_serializer: How to serialize the MQTT message value for producing. + :param qos: Quality of Service level (0 or 1; 2 not yet supported) Default = 1. + :param mqtt_flush_timeout_seconds: how long to wait for publish acknowledgment + of MQTT messages before failing. Default = 10. + :param retain: Retain last message for new subscribers. Default = False. + Also accepts a callable that uses the current message value as input. + :param properties: An optional Properties instance for messages. Default = None. + Also accepts a callable that uses the current message value as input. + :param on_client_connect_success: An optional callback made after successful + client authentication, primarily for additional logging. + :param on_client_connect_failure: An optional callback made after failed + client authentication (which should raise an Exception). + Callback should accept the raised Exception as an argument. + Callback must resolve (or propagate/re-raise) the Exception. + """ + super().__init__( + on_client_connect_success=on_client_connect_success, + on_client_connect_failure=on_client_connect_failure, + ) + if qos == 2: + raise ValueError(f"MQTT QoS level {2} is currently not supported.") + if not (protocol := VERSION_MAP.get(version)): + raise ValueError( + f"Invalid MQTT version {version}; valid: {get_args(ProtocolVersion)}" + ) + if properties and protocol != "5": + raise ValueError( + "MQTT Properties can only be used with MQTT protocol version 5" + ) + + self._version = version + self._server = server + self._port = port + self._topic_root = topic_root + self._key_serializer = key_serializer + self._value_serializer = value_serializer + self._qos = qos + self._flush_timeout = mqtt_flush_timeout_seconds + self._pending_acks: set[int] = set() + self._retain = _get_retain_callable(retain) + self._properties = _get_properties_callable(properties) + + self._client = paho.Client( + callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id=client_id, + userdata=None, + protocol=protocol, + ) + + if username: + self._client.username_pw_set(username, password) + if tls_enabled: + self._client.tls_set(tls_version=paho.ssl.PROTOCOL_TLS) + self._client.reconnect_delay_set(5, 60) + self._client.on_connect = _mqtt_on_connect_cb + self._client.on_disconnect = _mqtt_on_disconnect_cb + self._client.on_publish = self._on_publish_cb + self._publish_count = 0 + + def setup(self): + self._client.connect(self._server, self._port) + self._client.loop_start() + + def _publish_to_mqtt( + self, + data: Any, + topic_suffix: Any, + ): + properties = self._properties + info = self._client.publish( + f"{self._topic_root}/{self._key_serializer(topic_suffix)}", + payload=self._value_serializer(data), + qos=self._qos, + properties=properties(data) if properties else None, + retain=self._retain(data), + ) + if self._qos: + if info.rc != MQTT_SUCCESS: + raise MqttPublishEnqueueFailed( + f"Failed adding message to MQTT publishing queue; " + f"error code {info.rc}: {paho.error_string(info.rc)}" + ) + self._pending_acks.add(info.mid) + else: + self._publish_count += 1 + + def _on_publish_cb( + self, + client: paho.Client, + userdata: Any, + mid: int, + rc: paho.ReasonCode, + p: paho.Properties, + ): + """ + This is only triggered upon successful publish when self._qos > 0. + """ + self._publish_count += 1 + self._pending_acks.remove(mid) + + def add( + self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: HeadersTuples, + ): + try: + self._publish_to_mqtt(value, key) + except Exception as e: + self._cleanup() + raise e + + def flush(self): + if self._pending_acks: + start_time = time.monotonic() + timeout = start_time + self._flush_timeout + while self._pending_acks and start_time < timeout: + logger.debug(f"Pending acks remaining: {len(self._pending_acks)}") + time.sleep(1) + if self._pending_acks: + self._cleanup() + raise MqttPublishAckTimeout( + f"Mqtt acknowledgement timeout of {self._flush_timeout}s reached." + ) + logger.info(f"{self._publish_count} MQTT messages published.") + self._publish_count = 0 + + def on_paused(self): + pass + + def _cleanup(self): + self._client.loop_stop() + self._client.disconnect() + + +class MqttPublishEnqueueFailed(Exception): + pass + + +class MqttPublishAckTimeout(Exception): + pass + + +def _mqtt_on_connect_cb( + client: paho.Client, + userdata: any, + connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, +): + if reason_code != 0: + raise ConnectionError( + f"Failed to connect to MQTT broker; ERROR: ({reason_code.value}).{reason_code.getName()}" + ) + + +def _mqtt_on_disconnect_cb( + client: paho.Client, + userdata: any, + disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, +): + logger.info( + f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + ) + + +def _get_properties_callable( + properties: Optional[MqttPropertiesHandler], +) -> Optional[Callable[[Any], paho.Properties]]: + if isinstance(properties, paho.Properties): + return lambda data: properties(data) + return properties + + +def _get_retain_callable(retain: RetainHandler) -> Callable[[Any], bool]: + if isinstance(retain, bool): + return lambda data: retain + return retain diff --git a/python/destinations/MQTT/requirements.txt b/python/destinations/MQTT/requirements.txt index 55e9dfb4..ccd7a0c1 100644 --- a/python/destinations/MQTT/requirements.txt +++ b/python/destinations/MQTT/requirements.txt @@ -1,3 +1,2 @@ -quixstreams==2.9.0 -paho-mqtt==2.1.0 +quixstreams[mqtt]==3.22.0 python-dotenv \ No newline at end of file