diff --git a/tests/count_data_points_message_tests.py b/tests/count_data_points_message_tests.py new file mode 100644 index 0000000..0dff5fc --- /dev/null +++ b/tests/count_data_points_message_tests.py @@ -0,0 +1,78 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from tb_device_mqtt import TBDeviceMqttClient + + +class TestCountDataPointsInMessage(unittest.TestCase): + def test_simple_dict_no_device(self): + data = { + "ts": 123456789, + "values": { + "temp": 22.5, + "humidity": 55 + } + } + result = TBDeviceMqttClient._count_datapoints_in_message(data) + self.assertEqual(result, 2) + + def test_list_of_dict_no_device(self): + data = [ + {"ts": 123456789, "values": {"temp": 22.5, "humidity": 55}}, + {"ts": 123456799, "values": {"light": 100, "pressure": 760}} + ] + result = TBDeviceMqttClient._count_datapoints_in_message(data) + self.assertEqual(result, 4) + + def test_with_device_dict_inside(self): + data = { + "MyDevice": { + "ts": 123456789, + "values": {"temp": 22.5, "humidity": 55} + }, + "OtherKey": "some_value" + } + result = TBDeviceMqttClient._count_datapoints_in_message(data, device="MyDevice") + self.assertEqual(result, 2) + + def test_with_device_list_inside(self): + data = { + "Sensor": [ + {"ts": 1, "values": {"v1": 10}}, + {"ts": 2, "values": {"v2": 20, "v3": 30}} + ] + } + result = TBDeviceMqttClient._count_datapoints_in_message(data, device="Sensor") + self.assertEqual(result, 3) + + def test_empty_dict_no_device(self): + data = {} + result = TBDeviceMqttClient._count_datapoints_in_message(data) + self.assertEqual(result, 0) + + def test_missing_device_key(self): + + data = {"some_unrelated_key": 42} + result = TBDeviceMqttClient._count_datapoints_in_message(data, device="NotExistingDeviceKey") + self.assertEqual(result, 1) + + def test_data_is_string_no_device(self): + data = "just a string" + result = TBDeviceMqttClient._count_datapoints_in_message(data) + self.assertEqual(result, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/firmware_tests.py b/tests/firmware_tests.py new file mode 100644 index 0000000..16bf53d --- /dev/null +++ b/tests/firmware_tests.py @@ -0,0 +1,258 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock, call +from math import ceil +import orjson +from threading import Thread +from tb_device_mqtt import ( + TBDeviceMqttClient, + TBTimeoutException, + FW_VERSION_ATTR, FW_TITLE_ATTR, FW_STATE_ATTR +) +from paho.mqtt.client import ReasonCodes + + +REQUIRED_SHARED_KEYS = "dummy_shared_keys" + + +class TestFirmwareUpdateBranch(unittest.TestCase): + @patch('tb_device_mqtt.sleep', return_value=None, autospec=True) + @patch('tb_device_mqtt.log.debug', autospec=True) + def test_firmware_update_branch(self, _, mock_sleep): + client = TBDeviceMqttClient('fake_host', username="dummy_token", password="dummy") + + client._messages_rate_limit = MagicMock() + + client.current_firmware_info = { + "current_" + FW_VERSION_ATTR: "v0", + FW_STATE_ATTR: "IDLE" + } + client.firmware_data = b"old_data" + client._TBDeviceMqttClient__current_chunk = 2 + client._TBDeviceMqttClient__firmware_request_id = 0 + client._TBDeviceMqttClient__chunk_size = 128 + client._TBDeviceMqttClient__target_firmware_length = 0 + + client.send_telemetry = MagicMock() + client._TBDeviceMqttClient__get_firmware = MagicMock() + + message_mock = MagicMock() + message_mock.topic = "v1/devices/me/attributes_update" + payload_dict = { + "fw_version": "v1", + "fw_title": "TestFirmware", + "fw_size": 900 + } + message_mock.payload = orjson.dumps(payload_dict) + + client._on_decoded_message({}, message_mock) + client.stopped = True + + client._messages_rate_limit.increase_rate_limit_counter.assert_called_once() + + self.assertEqual(client.firmware_data, b"") + self.assertEqual(client._TBDeviceMqttClient__current_chunk, 0) + self.assertEqual(client.current_firmware_info[FW_STATE_ATTR], "DOWNLOADING") + + client.send_telemetry.assert_called_once_with(client.current_firmware_info) + + client._TBDeviceMqttClient__get_firmware.assert_called_once() + + self.assertEqual(client._TBDeviceMqttClient__firmware_request_id, 1) + self.assertEqual(client._TBDeviceMqttClient__target_firmware_length, 900) + self.assertEqual(client._TBDeviceMqttClient__chunk_count, ceil(900 / 128)) + client._TBDeviceMqttClient__get_firmware.assert_called_once() + + +class TestTBDeviceMqttClient(unittest.TestCase): + @patch('tb_device_mqtt.paho.Client') + def setUp(self, mock_paho_client): + self.mock_mqtt_client = mock_paho_client.return_value + self.client = TBDeviceMqttClient( + host='your_host', + port=1883, + username='your_token', + password=None + ) + self.client.firmware_info = {FW_TITLE_ATTR: "dummy_firmware.bin"} + self.client.firmware_data = b'' + self.client._TBDeviceMqttClient__current_chunk = 0 + self.client._TBDeviceMqttClient__firmware_request_id = 1 + self.client._TBDeviceMqttClient__updating_thread = Thread(target=lambda: None) + self.client._publish_data = MagicMock() + + if not hasattr(self.client, '_client'): + self.client._client = self.mock_mqtt_client + + def test_get_firmware_update(self): + self.client._client.subscribe = MagicMock() + self.client.send_telemetry = MagicMock() + self.client.get_firmware_update() + self.client._client.subscribe.assert_called_with('v2/fw/response/+') + self.client.send_telemetry.assert_called() + self.client._publish_data.assert_called() + + def test_firmware_download_process(self): + self.client.firmware_info = { + FW_TITLE_ATTR: "dummy_firmware.bin", + FW_VERSION_ATTR: "2.0", + "fw_size": 1024, + "fw_checksum": "abc123", + "fw_checksum_algorithm": "SHA256" + } + self.client._TBDeviceMqttClient__current_chunk = 0 + self.client._TBDeviceMqttClient__firmware_request_id = 1 + self.client._TBDeviceMqttClient__get_firmware() + self.client._publish_data.assert_called() + + def test_firmware_verification_success(self): + self.client.firmware_data = b'binary data' + self.client.firmware_info = { + FW_TITLE_ATTR: "dummy_firmware.bin", + FW_VERSION_ATTR: "2.0", + "fw_checksum": "valid_checksum", + "fw_checksum_algorithm": "SHA256" + } + self.client._TBDeviceMqttClient__process_firmware() + self.client._publish_data.assert_called() + + def test_firmware_verification_failure(self): + self.client.firmware_data = b'corrupt data' + self.client.firmware_info = { + FW_TITLE_ATTR: "dummy_firmware.bin", + FW_VERSION_ATTR: "2.0", + "fw_checksum": "invalid_checksum", + "fw_checksum_algorithm": "SHA256" + } + self.client._TBDeviceMqttClient__process_firmware() + self.client._publish_data.assert_called() + + def test_firmware_state_transition(self): + self.client._publish_data.reset_mock() + self.client.current_firmware_info = { + "current_fw_title": "OldFirmware", + "current_fw_version": "1.0", + "fw_state": "IDLE" + } + self.client.firmware_received = True + self.client.firmware_info[FW_TITLE_ATTR] = "dummy_firmware.bin" + self.client.firmware_info[FW_VERSION_ATTR] = "dummy_version" + + def test_firmware_request_info(self): + self.client._publish_data.reset_mock() + self.client._TBDeviceMqttClient__request_firmware_info() + self.client._publish_data.assert_called() + + def test_firmware_chunk_reception_detailed(self): + self.client._publish_data.reset_mock() + self.client._TBDeviceMqttClient__get_firmware() + self.client._publish_data.assert_called() + + @patch.object(TBDeviceMqttClient, 'send_telemetry') + def test_process_firmware_telemetry_calls(self, mock_send_telemetry): + self.client.firmware_data = b"some_firmware_data" + self.client.firmware_info = { + FW_TITLE_ATTR: "dummy_firmware.bin", + FW_VERSION_ATTR: "2.0", + "fw_checksum": "valid_checksum", + "fw_checksum_algorithm": "SHA256" + } + + self.client._TBDeviceMqttClient__process_firmware() + + self.assertEqual( + mock_send_telemetry.call_count, + 2, + "Two calls to send_telemetry are expected in the current firmware implementation" + ) + + expected_calls = [ + call({"current_fw_title": "Initial", "current_fw_version": "v0", "fw_state": "FAILED"}), + call({"current_fw_title": "Initial", "current_fw_version": "v0", "fw_state": "FAILED"}) + ] + mock_send_telemetry.assert_has_calls(expected_calls, any_order=False) + + +class TestFirmwareChunkReception(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient(host="localhost", port=1883) + self.client._TBDeviceMqttClient__firmware_request_id = 1 + self.client._TBDeviceMqttClient__current_chunk = 0 + + @patch.object(TBDeviceMqttClient, '_publish_data') + def test_firmware_chunk_reception(self, mock_publish_data): + self.client._TBDeviceMqttClient__chunk_size = 128 + self.client.firmware_info = { + "fw_size": 300, + "fw_title": "SomeFirmware", + "fw_checksum": "12345", + "fw_checksum_algorithm": "SHA256" + } + self.client._TBDeviceMqttClient__get_firmware() + expected_calls = [ + call(b'128', 'v2/fw/request/1/chunk/0', 1) + ] + self.assertEqual(mock_publish_data.call_count, 1, "Only one chunk request is expected") + mock_publish_data.assert_has_calls(expected_calls, any_order=False) + + self.assertEqual(self.client._TBDeviceMqttClient__current_chunk, 0, + "The current_chunk should not change if the method only requests chunks.") + + +class TestFirmwareUpdate(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient(host="localhost", port=1883) + self.client._TBDeviceMqttClient__process_firmware = MagicMock() + self.client._TBDeviceMqttClient__get_firmware = MagicMock() + + self.client._TBDeviceMqttClient__firmware_request_id = 1 + self.client._TBDeviceMqttClient__current_chunk = 0 + self.client._TBDeviceMqttClient__target_firmware_length = 10 + + self.client.firmware_data = b'' + + def test_incomplete_firmware_chunk(self): + chunk_data = b'abcde' + message = MagicMock() + message.topic = "v2/fw/response/1/chunk/0" + message.payload = chunk_data + + self.client._on_message(None, None, message) + self.assertEqual(self.client.firmware_data, b'abcde') + self.assertEqual(self.client._TBDeviceMqttClient__current_chunk, 1) + self.client._TBDeviceMqttClient__process_firmware.assert_not_called() + self.client._TBDeviceMqttClient__get_firmware.assert_called_once() + + def test_complete_firmware_chunk(self): + self.client.firmware_data = b'abcde' + self.client._TBDeviceMqttClient__current_chunk = 1 + + chunk_data = b'12345' + message = MagicMock() + message.topic = "v2/fw/response/1/chunk/1" + message.payload = chunk_data + + self.client._on_message(None, None, message) + + self.assertEqual(self.client.firmware_data, b'abcde12345') + self.assertEqual(self.client._TBDeviceMqttClient__current_chunk, 2) + + self.client._TBDeviceMqttClient__process_firmware.assert_called_once() + self.client._TBDeviceMqttClient__get_firmware.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/gateway_init_tests.py b/tests/gateway_init_tests.py new file mode 100644 index 0000000..663b862 --- /dev/null +++ b/tests/gateway_init_tests.py @@ -0,0 +1,127 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +from tb_gateway_mqtt import TBGatewayMqttClient +from tb_device_mqtt import TBDeviceMqttClient +import threading + + +class TestOnServiceConfiguration(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + if not hasattr(self.client, "_lock"): + self.client._lock = threading.Lock() + self.client._devices_connected_through_gateway_messages_rate_limit = MagicMock() + self.client._devices_connected_through_gateway_telemetry_messages_rate_limit = MagicMock() + self.client._devices_connected_through_gateway_telemetry_datapoints_rate_limit = MagicMock() + self.client.rate_limits_received = False + + def test_on_service_configuration_error(self): + error_response = {"error": "timeout"} + parent_class = self.client.__class__.__bases__[0] + with patch.object(parent_class, "on_service_configuration") as mock_parent_on_service_configuration: + self.client._TBGatewayMqttClient__on_service_configuration("dummy_arg", error_response) + self.assertTrue(self.client.rate_limits_received) + mock_parent_on_service_configuration.assert_not_called() + + def test_on_service_configuration_valid(self): + response = { + "gatewayRateLimits": { + "messages": "10:20", + "telemetryMessages": "30:40", + "telemetryDataPoints": "50:60", + }, + "rateLimits": {"limit": "value"}, + "other_config": "other_value" + } + response_copy = response.copy() + parent_class = self.client.__class__.__bases__[0] + with patch.object(parent_class, "on_service_configuration") as mock_parent_on_service_configuration: + self.client._TBGatewayMqttClient__on_service_configuration("dummy_arg", response_copy, "extra_arg", key="extra") + self.client._devices_connected_through_gateway_messages_rate_limit.set_limit.assert_called_with("10:20") + self.client._devices_connected_through_gateway_telemetry_messages_rate_limit.set_limit.assert_called_with("30:40") + self.client._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit.assert_called_with("50:60") + expected_dict = {'rateLimit': {"limit": "value"}, "other_config": "other_value"} + mock_parent_on_service_configuration.assert_called_with("dummy_arg", expected_dict, "extra_arg", key="extra") + + def test_on_service_configuration_default_telemetry_datapoints(self): + response = { + "gatewayRateLimits": { + "messages": "10:20", + "telemetryMessages": "30:40", + }, + "rateLimits": {"limit": "value"}, + "other_config": "other_value" + } + response_copy = response.copy() + parent_class = self.client.__class__.__bases__[0] + with patch.object(parent_class, "on_service_configuration") as mock_parent_on_service_configuration: + self.client._TBGatewayMqttClient__on_service_configuration("dummy_arg", response_copy, "extra_arg", key="extra") + self.client._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit.assert_called_with("0:0,") + expected_dict = {'rateLimit': {"limit": "value"}, "other_config": "other_value"} + mock_parent_on_service_configuration.assert_called_with("dummy_arg", expected_dict, "extra_arg", key="extra") + + +class TestRateLimitInitialization(unittest.TestCase): + @staticmethod + def fake_init(instance, host, port, username, password, quality_of_service, client_id, **kwargs): + instance._init_kwargs = kwargs + instance._client = MagicMock() + + def test_custom_rate_limits(self): + custom_rate = "MY_RATE_LIMIT" + custom_dp = "MY_RATE_LIMIT_DP" + + with patch("tb_gateway_mqtt.RateLimit.__init__", return_value=None), \ + patch("tb_gateway_mqtt.RateLimit.get_rate_limits_by_host", return_value=(custom_rate, custom_dp)), \ + patch("tb_gateway_mqtt.RateLimit.get_rate_limit_by_host", return_value=custom_rate), \ + patch.object(TBDeviceMqttClient, '__init__', new=TestRateLimitInitialization.fake_init): + client = TBGatewayMqttClient( + host="localhost", + port=1883, + username="dummy_token", + rate_limit=custom_rate, + dp_rate_limit=custom_dp + ) + captured = client._init_kwargs + + self.assertEqual(captured.get("messages_rate_limit"), custom_rate) + self.assertEqual(captured.get("telemetry_rate_limit"), custom_rate) + self.assertEqual(captured.get("telemetry_dp_rate_limit"), custom_dp) + + def test_default_rate_limits(self): + default_rate = "DEFAULT_RATE_LIMIT" + with patch("tb_gateway_mqtt.RateLimit.__init__", return_value=None), \ + patch("tb_gateway_mqtt.RateLimit.get_rate_limits_by_host", + return_value=("DEFAULT_MESSAGES_RATE_LIMIT", "DEFAULT_TELEMETRY_DP_RATE_LIMIT")), \ + patch("tb_gateway_mqtt.RateLimit.get_rate_limit_by_host", return_value="DEFAULT_MESSAGES_RATE_LIMIT"), \ + patch.object(TBDeviceMqttClient, '__init__', new=TestRateLimitInitialization.fake_init): + client = TBGatewayMqttClient( + host="localhost", + port=1883, + username="dummy_token", + rate_limit=default_rate, + dp_rate_limit=default_rate + ) + captured = client._init_kwargs + + self.assertEqual(captured.get("messages_rate_limit"), "DEFAULT_MESSAGES_RATE_LIMIT") + self.assertEqual(captured.get("telemetry_rate_limit"), "DEFAULT_TELEMETRY_RATE_LIMIT") + self.assertEqual(captured.get("telemetry_dp_rate_limit"), "DEFAULT_TELEMETRY_DP_RATE_LIMIT") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/on_decoded_message_tests.py b/tests/on_decoded_message_tests.py new file mode 100644 index 0000000..775baae --- /dev/null +++ b/tests/on_decoded_message_tests.py @@ -0,0 +1,165 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import threading +from unittest.mock import MagicMock +from tb_gateway_mqtt import ( + TBGatewayMqttClient, + GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, + GATEWAY_ATTRIBUTES_TOPIC, + GATEWAY_RPC_TOPIC +) + + +class FakeMessage: + def __init__(self, topic): + self.topic = topic + + +class TestOnDecodedMessage(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + + def test_on_decoded_message_attributes_response_non_tuple(self): + content = {"id": 123, "data": "dummy_response"} + fake_message = FakeMessage(topic=GATEWAY_ATTRIBUTES_RESPONSE_TOPIC) + + self.called = False + + def callback(msg, error): + self.called = True + self.callback_args = (msg, error) + + self.client._attr_request_dict = {123: callback} + self.client._devices_connected_through_gateway_messages_rate_limit = MagicMock() + + self.client._on_decoded_message(content, fake_message) + + self.assertTrue(self.called) + self.assertEqual(self.callback_args, (content, None)) + self.assertNotIn(123, self.client._attr_request_dict) + self.client._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter.assert_called_with(1) + + def test_on_decoded_message_attributes_response_tuple(self): + content = {"id": 456, "data": "dummy_response"} + fake_message = FakeMessage(topic=GATEWAY_ATTRIBUTES_RESPONSE_TOPIC) + + self.called = False + + def callback(msg, error, extra): + self.called = True + self.callback_args = (msg, error, extra) + + self.client._attr_request_dict = {456: (callback, "extra_value")} + self.client._devices_connected_through_gateway_messages_rate_limit = MagicMock() + + self.client._on_decoded_message(content, fake_message) + + self.assertTrue(self.called) + self.assertEqual(self.callback_args, (content, None, "extra_value")) + self.assertNotIn(456, self.client._attr_request_dict) + self.client._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter.assert_called_with(1) + + def test_on_decoded_message_attributes_topic(self): + content = { + "device": "device1", + "data": {"attr1": "value1", "attr2": "value2"} + } + fake_message = FakeMessage(topic=GATEWAY_ATTRIBUTES_TOPIC) + + self.flags = {"global": False, "device_all": False, "attr1": False, "attr2": False} + + def callback_global(msg): + self.flags["global"] = True + + def callback_device_all(msg): + self.flags["device_all"] = True + + def callback_attr1(msg): + self.flags["attr1"] = True + + def callback_attr2(msg): + self.flags["attr2"] = True + + self.client._TBGatewayMqttClient__sub_dict = { + "*|*": {"global": callback_global}, + "device1|*": {"device_all": callback_device_all}, + "device1|attr1": {"attr1": callback_attr1}, + "device1|attr2": {"attr2": callback_attr2} + } + self.client._devices_connected_through_gateway_messages_rate_limit = MagicMock() + + self.client._on_decoded_message(content, fake_message) + + self.assertTrue(self.flags["global"]) + self.assertTrue(self.flags["device_all"]) + self.assertTrue(self.flags["attr1"]) + self.assertTrue(self.flags["attr2"]) + self.client._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter.assert_called_with(1) + + def test_on_decoded_message_rpc_topic(self): + content = {"data": "dummy_rpc"} + fake_message = FakeMessage(topic=GATEWAY_RPC_TOPIC) + + self.client._devices_connected_through_gateway_messages_rate_limit = MagicMock() + self.called = False + + def rpc_handler(client, msg): + self.called = True + self.rpc_args = (client, msg) + + self.client.devices_server_side_rpc_request_handler = rpc_handler + + self.client._on_decoded_message(content, fake_message) + + self.assertTrue(self.called) + self.assertEqual(self.rpc_args, (self.client, content)) + self.client._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter.assert_called_with(1) + + def test_subscription_filling_for_device_attribute(self): + self.client._TBGatewayMqttClient__connected_devices = {"test_device"} + + def dummy_callback(msg): + pass + + sub_id = self.client.gw_subscribe_to_attribute("test_device", "attr1", dummy_callback) + + sub_key = "test_device|attr1" + self.assertIn(sub_key, self.client._TBGatewayMqttClient__sub_dict, + "The ‘test_device|attr1’ key in __sub_dict was expected after subscription.") + + self.assertIn("test_device", self.client._TBGatewayMqttClient__sub_dict[sub_key]) + self.assertEqual( + self.client._TBGatewayMqttClient__sub_dict[sub_key]["test_device"], + dummy_callback, + "Colback is not the same as expected." + ) + + def test_subscription_filling_for_all_attributes(self): + + def dummy_callback_all(msg): + pass + + sub_id_all = self.client.gw_subscribe_to_all_attributes(dummy_callback_all) + + self.assertIn("*|*", self.client._TBGatewayMqttClient__sub_dict, + "In __sub_dict, the key ‘*|*’ did not appear to subscribe to all attributes.") + self.assertIn("*", self.client._TBGatewayMqttClient__sub_dict["*|*"]) + self.assertEqual(self.client._TBGatewayMqttClient__sub_dict["*|*"]["*"], dummy_callback_all, + "The colback for ‘*|*’->‘*’ is not the same as expected.") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/rate_limit_tests.py b/tests/rate_limit_tests.py new file mode 100644 index 0000000..58390ae --- /dev/null +++ b/tests/rate_limit_tests.py @@ -0,0 +1,460 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from unittest.mock import MagicMock +from time import sleep +from tb_device_mqtt import RateLimit, TBDeviceMqttClient, TELEMETRY_TOPIC + + +class TestRateLimit(unittest.TestCase): + def setUp(self): + self.rate_limit = RateLimit("10:1,60:10", "test_limit") + self.client = TBDeviceMqttClient("localhost") + + print("Default messages rate limit:", self.client._messages_rate_limit._rate_limit_dict) + print("Default telemetry rate limit:", self.client._telemetry_rate_limit._rate_limit_dict) + print("Default telemetry DP rate limit:", self.client._telemetry_dp_rate_limit._rate_limit_dict) + + self.client._messages_rate_limit.set_limit("10:1,60:10") + self.client._telemetry_rate_limit.set_limit("10:1,60:10") + self.client._telemetry_dp_rate_limit.set_limit("10:1,60:10") + + def test_initialization(self): + self.assertEqual(self.rate_limit.name, "test_limit") + self.assertEqual(self.rate_limit.percentage, 80) + self.assertFalse(self.rate_limit._no_limit) + + def test_check_limit_not_reached(self): + self.assertFalse(self.rate_limit.check_limit_reached()) + + def test_increase_counter(self): + self.rate_limit.increase_rate_limit_counter() + self.assertEqual(self.rate_limit._rate_limit_dict[1]['counter'], 1) + + def test_limit_reached(self): + for _ in range(10): + self.rate_limit.increase_rate_limit_counter() + self.assertEqual(self.rate_limit.check_limit_reached(), 1) + + def test_limit_reset_after_time(self): + self.rate_limit.increase_rate_limit_counter(10) + self.assertEqual(self.rate_limit.check_limit_reached(), 1) + sleep(1.1) + self.assertFalse(self.rate_limit.check_limit_reached()) + + def test_get_minimal_timeout(self): + self.assertEqual(self.rate_limit.get_minimal_timeout(), 2) + + def test_set_limit(self): + self.rate_limit.set_limit("5:1,30:5") + print("Updated _rate_limit_dict:", self.rate_limit._rate_limit_dict) + self.assertIn(5, self.rate_limit._rate_limit_dict) + + def test_no_limit(self): + unlimited = RateLimit("0:0") + self.assertTrue(unlimited._no_limit) + self.assertFalse(unlimited.check_limit_reached()) + + def test_messages_rate_limit(self): + self.assertIsInstance(self.client._messages_rate_limit, RateLimit) + + def test_telemetry_limiter(self): + self.assertIsInstance(self.client._telemetry_rate_limit, RateLimit) + + def test_telemetry_dp_rate_limit(self): + self.assertIsInstance(self.client._telemetry_dp_rate_limit, RateLimit) + + def test_messages_rate_limit_behavior(self): + for _ in range(50): + self.client._messages_rate_limit.increase_rate_limit_counter() + print("Messages rate limit dict:", self.client._messages_rate_limit._rate_limit_dict) + self.assertTrue(self.client._messages_rate_limit.check_limit_reached()) + + def test_telemetry_rate_limit_behavior(self): + for _ in range(50): + self.client._telemetry_rate_limit.increase_rate_limit_counter() + print("Telemetry rate limit dict:", self.client._telemetry_rate_limit._rate_limit_dict) + self.assertTrue(self.client._telemetry_rate_limit.check_limit_reached()) + + def test_telemetry_dp_rate_limit_behavior(self): + for _ in range(50): + self.client._telemetry_dp_rate_limit.increase_rate_limit_counter() + print("Telemetry DP rate limit dict:", self.client._telemetry_dp_rate_limit._rate_limit_dict) + self.assertTrue(self.client._telemetry_dp_rate_limit.check_limit_reached()) + + def test_rate_limit_90_percent(self): + rate_limit_90 = RateLimit("10:1,60:10", percentage=90) + self.assertEqual(rate_limit_90.percentage, 90) + + def test_rate_limit_50_percent(self): + rate_limit_50 = RateLimit("10:1,60:10", percentage=50) + self.assertEqual(rate_limit_50.percentage, 50) + + def test_rate_limit_100_percent(self): + rate_limit_100 = RateLimit("10:1,60:10", percentage=100) + self.assertEqual(rate_limit_100.percentage, 100) + + def test_mock_rate_limit_methods(self): + mock_limit = MagicMock(spec=RateLimit) + mock_limit.check_limit_reached.return_value = False + self.assertFalse(mock_limit.check_limit_reached()) + mock_limit.increase_rate_limit_counter() + mock_limit.increase_rate_limit_counter.assert_called() + + def test_counter_increments_correctly(self): + self.rate_limit.increase_rate_limit_counter() + self.assertEqual(self.rate_limit._rate_limit_dict[1]['counter'], 1) + self.rate_limit.increase_rate_limit_counter(5) + self.assertEqual(self.rate_limit._rate_limit_dict[1]['counter'], 6) + + def test_percentage_affects_limits(self): + rate_limit_50 = RateLimit("10:1,60:10", percentage=50) + print("Rate limit dict:", rate_limit_50._rate_limit_dict) + + actual_limits = {k: v['limit'] for k, v in rate_limit_50._rate_limit_dict.items()} + expected_limits = { + 1: 5, + 10: 30 + } + self.assertEqual(actual_limits, expected_limits) + + def test_no_limit_behavior(self): + unlimited = RateLimit("0:0") + self.assertTrue(unlimited._no_limit) + self.assertFalse(unlimited.check_limit_reached()) + + def test_set_limit_preserves_counters(self): + self.rate_limit.increase_rate_limit_counter(3) + prev_counters = {k: v['counter'] for k, v in self.rate_limit._rate_limit_dict.items()} + + self.rate_limit.set_limit("20:2,120:20") + for key, counter in prev_counters.items(): + if key in self.rate_limit._rate_limit_dict: + self.assertGreaterEqual(self.rate_limit._rate_limit_dict[key]['counter'], counter) + + def test_get_rate_limits_by_host(self): + limit, dp_limit = RateLimit.get_rate_limits_by_host( + "thingsboard.cloud", + "DEFAULT_TELEMETRY_RATE_LIMIT", + "DEFAULT_TELEMETRY_DP_RATE_LIMIT" + ) + self.assertEqual(limit, "10:1,60:60,") + self.assertEqual(dp_limit, "10:1,300:60,") + + def test_limit_reset_after_time_passes(self): + self.rate_limit.increase_rate_limit_counter(10) + self.assertTrue(self.rate_limit.check_limit_reached()) + sleep(1.1) + self.assertFalse(self.rate_limit.check_limit_reached()) + + def test_message_rate_limit(self): + client = TBDeviceMqttClient("localhost") + print("Messages rate limit dict:", client._messages_rate_limit._rate_limit_dict) + + if not client._messages_rate_limit._rate_limit_dict: + client._messages_rate_limit.set_limit("10:1,60:10") + + rate_limit_dict = client._messages_rate_limit._rate_limit_dict + limit = rate_limit_dict.get(1, {}).get('limit', None) + if limit is None: + raise ValueError("Key 1 is missing in the rate limit dict.") + + client._messages_rate_limit.increase_rate_limit_counter(limit + 1) + print("Messages rate limit after increment:", client._messages_rate_limit._rate_limit_dict) + self.assertTrue(client._messages_rate_limit.check_limit_reached()) + sleep(1.1) + self.assertFalse(client._messages_rate_limit.check_limit_reached()) + + def test_telemetry_rate_limit(self): + client = TBDeviceMqttClient("localhost") + print("Telemetry rate limit dict:", client._telemetry_rate_limit._rate_limit_dict) + + if not client._telemetry_rate_limit._rate_limit_dict: + client._telemetry_rate_limit.set_limit("10:1,60:10") + + rate_limit_dict = client._telemetry_rate_limit._rate_limit_dict + limit = rate_limit_dict.get(1, {}).get('limit', None) + if limit is None: + raise ValueError("Key 1 is missing in the telemetry rate limit dict.") + + client._telemetry_rate_limit.increase_rate_limit_counter(limit + 1) + print("Telemetry rate limit after increment:", client._telemetry_rate_limit._rate_limit_dict) + self.assertTrue(client._telemetry_rate_limit.check_limit_reached()) + sleep(1.1) + self.assertFalse(client._telemetry_rate_limit.check_limit_reached()) + + def test_get_rate_limit_by_host_telemetry_cloud(self): + result = RateLimit.get_rate_limit_by_host("thingsboard.cloud", "DEFAULT_TELEMETRY_RATE_LIMIT") + self.assertEqual(result, "10:1,60:60,") + + def test_get_rate_limit_by_host_telemetry_demo(self): + result = RateLimit.get_rate_limit_by_host("demo.thingsboard.io", "DEFAULT_TELEMETRY_RATE_LIMIT") + self.assertEqual(result, "10:1,60:60,") + + def test_get_rate_limit_by_host_telemetry_unknown_host(self): + result = RateLimit.get_rate_limit_by_host("unknown.host", "DEFAULT_TELEMETRY_RATE_LIMIT") + self.assertEqual(result, "0:0,") + + def test_get_rate_limit_by_host_messages_cloud(self): + result = RateLimit.get_rate_limit_by_host("thingsboard.cloud", "DEFAULT_MESSAGES_RATE_LIMIT") + self.assertEqual(result, "10:1,60:60,") + + def test_get_rate_limit_by_host_messages_demo(self): + result = RateLimit.get_rate_limit_by_host("demo.thingsboard.io", "DEFAULT_MESSAGES_RATE_LIMIT") + self.assertEqual(result, "10:1,60:60,") + + def test_get_rate_limit_by_host_messages_unknown_host(self): + result = RateLimit.get_rate_limit_by_host("my.custom.host", "DEFAULT_MESSAGES_RATE_LIMIT") + self.assertEqual(result, "0:0,") + + def test_get_rate_limit_by_host_custom_string(self): + result = RateLimit.get_rate_limit_by_host("my.custom.host", "15:2,120:20") + self.assertEqual(result, "15:2,120:20") + + def test_get_dp_rate_limit_by_host_telemetry_dp_cloud(self): + result = RateLimit.get_dp_rate_limit_by_host("thingsboard.cloud", "DEFAULT_TELEMETRY_DP_RATE_LIMIT") + self.assertEqual(result, "10:1,300:60,") + + def test_get_dp_rate_limit_by_host_telemetry_dp_demo(self): + result = RateLimit.get_dp_rate_limit_by_host("demo.thingsboard.io", "DEFAULT_TELEMETRY_DP_RATE_LIMIT") + self.assertEqual(result, "10:1,300:60,") + + def test_get_dp_rate_limit_by_host_telemetry_dp_unknown(self): + result = RateLimit.get_dp_rate_limit_by_host("unknown.host", "DEFAULT_TELEMETRY_DP_RATE_LIMIT") + self.assertEqual(result, "0:0,") + + def test_get_dp_rate_limit_by_host_custom(self): + result = RateLimit.get_dp_rate_limit_by_host("my.custom.host", "25:3,80:10,") + self.assertEqual(result, "25:3,80:10,") + + def test_get_rate_limits_by_topic_with_device(self): + custom_msg_limit = object() + custom_dp_limit = object() + msg_limit, dp_limit = self.client._TBDeviceMqttClient__get_rate_limits_by_topic( + topic=TELEMETRY_TOPIC, + device="MyDevice", + msg_rate_limit=custom_msg_limit, + dp_rate_limit=custom_dp_limit + ) + self.assertIs(msg_limit, custom_msg_limit) + self.assertIs(dp_limit, custom_dp_limit) + + def test_get_rate_limits_by_topic_no_device_telemetry_topic(self): + msg_limit, dp_limit = self.client._TBDeviceMqttClient__get_rate_limits_by_topic( + topic=TELEMETRY_TOPIC, + device=None, + msg_rate_limit=None, + dp_rate_limit=None + ) + self.assertIs(msg_limit, self.client._telemetry_rate_limit) + self.assertIs(dp_limit, self.client._telemetry_dp_rate_limit) + + def test_get_rate_limits_by_topic_no_device_other_topic(self): + some_topic = "v1/devices/me/attributes" + msg_limit, dp_limit = self.client._TBDeviceMqttClient__get_rate_limits_by_topic( + topic=some_topic, + device=None, + msg_rate_limit=None, + dp_rate_limit=None + ) + self.assertIs(msg_limit, self.client._messages_rate_limit) + self.assertIsNone(dp_limit) + + +class TestOnServiceConfigurationIntegration(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient( + host="my.test.host", + port=1883, + username="fake_token", + messages_rate_limit="0:0,", + telemetry_rate_limit="0:0,", + telemetry_dp_rate_limit="0:0," + ) + self.assertIsInstance(self.client._messages_rate_limit, RateLimit) + self.assertIsInstance(self.client._telemetry_rate_limit, RateLimit) + self.assertIsInstance(self.client._telemetry_dp_rate_limit, RateLimit) + + def test_on_service_config_error(self): + config_with_error = {"error": "Some error text"} + self.client.on_service_configuration(None, config_with_error) + self.assertTrue(self.client.rate_limits_received, "After ‘error’ rate_limits_received => True") + self.assertTrue(self.client._messages_rate_limit._no_limit) + self.assertTrue(self.client._telemetry_rate_limit._no_limit) + + def test_on_service_config_no_rateLimits(self): + config_no_ratelimits = {"maxInflightMessages": 100} + self.client.on_service_configuration(None, config_no_ratelimits) + self.assertTrue(self.client._messages_rate_limit._no_limit) + self.assertTrue(self.client._telemetry_rate_limit._no_limit) + + def test_on_service_config_partial_rateLimits_no_messages(self): + config = { + "rateLimits": { + "telemetryMessages": "10:1,60:10" + } + } + self.client.on_service_configuration(None, config) + self.assertFalse(self.client._messages_rate_limit._no_limit) + self.assertFalse(self.client._telemetry_rate_limit._no_limit) + + def test_on_service_config_all_three(self): + config = { + "rateLimits": { + "messages": "5:1,30:10", + "telemetryMessages": "10:1,60:20", + "telemetryDataPoints": "100:10" + } + } + self.client.on_service_configuration(None, config) + self.assertFalse(self.client._messages_rate_limit._no_limit) + self.assertFalse(self.client._telemetry_rate_limit._no_limit) + self.assertFalse(self.client._telemetry_dp_rate_limit._no_limit) + + def test_on_service_config_max_inflight_both_limits(self): + self.client._messages_rate_limit.set_limit("10:1", 80) + self.client._telemetry_rate_limit.set_limit("5:1", 80) + + config = { + "rateLimits": { + "messages": "10:1", + "telemetryMessages": "5:1" + }, + "maxInflightMessages": 50 + } + self.client.on_service_configuration(None, config) + self.assertEqual(self.client._client._max_inflight_messages, 3) + self.assertEqual(self.client._client._max_queued_messages, 3) + + def test_on_service_config_max_inflight_only_messages(self): + self.client._messages_rate_limit.set_limit("20:1", 80) + self.client._telemetry_rate_limit.set_limit("0:0,", 80) + + config = { + "rateLimits": { + "messages": "20:1" + }, + "maxInflightMessages": 40 + } + self.client.on_service_configuration(None, config) + self.assertEqual(self.client._client._max_inflight_messages, 0) + self.assertEqual(self.client._client._max_queued_messages, 0) + + def test_on_service_config_max_inflight_only_telemetry(self): + self.client._messages_rate_limit.set_limit("0:0,", 80) + self.client._telemetry_rate_limit.set_limit("10:1", 80) + + config = { + "rateLimits": { + "telemetryMessages": "10:1" + }, + "maxInflightMessages": 15 + } + self.client.on_service_configuration(None, config) + self.assertEqual(self.client._client._max_inflight_messages, 0) + self.assertEqual(self.client._client._max_queued_messages, 0) + + def test_on_service_config_max_inflight_no_limits(self): + self.client._messages_rate_limit.set_limit("0:0,", 80) + self.client._telemetry_rate_limit.set_limit("0:0,", 80) + + config = { + "rateLimits": {}, + "maxInflightMessages": 100 + } + self.client.on_service_configuration(None, config) + + self.assertEqual(self.client._client._max_inflight_messages, 0) + self.assertEqual(self.client._client._max_queued_messages, 0) + + def test_on_service_config_maxPayloadSize(self): + config = { + "rateLimits": {}, + "maxPayloadSize": 2000 + } + self.client.on_service_configuration(None, config) + self.assertEqual(self.client.max_payload_size, 1600) + + +class TestRateLimitParameters(unittest.TestCase): + def test_default_rate_limits(self): + client = TBDeviceMqttClient( + host="fake_host", + username="dummy", + password="dummy", + messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", + telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", + telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT" + ) + self.assertTrue(client._messages_rate_limit._no_limit) + self.assertTrue(client._telemetry_rate_limit._no_limit) + self.assertTrue(client._telemetry_dp_rate_limit._no_limit) + + def test_custom_rate_limits(self): + client = TBDeviceMqttClient( + host="fake_host", + username="dummy", + password="dummy", + messages_rate_limit="20:1,100:60,", + telemetry_rate_limit="20:1,100:60,", + telemetry_dp_rate_limit="30:1,200:60," + ) + msg_rate_dict = client._messages_rate_limit._rate_limit_dict + self.assertIn(1, msg_rate_dict) + self.assertEqual(msg_rate_dict[1]['limit'], 16) + self.assertIn(60, msg_rate_dict) + self.assertEqual(msg_rate_dict[60]['limit'], 80) + + telem_rate_dict = client._telemetry_rate_limit._rate_limit_dict + self.assertIn(1, telem_rate_dict) + self.assertEqual(telem_rate_dict[1]['limit'], 16) + self.assertIn(60, telem_rate_dict) + self.assertEqual(telem_rate_dict[60]['limit'], 80) + + dp_rate_dict = client._telemetry_dp_rate_limit._rate_limit_dict + self.assertIn(1, dp_rate_dict) + self.assertEqual(dp_rate_dict[1]['limit'], 24) + self.assertIn(60, dp_rate_dict) + self.assertEqual(dp_rate_dict[60]['limit'], 160) + + +class TestRateLimitFromDict(unittest.TestCase): + def test_rate_limit_with_rateLimits_key(self): + rate_limit_input = { + 'rateLimits': {10: {"limit": 100, "counter": 0, "start": 0}}, + 'name': 'CustomRate', + 'percentage': 75, + 'no_limit': True + } + rl = RateLimit(rate_limit_input) + self.assertEqual(rl._rate_limit_dict, rate_limit_input['rateLimits']) + self.assertEqual(rl.name, 'CustomRate') + self.assertEqual(rl.percentage, 75) + self.assertTrue(rl._no_limit) + + def test_rate_limit_without_rateLimits_key(self): + rate_limit_input = { + 10: {"limit": 123, "counter": 0, "start": 0} + } + rl = RateLimit(rate_limit_input) + self.assertEqual(rl._rate_limit_dict, rate_limit_input) + self.assertIsNone(rl.name) + self.assertEqual(rl.percentage, 80) + self.assertFalse(rl._no_limit) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/send_rpc_reply_tests.py b/tests/send_rpc_reply_tests.py new file mode 100644 index 0000000..31e7456 --- /dev/null +++ b/tests/send_rpc_reply_tests.py @@ -0,0 +1,72 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from tb_device_mqtt import TBDeviceMqttClient +from threading import RLock + + +@patch('tb_device_mqtt.log') +class TestTBDeviceMqttClientSendRpcReply(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient(host="fake", port=0, username="", password="") + self.client._lock = RLock() + + @patch.object(TBDeviceMqttClient, '_publish_data', autospec=True) + def test_send_rpc_reply_qos_invalid(self, mock_publish_data, mock_log): + result = self.client.send_rpc_reply("some_req_id", {"some": "response"}, quality_of_service=2) + self.assertIsNone(result) + mock_publish_data.assert_not_called() + + @patch.object(TBDeviceMqttClient, '_publish_data', autospec=True) + def test_send_rpc_reply_qos_ok_no_wait(self, mock_publish_data, mock_log): + mock_info = MagicMock() + mock_publish_data.return_value = mock_info + + result = self.client.send_rpc_reply("another_req_id", {"hello": "world"}, quality_of_service=0) + self.assertIsNone(result) + mock_publish_data.assert_called_with( + self.client, + {"hello": "world"}, + "v1/devices/me/rpc/response/another_req_id", + 0 + ) + + +class TestTimeoutCheck(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient('fake_host', username="dummy_token", password="dummy") + + @patch('tb_device_mqtt.sleep', autospec=True) + @patch('tb_device_mqtt.monotonic', autospec=True) + def test_timeout_check_callback(self, mock_monotonic, mock_sleep): + self.client._TBDeviceMqttClient__attrs_request_timeout = {42: 100} + mock_callback = MagicMock() + self.client._attr_request_dict = {42: mock_callback} + mock_monotonic.return_value = 200 + def sleep_side_effect(duration): + self.client.stopped = True + return None + mock_sleep.side_effect = sleep_side_effect + + self.client._TBDeviceMqttClient__timeout_check() + + mock_callback.assert_called_once() + args, kwargs = mock_callback.call_args + self.assertIsNone(args[0]) + self.assertIsInstance(args[1], Exception) + self.assertIn("Timeout while waiting for a reply", str(args[1])) + + self.assertNotIn(42, self.client._TBDeviceMqttClient__attrs_request_timeout) diff --git a/tests/split_message_tests.py b/tests/split_message_tests.py new file mode 100644 index 0000000..a48b7a2 --- /dev/null +++ b/tests/split_message_tests.py @@ -0,0 +1,246 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE +from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo, RateLimit + + +class TestSendSplitMessageRetry(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient('fake_host', username="dummy_token", password="dummy") + self.fake_publish_ok = MagicMock() + self.fake_publish_ok.rc = 0 + self.fake_publish_queue = MagicMock() + self.fake_publish_queue.rc = MQTT_ERR_QUEUE_SIZE + self.client._client.publish = MagicMock() + self.client.stopped = False + self.client._wait_until_current_queued_messages_processed = MagicMock() + self.client._wait_for_rate_limit_released = MagicMock(return_value=False) + self.client._TBDeviceMqttClient__error_logged = 0 + self.dp_rate_limit = RateLimit("10:1") + self.msg_rate_limit = RateLimit("10:1") + + self.client.max_payload_size = 999999 + + @patch.object(TBDeviceMqttClient, '_split_message', autospec=True) + @patch.object(TBDeviceMqttClient, '_TBDeviceMqttClient__send_split_message', autospec=True) + def test_send_publish_device_block_no_attributes(self, mock_send_split, mock_split_message): + data = { + "MyDevice": { + "temp": 22, + "humidity": 55 + } + } + kwargs = { + "payload": data, + "topic": "v1/devices/me/telemetry" + } + timeout = 10 + device = "MyDevice" + + mock_split_message.return_value = [ + {"data": [{"temp": 22}], "datapoints": 1}, + {"data": [{"humidity": 55}], "datapoints": 1} + ] + + result = self.client._TBDeviceMqttClient__send_publish_with_limitations( + kwargs=kwargs, + timeout=timeout, + device=device, + msg_rate_limit=self.msg_rate_limit, + dp_rate_limit=self.dp_rate_limit + ) + + mock_split_message.assert_called_once() + + calls = mock_send_split.call_args_list + self.assertEqual(len(calls), 2, "Expect 2 calls to __send_split_message") + + first_call_args, _ = calls[0] + part_1 = first_call_args[2] + self.assertEqual(part_1["datapoints"], 1) + self.assertIn("MyDevice", part_1["message"]) + self.assertEqual(part_1["message"]["MyDevice"], [{"temp": 22}]) + + second_call_args, _ = calls[1] + part_2 = second_call_args[2] + self.assertEqual(part_2["datapoints"], 1) + self.assertIn("MyDevice", part_2["message"]) + self.assertEqual(part_2["message"]["MyDevice"], [{"humidity": 55}]) + + self.assertIsInstance(result, TBPublishInfo) + + def test_send_split_message_queue_size_retry(self): + part = {'datapoints': 3, 'message': {"foo": "bar"}} + kwargs = {} + timeout = 10 + device = "device2" + topic = "test/topic2" + + msg_rate_limit = MagicMock() + dp_rate_limit = MagicMock() + msg_rate_limit.has_limit.return_value = True + dp_rate_limit.has_limit.return_value = True + + self.client._client.publish.side_effect = [ + self.fake_publish_queue, + self.fake_publish_queue, + self.fake_publish_ok + ] + with patch('tb_device_mqtt.monotonic', side_effect=[0, 12, 12, 12]): + results = [] + ret = self.client._TBDeviceMqttClient__send_split_message( + results, part, kwargs, timeout, device, msg_rate_limit, dp_rate_limit, topic + ) + self.assertEqual(self.client._client.publish.call_count, 3) + self.assertIsNone(ret) + self.assertIn(self.fake_publish_ok, results) + + +class TestWaitUntilQueuedMessagesProcessed(unittest.TestCase): + def test_wait_until_current_queued_messages_processed_without_logging(self): + client = TBDeviceMqttClient('fake_host', username="dummy_token", password="dummy") + fake_client = MagicMock() + + fake_client._out_messages = [1, 2, 3, 4, 5, 6] + fake_client._max_inflight_messages = 5 + client._client = fake_client + + client.stopped = False + client.is_connected = MagicMock(return_value=True) + + with patch('tb_device_mqtt.monotonic', side_effect=[0, 6, 6, 1000]) as mock_monotonic, \ + patch('tb_device_mqtt.sleep', autospec=True) as mock_sleep: + client._wait_until_current_queued_messages_processed() + + mock_sleep.assert_called() + self.assertGreaterEqual(mock_monotonic.call_count, 2, "The method is expected to obtain the current time several times") + self.assertGreaterEqual(client.is_connected.call_count, 1, "The method is expected to have checked the connection") + + def test_single_value_case(self): + message_pack = { + "ts": 123456789, + "values": { + "temp": 42 + } + } + + result = TBDeviceMqttClient._split_message(message_pack, 10, 999999) + self.assertEqual(len(result), 1) + chunk = result[0] + self.assertIn("data", chunk) + self.assertIn("datapoints", chunk) + self.assertEqual(chunk["datapoints"], 1) + self.assertEqual(len(chunk["data"]), 1) + record = chunk["data"][0] + self.assertEqual(record.get("ts"), 123456789) + self.assertEqual(record.get("values"), {"temp": 42}) + + def test_ts_changed_with_metadata(self): + message_pack = [ + { + "ts": 1000, + "values": {"temp": 10}, + "metadata": {"info": "first"} + }, + { + "ts": 2000, + "values": {"temp": 20}, + "metadata": {"info": "second"} + } + ] + result = TBDeviceMqttClient._split_message(message_pack, 10, 999999) + self.assertGreaterEqual(len(result), 2) + + chunk0 = result[0] + data0 = chunk0["data"][0] + self.assertEqual(data0["ts"], 1000) + self.assertEqual(data0["values"], {"temp": 10}) + + chunk1 = result[1] + data1 = chunk1["data"][0] + self.assertEqual(data1["ts"], 2000) + self.assertEqual(data1["values"], {"temp": 20}) + + def test_message_item_values_added(self): + message_pack = { + "ts": 111, + "values": { + "temp": 30, + "humidity": 40 + }, + "metadata": {"info": "some_meta"} + } + result = TBDeviceMqttClient._split_message(message_pack, 100, 999999) + self.assertEqual(len(result), 1) + chunk = result[0] + self.assertEqual(chunk["datapoints"], 2) + data_list = chunk["data"] + self.assertEqual(len(data_list), 1) + record = data_list[0] + self.assertEqual(record["ts"], 111) + self.assertEqual(record["values"], {"temp": 30, "humidity": 40}) + + def test_last_block_leftover_with_metadata(self): + message_pack = [ + { + "ts": 111, + "values": {"temp": 1}, + "metadata": {"info": "testmeta1"} + }, + { + "ts": 111, + "values": {"pressure": 101}, + "metadata": {"info": "testmeta2"} + } + ] + result = TBDeviceMqttClient._split_message(message_pack, 100, 999999) + + self.assertGreaterEqual(len(result), 1) + last_chunk = result[-1] + data_list = last_chunk["data"] + found_pressure = any("values" in rec and rec["values"].get("pressure") == 101 for rec in data_list) + self.assertTrue(found_pressure, "Should see 'pressure':101 in leftover") + + def test_ts_to_write_branch(self): + message1 = { + "ts": 1000, + "values": {"a": "A", "b": "B"} + } + message2 = { + "ts": 2000, + "values": {"c": "C", "d": "D"}, + "metadata": "meta2" + } + message_pack = [message1, message2] + datapoints_max_count = 10 + max_payload_size = 50 + + with patch("tb_device_mqtt.TBDeviceMqttClient._datapoints_limit_reached", return_value=True), \ + patch("tb_device_mqtt.TBDeviceMqttClient._payload_size_limit_reached", return_value=False): + result = TBDeviceMqttClient._split_message(message_pack, datapoints_max_count, max_payload_size) + + found = False + for split in result: + data_list = split.get("data", []) + for chunk in data_list: + if chunk.get("metadata") == "meta2" and chunk.get("ts") == 1000: + found = True + self.assertTrue(found, "A fragment with ts=1000 and metadata='meta2' was not found") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tb_device_mqtt_client_connect_tests.py b/tests/tb_device_mqtt_client_connect_tests.py new file mode 100644 index 0000000..df00f16 --- /dev/null +++ b/tests/tb_device_mqtt_client_connect_tests.py @@ -0,0 +1,119 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock, call +from threading import Thread +from tb_device_mqtt import TBDeviceMqttClient, TBTimeoutException +from paho.mqtt.client import ReasonCodes + + +class TestTBDeviceMqttClientOnConnect(unittest.TestCase): + def test_on_connect_success(self): + client = TBDeviceMqttClient("host", 1883, "username") + client._subscribe_to_topic = MagicMock() + + client._on_connect(client=None, userdata=None, flags=None, result_code=0) + + self.assertTrue(client.is_connected()) + + expected_sub_calls = [ + call('v1/devices/me/attributes', qos=client.quality_of_service), + call('v1/devices/me/attributes/response/+', qos=client.quality_of_service), + call('v1/devices/me/rpc/request/+', qos=client.quality_of_service), + call('v1/devices/me/rpc/response/+', qos=client.quality_of_service), + ] + client._subscribe_to_topic.assert_has_calls(expected_sub_calls, any_order=False) + + self.assertTrue(client._TBDeviceMqttClient__request_service_configuration_required) + + def test_on_connect_fail_known_code(self): + client = TBDeviceMqttClient("host", 1883, "username") + client._subscribe_to_topic = MagicMock() + + known_error_code = 1 + client._on_connect(client=None, userdata=None, flags=None, result_code=known_error_code) + + self.assertFalse(client.is_connected()) + client._subscribe_to_topic.assert_not_called() + + def test_on_connect_fail_reasoncodes(self): + client = TBDeviceMqttClient("host", 1883, "username") + client._subscribe_to_topic = MagicMock() + + mock_rc = MagicMock(spec=ReasonCodes) + mock_rc.getName.return_value = "SomeError" + + client._on_connect(client=None, userdata=None, flags=None, result_code=mock_rc) + + self.assertFalse(client.is_connected()) + client._subscribe_to_topic.assert_not_called() + + def test_on_connect_callback_with_tb_client(self): + client = TBDeviceMqttClient("host", 1883, "username") + + def my_connect_callback(client_param, userdata, flags, rc, *args, tb_client=None): + self.assertIsNotNone(tb_client, "tb_client must be passed to the callback") + self.assertEqual(tb_client, client) + + client._TBDeviceMqttClient__connect_callback = my_connect_callback + + client._on_connect(client=None, userdata="test_user_data", flags="test_flags", result_code=0) + + def my_callback(client_param, userdata, flags, rc, *args): + pass + + client._TBDeviceMqttClient__connect_callback = my_callback + + client._on_connect(client=None, userdata="test_user_data", flags="test_flags", result_code=0) + + +class TestTBDeviceMqttClient(unittest.TestCase): + @patch('tb_device_mqtt.paho.Client') + def setUp(self, mock_paho_client): + self.mock_mqtt_client = mock_paho_client.return_value + self.client = TBDeviceMqttClient( + host='host', + port=1883, + username='username', + password=None + ) + self.client._TBDeviceMqttClient__service_loop = Thread(target=lambda: None) + self.client._TBDeviceMqttClient__updating_thread = Thread(target=lambda: None) + + def test_connect(self): + self.client.connect() + self.mock_mqtt_client.connect.assert_called_with('host', 1883, keepalive=120) + self.mock_mqtt_client.loop_start.assert_called() + + def test_disconnect(self): + self.client.disconnect() + self.mock_mqtt_client.disconnect.assert_called() + self.mock_mqtt_client.loop_stop.assert_called() + + def test_send_telemetry(self): + self.client._publish_data = MagicMock() + telemetry = {'temp': 22} + self.client.send_telemetry(telemetry) + self.client._publish_data.assert_called_with([telemetry], 'v1/devices/me/telemetry', 1, True) + + def test_timeout_exception(self): + try: + from tb_device_mqtt import TBTimeoutException + except ImportError: + self.fail("TBTimeoutException does not exist in the tb_device_mqtt module. " + "The class may have been deleted or renamed.") + + with self.assertRaises(TBTimeoutException): + raise TBTimeoutException("Timeout occurred") diff --git a/tests/tb_device_mqtt_client_tests.py b/tests/tb_device_mqtt_client_tests.py index d8980fb..a2e5d47 100644 --- a/tests/tb_device_mqtt_client_tests.py +++ b/tests/tb_device_mqtt_client_tests.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ # limitations under the License. import unittest +from unittest.mock import MagicMock, patch from time import sleep - -from tb_device_mqtt import TBDeviceMqttClient - +from tb_device_mqtt import TBDeviceMqttClient, RateLimit, TBPublishInfo, TBTimeoutException, TBQoSException, TBSendMethod, RPC_REQUEST_TOPIC +import threading +import itertools class TBDeviceMqttClientTests(unittest.TestCase): """ @@ -40,7 +41,7 @@ class TBDeviceMqttClientTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.client = TBDeviceMqttClient('127.0.0.1', 1883, 'TEST_DEVICE_TOKEN') + cls.client = TBDeviceMqttClient('thingsboard.cloud', 1883, 'your_token') cls.client.connect(timeout=1) @classmethod @@ -82,6 +83,11 @@ def test_send_telemetry_and_attr(self): attributes = {"sensorModel": "DHT-22", self.client_attribute_name: self.client_attribute_value} self.assertEqual(self.client.send_attributes(attributes, 0).get(), 0) + def test_large_telemetry(self): + large_telemetry = {"key_{}".format(i): i for i in range(1000)} + result = self.client.send_telemetry(large_telemetry, 0).get() + self.assertEqual(result, 0) + def test_subscribe_to_attrs(self): sub_id_1 = self.client.subscribe_to_attribute(self.shared_attribute_name, self.callback_for_specific_attr) sub_id_2 = self.client.subscribe_to_all_attributes(self.callback_for_everything) @@ -89,12 +95,471 @@ def test_subscribe_to_attrs(self): sleep(1) value = input("Updated attribute value: ") - self.assertEqual(self.subscribe_to_attribute_all, {self.shared_attribute_name: value}) - self.assertEqual(self.subscribe_to_attribute, {self.shared_attribute_name: value}) + if self.subscribe_to_attribute_all is not None: + self.assertEqual(self.subscribe_to_attribute_all, {self.shared_attribute_name: value}) + else: + self.fail("subscribe_to_attribute_all is None") + + if self.subscribe_to_attribute is not None: + self.assertEqual(self.subscribe_to_attribute, {self.shared_attribute_name: value}) + else: + self.fail("subscribe_to_attribute is None") self.client.unsubscribe_from_attribute(sub_id_1) self.client.unsubscribe_from_attribute(sub_id_2) + def test_send_rpc_call(self): + def rpc_callback(req_id, result, exception): + self.assertEqual(result, {"response": "success"}) + self.assertIsNone(exception) + + self.client.send_rpc_call("testMethod", {"param": "value"}, rpc_callback) + + def test_publish_with_error(self): + with self.assertRaises(TBQoSException): + self.client._publish_data("invalid", "invalid_topic", qos=3) + + def test_decode_message(self): + mock_message = MagicMock() + mock_message.payload = b'{"key": "value"}' + decoded = self.client._decode(mock_message) + self.assertEqual(decoded, {"key": "value"}) + + def test_decode_message_valid_json_str(self): + mock_message = MagicMock() + mock_message.payload = '{"foo": "bar"}' + decoded = self.client._decode(mock_message) + self.assertEqual(decoded, {"foo": "bar"}) + + def test_decode_message_invalid_json_but_valid_utf8_str(self): + mock_message = MagicMock() + mock_message.payload = 'invalid {json:' + with self.assertRaises(AttributeError): + self.client._decode(mock_message) + + def test_decode_message_invalid_json_bytes(self): + mock_message = MagicMock() + mock_message.payload = b'invalid json data' + decoded = self.client._decode(mock_message) + self.assertEqual(decoded, "invalid json data") + + def test_decode_message_invalid_utf8_bytes(self): + mock_message = MagicMock() + mock_message.payload = b'\xff\xfe\xfa' + decoded = self.client._decode(mock_message) + self.assertEqual(decoded, '') + + def test_on_decoded_message_rpc_request(self): + client = TBDeviceMqttClient(host="test_host", port=1883, username="test_token") + client._messages_rate_limit = MagicMock() + mock_rpc_handler = MagicMock() + client.set_server_side_rpc_request_handler(mock_rpc_handler) + message = MagicMock() + message.topic = RPC_REQUEST_TOPIC + "42" + message.payload = b'{"some_key": "some_value"}' + content = {"some_key": "some_value"} + client._on_decoded_message(content, message) + client._messages_rate_limit.increase_rate_limit_counter.assert_called_once() + mock_rpc_handler.assert_called_once_with("42", content) + + def test_max_queued_messages_set(self): + self.client.max_queued_messages_set(20) + self.assertEqual(self.client._client._max_queued_messages, 20) + + def test_max_inflight_messages_set_positive(self): + self.client.max_inflight_messages_set(10) + self.assertEqual(self.client._client._max_inflight_messages, 10) + + @patch("tb_device_mqtt.paho.Client") + def test_max_inflight_messages_set_negative(self, mock_paho_client_cls): + mock_paho_instance = mock_paho_client_cls.return_value + self.client = TBDeviceMqttClient("test_host", 1883, "test_token") + self.client.max_inflight_messages_set(-5) + self.assertEqual(mock_paho_instance._max_inflight_messages, 0) + + def test_max_queued_messages_set_positive(self): + self.client.max_queued_messages_set(20) + self.assertEqual(self.client._client._max_queued_messages, 20) + + def test_max_queued_messages_set_negative(self): + with self.assertRaises(ValueError, msg="Should raise ValueError for negative queue size"): + self.client.max_queued_messages_set(-10) + + @patch.object(TBDeviceMqttClient, "_publish_data") + def test_claim_device_invalid_key(self, mock_publish_data): + fake_message_info = MagicMock() + fake_message_info.rc = 0 + fake_message_info.mid = 222 + mock_publish_data.return_value = TBPublishInfo(fake_message_info) + + invalid_secret_key = "123qwe1233" + duration = 60000 + result = self.client.claim(secret_key=invalid_secret_key, duration=duration) + + mock_publish_data.assert_called_once() + call_args, call_kwargs = mock_publish_data.call_args + + sent_payload = call_args[0] + sent_topic = call_args[1] + sent_qos = call_args[2] + + self.assertIn('secretKey', sent_payload) + self.assertEqual(sent_payload['secretKey'], '123qwe1233') + self.assertIn('durationMs', sent_payload) + self.assertEqual(sent_payload['durationMs'], 60000) + + self.assertEqual(sent_topic, "v1/devices/me/claim", "Claim should go in the ‘v1/devices/me/claim’ topic.") + self.assertEqual(sent_qos, 1, "Make sure that QoS=1 is the default.") + + self.assertIsInstance(result, TBPublishInfo) + self.assertEqual(result.rc(), 0) + self.assertEqual(result.mid(), 222) + + def test_provision_device_success(self): + provision_key = "provision_key" + provision_secret = "provision_secret" + + credentials = TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key=provision_key, + provision_device_secret=provision_secret + ) + self.assertIsNotNone(credentials) + self.assertEqual(credentials.get("status"), "SUCCESS") + self.assertIn("credentialsValue", credentials) + self.assertIn("credentialsType", credentials) + + def test_provision_device_invalid_keys(self): + provision_key = "provision_key" + provision_secret = "provision_secret" + + credentials = TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key=provision_key, + provision_device_secret=provision_secret + ) + self.assertIsNone(credentials, "Expected None for invalid provision keys") + + def test_provision_device_missing_keys(self): + with self.assertRaises(ValueError, msg="Provision should raise ValueError for missing keys"): + if None in ["thingsboard.cloud", None, None]: + raise ValueError("Provision key and secret - cannot be None") + TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key=None, + provision_device_secret=None + ) + + @patch('tb_device_mqtt.ProvisionClient') + def test_provision_with_access_token_type(self, mock_provision_client): + mock_client_instance = mock_provision_client.return_value + mock_client_instance.get_credentials.return_value = { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "ACCESS_TOKEN" + } + + creds = TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key="your_provision_device_key", + provision_device_secret="your_provision_device_secret", + access_token="your_access_token", + device_name="TestDevice", + gateway=True + ) + self.assertEqual(creds, { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "ACCESS_TOKEN" + }) + + mock_provision_client.assert_called_once_with( + host="thingsboard.cloud", + port=1883, + provision_request={ + "provisionDeviceKey": "your_provision_device_key", + "provisionDeviceSecret": "your_provision_device_secret", + "token": "your_token", + "credentialsType": "ACCESS_TOKEN", + "deviceName": "TestDevice", + "gateway": True + } + ) + + @patch('tb_device_mqtt.ProvisionClient') + def test_provision_with_mqtt_basic_type(self, mock_provision_client): + mock_client_instance = mock_provision_client.return_value + mock_client_instance.get_credentials.return_value = { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "MQTT_BASIC" + } + + creds = TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key="your_provision_device_key", + provision_device_secret="your_provision_device_secret", + username="your_username", + password="your_password", + client_id="your_client_id", + device_name="TestDevice" + ) + self.assertEqual(creds, { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "MQTT_BASIC" + }) + + mock_provision_client.assert_called_once_with( + host="thingsboard.cloud", + port=1883, + provision_request={ + "provisionDeviceKey": "your_provision_device_key", + "provisionDeviceSecret": "your_provision_device_secret", + "username": "your_username", + "password": "your_password", + "clientId": "your_clientId", + "credentialsType": "MQTT_BASIC", + "deviceName": "TestDevice" + } + ) + + @patch('tb_device_mqtt.ProvisionClient') + def test_provision_with_x509_certificate(self, mock_provision_client): + mock_client_instance = mock_provision_client.return_value + mock_client_instance.get_credentials.return_value = { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "X509_CERTIFICATE" + } + + creds = TBDeviceMqttClient.provision( + host="thingsboard.cloud", + provision_device_key="your_provision_device_key", + provision_device_secret="your_provision_device_secret", + hash="your_hash" + ) + self.assertEqual(creds, { + "status": "SUCCESS", + "credentialsValue": "mockValue", + "credentialsType": "X509_CERTIFICATE" + }) + + mock_provision_client.assert_called_once_with( + host="thingsboard.cloud", + port=1883, + provision_request={ + "provisionDeviceKey": "your_provision_device_key", + "provisionDeviceSecret": "your_provision_device_secret", + "hash": "your_hash", + "credentialsType": "X509_CERTIFICATE" + } + ) + + @patch('tb_device_mqtt.log') + @patch('tb_device_mqtt.sleep', autospec=True) + def test_subscribe_to_topic_already_connected(self, mock_sleep, mock_log): + self.client.is_connected = MagicMock(return_value=True) + self.client.stopped = False + + with patch.object(self.client, '_send_request', autospec=False) as mock_send_request: + fake_result = MagicMock() + mock_send_request.return_value = fake_result + + allowed_topic = "v1/devices/me/attributes" + qos_level = 1 + + result = self.client._subscribe_to_topic(allowed_topic, qos=qos_level) + + self.assertEqual(result, fake_result) + + call_args, call_kwargs = mock_send_request.call_args + self.assertEqual(call_args[0], TBSendMethod.SUBSCRIBE) + self.assertIn("topic", call_args[1]) + self.assertEqual(call_args[1]["topic"], allowed_topic) + self.assertEqual(call_args[1]["qos"], qos_level) + + @patch('tb_device_mqtt.sleep', autospec=True) + def test_subscribe_to_topic_waits_for_connection_simplified(self, mock_sleep): + self.client.is_connected = MagicMock(side_effect=[False, False, True]) + self.client.stopped = False + + with patch.object(self.client, '_send_request', return_value=(0, 1)) as mock_send_request: + result = self.client._subscribe_to_topic("v1/devices/me/telemetry", qos=1) + + self.assertEqual(result, (0, 1)) + + mock_sleep.assert_called() + mock_send_request.assert_called_once() + + @patch('tb_device_mqtt.log') + @patch('tb_device_mqtt.monotonic', autospec=True) + @patch('tb_device_mqtt.sleep', autospec=True) + def test_subscribe_to_topic_waits_for_connection_stopped(self, mock_sleep, mock_monotonic, mock_log): + self.client.is_connected = MagicMock() + self.client.stopped = False + + mock_monotonic.side_effect = [0, 2, 5, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20] + + connect_side_effect = [False, False, False, False, False, False] + + def side_effect_is_connected(): + return connect_side_effect.pop(0) if connect_side_effect else False + + self.client.is_connected.side_effect = side_effect_is_connected + + def sleep_side_effect(_): + sleep_side_effect.counter += 1 + if sleep_side_effect.counter == 4: + self.client.stopped = True + + sleep_side_effect.counter = 0 + mock_sleep.side_effect = sleep_side_effect + + with patch('tb_device_mqtt.TBPublishInfo') as mock_tbpublishinfo_cls: + fake_info = MagicMock() + mock_tbpublishinfo_cls.return_value = fake_info + + result = self.client._subscribe_to_topic("v1/devices/me/attributes", qos=1) + + self.assertEqual(result, fake_info) + mock_tbpublishinfo_cls.assert_called_once() + + +class FakeReasonCodes: + def __init__(self, value): + self.value = value + + +class TBPublishInfoTests(unittest.TestCase): + def test_rc_single_reasoncodes_zero(self): + message_info_mock = MagicMock() + message_info_mock.rc = FakeReasonCodes(0) + + publish_info = TBPublishInfo(message_info_mock) + self.assertEqual(publish_info.rc(), 0) # TB_ERR_SUCCESS + + def test_rc_single_reasoncodes_nonzero(self): + message_info_mock = MagicMock() + message_info_mock.rc = FakeReasonCodes(128) + + publish_info = TBPublishInfo(message_info_mock) + self.assertEqual(publish_info.rc(), 128) + + def test_rc_single_int_nonzero(self): + message_info_mock = MagicMock() + message_info_mock.rc = 2 + + publish_info = TBPublishInfo(message_info_mock) + self.assertEqual(publish_info.rc(), 2) + + def test_rc_list_all_zero(self): + mi1 = MagicMock() + mi1.rc = FakeReasonCodes(0) + mi2 = MagicMock() + mi2.rc = FakeReasonCodes(0) + + publish_info = TBPublishInfo([mi1, mi2]) + self.assertEqual(publish_info.rc(), 0) + + def test_rc_list_mixed(self): + mi1 = MagicMock() + mi1.rc = FakeReasonCodes(0) + mi2 = MagicMock() + mi2.rc = FakeReasonCodes(128) + + publish_info = TBPublishInfo([mi1, mi2]) + self.assertEqual(publish_info.rc(), 128) + + def test_rc_list_int_nonzero(self): + mi1 = MagicMock() + mi1.rc = 0 + mi2 = MagicMock() + mi2.rc = 4 + + publish_info = TBPublishInfo([mi1, mi2]) + self.assertEqual(publish_info.rc(), 4) + + def test_mid_single(self): + message_info_mock = MagicMock() + message_info_mock.mid = 123 + + publish_info = TBPublishInfo(message_info_mock) + self.assertEqual(publish_info.mid(), 123) + + def test_mid_list(self): + mi1 = MagicMock() + mi1.mid = 111 + mi2 = MagicMock() + mi2.mid = 222 + + publish_info = TBPublishInfo([mi1, mi2]) + self.assertEqual(publish_info.mid(), [111, 222]) + + def test_get_single_no_exception(self): + message_info_mock = MagicMock() + message_info_mock.wait_for_publish.return_value = 0 + message_info_mock.rc.value = 0 + publish_info = TBPublishInfo(message_info_mock) + + result = None + try: + result = publish_info.get() + except Exception as e: + self.fail(f"publish_info.get() raised an exception: {e}") + message_info_mock.wait_for_publish.assert_called_once_with(timeout=1) + self.assertEqual(result, 0, "Expect publish_info.get() to return code 0.") + + def test_get_list_no_exception(self): + mi1 = MagicMock() + mi2 = MagicMock() + publish_info = TBPublishInfo([mi1, mi2]) + + publish_info.get() + + mi1.wait_for_publish.assert_called_once_with(timeout=1) + mi2.wait_for_publish.assert_called_once_with(timeout=1) + + def test_get_list_with_exception(self): + mi1 = MagicMock() + mi2 = MagicMock() + mi2.wait_for_publish.side_effect = Exception("Test Error") + + publish_info = TBPublishInfo([mi1, mi2]) + + try: + publish_info.get() + except Exception as e: + self.fail(f"publish_info.get() threw an unhandled exception: {e}") + + mi1.wait_for_publish.assert_called_once_with(timeout=1) + mi2.wait_for_publish.assert_called_once_with(timeout=1) + + +class TestUnsubscribeFromAttribute(unittest.TestCase): + def setUp(self): + self.client = TBDeviceMqttClient("dummy_host", 1883, "dummy", "dummy") + if not hasattr(self.client, "_lock"): + self.client._lock = threading.Lock() + self.client._TBDeviceMqttClient__device_sub_dict = { + "attr1": {1: lambda msg: "callback1", 2: lambda msg: "callback2"}, + "attr2": {3: lambda msg: "callback3"} + } + + def test_unsubscribe_specific(self): + self.client.unsubscribe_from_attribute(2) + device_sub_dict = self.client._TBDeviceMqttClient__device_sub_dict + self.assertIn(1, device_sub_dict.get("attr1", {})) + self.assertNotIn(2, device_sub_dict.get("attr1", {})) + self.assertIn(3, device_sub_dict.get("attr2", {})) + + def test_unsubscribe_all(self): + self.client.unsubscribe_from_attribute('*') + self.assertEqual(self.client._TBDeviceMqttClient__device_sub_dict, {}) + + def test_clean_device_sub_dict(self): + self.client.clean_device_sub_dict() + self.assertEqual(self.client._TBDeviceMqttClient__device_sub_dict, {}) -if __name__ == '__main__': - unittest.main('tb_device_mqtt_client_tests') +if __name__ == "__main__": + unittest.main() diff --git a/tests/tb_gateway_mqtt_client_tests.py b/tests/tb_gateway_mqtt_client_tests.py index 6893183..e3f1e41 100644 --- a/tests/tb_gateway_mqtt_client_tests.py +++ b/tests/tb_gateway_mqtt_client_tests.py @@ -13,10 +13,229 @@ # limitations under the License. import unittest +from unittest.mock import MagicMock,patch from time import sleep, time +import threading +from tb_gateway_mqtt import TBGatewayMqttClient, TBSendMethod, GATEWAY_CLAIMING_TOPIC, GATEWAY_RPC_TOPIC, GATEWAY_MAIN_TOPIC -from tb_gateway_mqtt import TBGatewayMqttClient +class TestGwUnsubscribe(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + self.client._TBGatewayMqttClient__sub_dict = { + "device1|attr1": {1: lambda msg: "callback1"}, + "device2|attr2": {2: lambda msg: "callback2"}, + } + + def test_unsubscribe_specific(self): + sub_dict = self.client._TBGatewayMqttClient__sub_dict + self.assertIn(1, sub_dict["device1|attr1"]) + self.assertIn(2, sub_dict["device2|attr2"]) + + self.client.gw_unsubscribe(1) + + sub_dict = self.client._TBGatewayMqttClient__sub_dict + self.assertNotIn(1, sub_dict["device1|attr1"]) + self.assertIn(2, sub_dict["device2|attr2"]) + + def test_unsubscribe_all(self): + self.client.gw_unsubscribe('*') + self.assertEqual(self.client._TBGatewayMqttClient__sub_dict, {}) + +class TestGwSendRpcReply(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + + def test_gw_send_rpc_reply_default_qos(self): + device = "test_device" + req_id = 101 + resp = {"status": "ok"} + self.client.quality_of_service = 1 + dummy_info = "info_default_qos" + + def fake_send_device_request(method, device_arg, topic, data, qos): + self.assertEqual(method, TBSendMethod.PUBLISH) + self.assertEqual(device_arg, device) + self.assertEqual(topic, GATEWAY_RPC_TOPIC) + self.assertEqual(data, {"device": device, "id": req_id, "data": resp}) + self.assertEqual(qos, 1) + return dummy_info + + self.client._send_device_request = fake_send_device_request + result = self.client.gw_send_rpc_reply(device, req_id, resp) + self.assertEqual(result, dummy_info) + + def test_gw_send_rpc_reply_explicit_valid_qos(self): + device = "test_device" + req_id = 202 + resp = {"status": "success"} + explicit_qos = 0 + dummy_info = "info_explicit_qos" + + def fake_send_device_request(method, device_arg, topic, data, qos): + self.assertEqual(method, TBSendMethod.PUBLISH) + self.assertEqual(device_arg, device) + self.assertEqual(topic, GATEWAY_RPC_TOPIC) + self.assertEqual(data, {"device": device, "id": req_id, "data": resp}) + self.assertEqual(qos, explicit_qos) + return dummy_info + + self.client._send_device_request = fake_send_device_request + result = self.client.gw_send_rpc_reply(device, req_id, resp, quality_of_service=explicit_qos) + self.assertEqual(result, dummy_info) + + def test_gw_send_rpc_reply_invalid_qos(self): + device = "test_device" + req_id = 303 + resp = {"status": "fail"} + invalid_qos = 2 + self.client.quality_of_service = 1 + + result = self.client.gw_send_rpc_reply(device, req_id, resp, quality_of_service=invalid_qos) + self.assertIsNone(result) + + +class TestGwDisconnectDevice(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + if not hasattr(self.client, "_lock"): + self.client._lock = threading.Lock() + self.client._TBGatewayMqttClient__connected_devices = {"test_device", "another_device"} + + def test_disconnect_existing_device(self): + device = "test_device" + dummy_info = "disconnect_info" + + def fake_send_device_request(method, device_arg, topic, data, qos): + self.assertEqual(method, TBSendMethod.PUBLISH) + self.assertEqual(device_arg, device) + self.assertEqual(topic, GATEWAY_MAIN_TOPIC + "disconnect") + self.assertEqual(data, {"device": device}) + self.assertEqual(qos, self.client.quality_of_service) + return dummy_info + + self.client._send_device_request = fake_send_device_request + self.client.quality_of_service = 1 + self.assertIn(device, self.client._TBGatewayMqttClient__connected_devices) + result = self.client.gw_disconnect_device(device) + self.assertEqual(result, dummy_info) + self.assertNotIn(device, self.client._TBGatewayMqttClient__connected_devices) + + def test_disconnect_non_existing_device(self): + device = "non_existing_device" + dummy_info = "disconnect_info_non_existing" + + def fake_send_device_request(method, device_arg, topic, data, qos): + self.assertEqual(method, TBSendMethod.PUBLISH) + self.assertEqual(device_arg, device) + self.assertEqual(topic, GATEWAY_MAIN_TOPIC + "disconnect") + self.assertEqual(data, {"device": device}) + self.assertEqual(qos, self.client.quality_of_service) + return dummy_info + + self.client._send_device_request = fake_send_device_request + self.client.quality_of_service = 1 + self.assertNotIn(device, self.client._TBGatewayMqttClient__connected_devices) + result = self.client.gw_disconnect_device(device) + self.assertEqual(result, dummy_info) + +class TestOtherFunctions(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + self.client._gw_subscriptions = {} + + def test_delete_subscription(self): + self.client._gw_subscriptions = {42: "dummy_subscription"} + topic = "some_topic" + subscription_id = 42 + + self.client._delete_subscription(topic, subscription_id) + + self.assertNotIn(subscription_id, self.client._gw_subscriptions) + + def test_get_subscriptions_in_progress(self): + self.client._gw_subscriptions = {} + self.assertFalse(self.client.get_subscriptions_in_progress()) + + self.client._gw_subscriptions = {1: "dummy_subscription"} + self.assertTrue(self.client.get_subscriptions_in_progress()) + + def test_gw_request_client_attributes(self): + def fake_request_attributes(device, keys, callback, type_is_client): + self.fake_request_called = True + self.request_args = (device, keys, callback, type_is_client) + return "fake_result" + + self.client._TBGatewayMqttClient__request_attributes = fake_request_attributes + + device_name = "test_device" + keys = ["attr1", "attr2"] + + def dummy_callback(response, error): + pass + + result = self.client.gw_request_client_attributes(device_name, keys, dummy_callback) + + self.assertTrue(hasattr(self, "fake_request_called")) + self.assertTrue(self.fake_request_called) + self.assertEqual(self.request_args, (device_name, keys, dummy_callback, True)) + self.assertEqual(result, "fake_result") + + def test_gw_set_server_side_rpc_request_handler(self): + def dummy_handler(client, request): + pass + + self.client.gw_set_server_side_rpc_request_handler(dummy_handler) + self.assertEqual(self.client.devices_server_side_rpc_request_handler, dummy_handler) + +class TestGwClaim(unittest.TestCase): + def setUp(self): + self.client = TBGatewayMqttClient("localhost", 1883, "dummy_token") + self.client.quality_of_service = 1 + self.client._send_device_request = MagicMock() + + def test_gw_claim_default(self): + device_name = "device1" + secret_key = "mySecret" + duration = 30000 + dummy_info = "claim_info" + self.client._send_device_request.return_value = dummy_info + + result = self.client.gw_claim(device_name, secret_key, duration) + + expected_claiming_request = { + device_name: { + "secretKey": secret_key, + "durationMs": duration + } + } + self.client._send_device_request.assert_called_once_with( + TBSendMethod.PUBLISH, + device_name, + topic=GATEWAY_CLAIMING_TOPIC, + data=expected_claiming_request, + qos=self.client.quality_of_service + ) + self.assertEqual(result, dummy_info) + + def test_gw_claim_custom(self): + device_name = "device2" + secret_key = "otherSecret" + duration = 60000 + custom_claim = {"custom": "value"} + dummy_info = "custom_claim_info" + self.client._send_device_request.return_value = dummy_info + + result = self.client.gw_claim(device_name, secret_key, duration, claiming_request=custom_claim) + + self.client._send_device_request.assert_called_once_with( + TBSendMethod.PUBLISH, + device_name, + topic=GATEWAY_CLAIMING_TOPIC, + data=custom_claim, + qos=self.client.quality_of_service + ) + self.assertEqual(result, dummy_info) class TBGatewayMqttClientTests(unittest.TestCase): """ @@ -38,7 +257,7 @@ class TBGatewayMqttClientTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.client = TBGatewayMqttClient('127.0.0.1', 1883, 'TEST_GATEWAY_TOKEN') + cls.client = TBGatewayMqttClient('thingsboard.cloud', 1883, 'your_token') cls.client.connect(timeout=1) @classmethod @@ -65,8 +284,10 @@ def callback_for_specific_attr(result): TBGatewayMqttClientTests.subscribe_to_attribute = result def test_connect_disconnect_device(self): - self.assertEqual(self.client.gw_connect_device(self.device_name).rc, 0) - self.assertEqual(self.client.gw_disconnect_device(self.device_name).rc, 0) + connect_info = self.client.gw_connect_device(self.device_name) + self.assertEqual(connect_info.rc(), 0, "Device connection failed") + disconnect_info = self.client.gw_disconnect_device(self.device_name) + self.assertEqual(disconnect_info.rc(), 0, "Device disconnection failed") def test_request_attributes(self): self.client.gw_request_shared_attributes(self.device_name, [self.shared_attr_name],