Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 60 additions & 117 deletions src/logcollector/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,19 @@ async def start(self) -> None:
)

task_fetch = asyncio.Task(self.fetch())
task_send = asyncio.Task(self.send())

try:
await asyncio.gather(
task_fetch,
task_send,
)
except KeyboardInterrupt:
task_fetch.cancel()
task_send.cancel()

logger.info("LogCollector stopped.")

async def fetch(self) -> None:
"""Starts a loop to continuously listen on the configured Kafka topic. If a message is consumed, it is
decoded and stored."""
decoded and sent."""
loop = asyncio.get_running_loop()

while True:
Expand All @@ -90,134 +87,80 @@ async def fetch(self) -> None:
)
logger.debug(f"From Kafka: '{value}'")

await self.store(datetime.datetime.now(), value)
self.send(datetime.datetime.now(), value)

async def send(self) -> None:
"""Continuously sends the next logline in JSON format to the BatchSender, where it is stored in
def send(self, timestamp_in: datetime.datetime, message: str) -> None:
"""Sends the logline in JSON format to the BatchSender, where it is stored in
a temporary batch before being sent to the :class:`Prefilter`. Adds the subnet ID to the message.

Args:
timestamp_in (datetime.datetime): Timestamp of entering the pipeline
message (str): Message to be stored
"""

try:
while True:
if not self.loglines.empty():
timestamp_in, logline = await self.loglines.get()

self.fill_levels.insert(
dict(
timestamp=datetime.datetime.now(),
stage=module_name,
entry_type="total_loglines",
entry_count=self.loglines.qsize(),
)
)

try:
fields = self.logline_handler.validate_logline_and_get_fields_as_json(
logline
)
except ValueError:
self.failed_dns_loglines.insert(
dict(
message_text=logline,
timestamp_in=timestamp_in,
timestamp_failed=datetime.datetime.now(),
reason_for_failure=None, # TODO: Add actual reason
)
)
continue

subnet_id = self._get_subnet_id(
ipaddress.ip_address(fields.get("client_ip"))
)

additional_fields = fields.copy()
for field in REQUIRED_FIELDS:
additional_fields.pop(field)

logline_id = uuid.uuid4()

self.dns_loglines.insert(
dict(
logline_id=logline_id,
subnet_id=subnet_id,
timestamp=datetime.datetime.strptime(
fields.get("timestamp"), TIMESTAMP_FORMAT
),
status_code=fields.get("status_code"),
client_ip=fields.get("client_ip"),
record_type=fields.get("record_type"),
additional_fields=json.dumps(additional_fields),
)
)

self.logline_timestamps.insert(
dict(
logline_id=logline_id,
stage=module_name,
status="in_process",
timestamp=timestamp_in,
is_active=True,
)
)

message_fields = fields.copy()
message_fields["logline_id"] = str(logline_id)

self.logline_timestamps.insert(
dict(
logline_id=logline_id,
stage=module_name,
status="finished",
timestamp=datetime.datetime.now(),
is_active=True,
)
)

self.batch_handler.add_message(
subnet_id, json.dumps(message_fields)
)
logger.debug(f"Sent: '{logline}'")
else:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
while not self.loglines.empty():
logline = await self.loglines.get()

self.fill_levels.insert(
dict(
timestamp=datetime.datetime.now(),
stage=module_name,
entry_type="total_loglines",
entry_count=self.loglines.qsize(),
)
fields = self.logline_handler.validate_logline_and_get_fields_as_json(
message
)
except ValueError:
self.failed_dns_loglines.insert(
dict(
message_text=message,
timestamp_in=timestamp_in,
timestamp_failed=datetime.datetime.now(),
reason_for_failure=None, # TODO: Add actual reason
)
)
return

fields = self.logline_handler.validate_logline_and_get_fields_as_json(
logline
)
subnet_id = self._get_subnet_id(
ipaddress.ip_address(fields.get("client_ip"))
)
subnet_id = self._get_subnet_id(ipaddress.ip_address(fields.get("client_ip")))

self.batch_handler.add_message(subnet_id, json.dumps(fields))
additional_fields = fields.copy()
for field in REQUIRED_FIELDS:
additional_fields.pop(field)

async def store(self, timestamp_in: datetime.datetime, message: str):
"""Stores the message temporarily.
logline_id = uuid.uuid4()

Args:
timestamp_in (datetime.datetime): Timestamp of entering the pipeline
message (str): Message to be stored
"""
await self.loglines.put((timestamp_in, message))
self.dns_loglines.insert(
dict(
logline_id=logline_id,
subnet_id=subnet_id,
timestamp=datetime.datetime.strptime(
fields.get("timestamp"), TIMESTAMP_FORMAT
),
status_code=fields.get("status_code"),
client_ip=fields.get("client_ip"),
record_type=fields.get("record_type"),
additional_fields=json.dumps(additional_fields),
)
)

self.fill_levels.insert(
self.logline_timestamps.insert(
dict(
timestamp=datetime.datetime.now(),
logline_id=logline_id,
stage=module_name,
entry_type="total_loglines",
entry_count=self.loglines.qsize(),
status="in_process",
timestamp=timestamp_in,
is_active=True,
)
)

message_fields = fields.copy()
message_fields["logline_id"] = str(logline_id)

self.logline_timestamps.insert(
dict(
logline_id=logline_id,
stage=module_name,
status="finished",
timestamp=datetime.datetime.now(),
is_active=True,
)
)

self.batch_handler.add_message(subnet_id, json.dumps(message_fields))
logger.debug(f"Sent: '{message}'")

@staticmethod
def _get_subnet_id(address: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
"""
Expand Down
39 changes: 5 additions & 34 deletions tests/logcollector/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def setUp(
async def test_start_successful_execution(self):
# Arrange
self.sut.fetch = AsyncMock()
self.sut.send = AsyncMock()

async def mock_gather(*args, **kwargs):
return None
Expand All @@ -72,12 +71,10 @@ async def mock_gather(*args, **kwargs):
# Assert
mock.assert_called_once()
self.sut.fetch.assert_called_once()
self.sut.send.assert_called_once()

async def test_start_handles_keyboard_interrupt(self):
# Arrange
self.sut.fetch = AsyncMock()
self.sut.send = AsyncMock()

async def mock_gather(*args, **kwargs):
raise KeyboardInterrupt
Expand All @@ -91,7 +88,6 @@ async def mock_gather(*args, **kwargs):
# Assert
mock.assert_called_once()
self.sut.fetch.assert_called_once()
self.sut.send.assert_called_once()


class TestFetch(unittest.IsolatedAsyncioTestCase):
Expand All @@ -109,15 +105,15 @@ async def asyncSetUp(
self.sut = LogCollector()
self.sut.kafka_consume_handler = AsyncMock()

@patch("src.logcollector.collector.LogCollector.store")
@patch("src.logcollector.collector.LogCollector.send")
@patch("src.logcollector.collector.logger")
@patch("asyncio.get_running_loop")
@patch("src.logcollector.collector.ClickHouseKafkaSender")
async def test_handle_kafka_inputs(
self, mock_clickhouse, mock_get_running_loop, mock_logger, mock_store
self, mock_clickhouse, mock_get_running_loop, mock_logger, mock_send
):
mock_store_instance = AsyncMock()
mock_store.return_value = mock_store_instance
mock_send_instance = AsyncMock()
mock_send.return_value = mock_send_instance
mock_loop = AsyncMock()
mock_get_running_loop.return_value = mock_loop
self.sut.kafka_consume_handler.consume.return_value = (
Expand All @@ -134,7 +130,7 @@ async def test_handle_kafka_inputs(
with self.assertRaises(asyncio.CancelledError):
await self.sut.fetch()

mock_store.assert_called_once()
mock_send.assert_called_once()


class TestSend(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -341,31 +337,6 @@ async def test_send_value_error(
)


class TestStore(unittest.IsolatedAsyncioTestCase):
@patch("src.logcollector.collector.ExactlyOnceKafkaConsumeHandler")
@patch("src.logcollector.collector.BufferedBatchSender")
@patch("src.logcollector.collector.LoglineHandler")
@patch("src.logcollector.collector.ClickHouseKafkaSender")
async def test_store(
self,
mock_clickhouse,
mock_logline_handler,
mock_batch_handler,
mock_kafka_consume_handler,
):
# Arrange
sut = LogCollector()
self.assertTrue(sut.loglines.empty())

# Act
await sut.store(datetime.datetime.now(), "test_message")

# Assert
stored_timestamp, stored_message = await sut.loglines.get()
self.assertEqual("test_message", stored_message)
self.assertTrue(sut.loglines.empty())


class TestGetSubnetId(unittest.TestCase):
@patch("src.logcollector.collector.IPV4_PREFIX_LENGTH", 24)
@patch("src.logcollector.collector.ExactlyOnceKafkaConsumeHandler")
Expand Down
Loading