diff --git a/dissect/target/loaders/mqtt.py b/dissect/target/loaders/mqtt.py index d9154871d3..5ef7442168 100644 --- a/dissect/target/loaders/mqtt.py +++ b/dissect/target/loaders/mqtt.py @@ -65,6 +65,9 @@ def _read(self, offset: int, length: int, optimization_strategy: int = 0) -> byt class MQTTConnection: broker = None host = None + prev = -1 + factor = 1 + prefetch_factor_inc = 10 def __init__(self, broker: Broker, host: str): self.broker = broker @@ -95,20 +98,32 @@ def info(self) -> list[MQTTStream]: def read(self, disk_id: int, offset: int, length: int, optimization_strategy: int) -> bytes: message = None - self.broker.seek(self.host, disk_id, offset, length, optimization_strategy) + message = self.broker.read(self.host, disk_id, offset, length) + if message: + return message.data + + if self.prev == offset - (length * self.factor): + if self.factor < 500: + self.factor += self.prefetch_factor_inc + else: + self.factor = 1 + + self.prev = offset + flength = length * self.factor + self.broker.factor = self.factor + self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy) attempts = 0 while True: - message = self.broker.read(self.host, disk_id, offset, length) - # don't waste time with sleep if we have a response - if message: + if message := self.broker.read(self.host, disk_id, offset, length): + # don't waste time with sleep if we have a response break attempts += 1 - time.sleep(0.01) - if attempts > 100: + time.sleep(0.1) + if attempts > 300: # message might have not reached agent, resend... - self.broker.seek(self.host, disk_id, offset, length, optimization_strategy) + self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy) attempts = 0 return message.data @@ -127,6 +142,7 @@ class Broker: diskinfo = {} index = {} topo = {} + factor = 1 def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case: str, **kwargs): self.broker_host = broker @@ -137,10 +153,13 @@ def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case: self.case = case self.command = kwargs.get("command", None) + def clear_cache(self) -> None: + self.index = {} + @suppress def read(self, host: str, disk_id: int, seek_address: int, read_length: int) -> SeekMessage: key = f"{host}-{disk_id}-{seek_address}-{read_length}" - return self.index.pop(key) + return self.index.get(key) @suppress def disk(self, host: str) -> DiskMessage: @@ -165,14 +184,15 @@ def _on_read(self, hostname: str, tokens: list[str], payload: bytes) -> None: disk_id = tokens[3] seek_address = int(tokens[4], 16) read_length = int(tokens[5], 16) - msg = SeekMessage(data=payload) - key = f"{hostname}-{disk_id}-{seek_address}-{read_length}" + for i in range(self.factor): + sublength = int(read_length / self.factor) + start = i * sublength + key = f"{hostname}-{disk_id}-{seek_address+start}-{sublength}" + if key in self.index: + continue - if key in self.index: - return - - self.index[key] = msg + self.index[key] = SeekMessage(data=payload[start : start + sublength]) def _on_id(self, hostname: str, payload: bytes) -> None: key = hostname @@ -204,9 +224,14 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.client.MQTTM elif response == "ID": self._on_id(hostname, msg.payload) - def seek(self, host: str, disk_id: int, offset: int, length: int, optimization_strategy: int) -> None: + def seek(self, host: str, disk_id: int, offset: int, flength: int, optimization_strategy: int) -> None: + length = int(flength / self.factor) + key = f"{host}-{disk_id}-{offset}-{length}" + if key in self.index: + return + self.mqtt_client.publish( - f"{self.case}/{host}/SEEK/{disk_id}/{hex(offset)}/{hex(length)}", pack(" None: diff --git a/dissect/target/tools/query.py b/dissect/target/tools/query.py index 98159b66b8..f861b9a26e 100644 --- a/dissect/target/tools/query.py +++ b/dissect/target/tools/query.py @@ -173,8 +173,7 @@ def main(): collected_plugins = {} if targets: - for target in targets: - plugin_target = Target.open(target) + for plugin_target in Target.open_all(targets, args.children): if isinstance(plugin_target._loader, ProxyLoader): parser.error("can't list compatible plugins for remote targets.") funcs, _ = find_plugin_functions(plugin_target, args.list, compatibility=True, show_hidden=True) diff --git a/tests/loaders/test_mqtt.py b/tests/loaders/test_mqtt.py index 9e9eefc795..7cba68788f 100644 --- a/tests/loaders/test_mqtt.py +++ b/tests/loaders/test_mqtt.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import sys import time +from dataclasses import dataclass from struct import pack from typing import Iterator from unittest.mock import MagicMock, patch @@ -44,6 +47,24 @@ def publish(self, topic: str, *args) -> None: self.on_message(self, None, response) +@dataclass +class MockSeekMessage: + data: bytes = b"" + + +class MockBroker(MagicMock): + _seek = False + + def seek(self, *args) -> None: + self._seek = True + + def read(self, *args) -> MockSeekMessage | None: + if self._seek: + self._seek = False + return MockSeekMessage(data=b"010101") + return None + + @pytest.fixture def mock_paho(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]: with monkeypatch.context() as m: @@ -62,6 +83,11 @@ def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]: yield mock_client +@pytest.fixture +def mock_broker() -> Iterator[MockBroker]: + yield MockBroker() + + @pytest.mark.parametrize( "alias, hosts, disks, disk, seek, read, expected", [ @@ -102,3 +128,21 @@ def test_remote_loader_stream( target.disks[disk].seek(seek) data = target.disks[disk].read(read) assert data == expected + + +def test_mqtt_loader_prefetch(mock_broker: MockBroker) -> None: + from dissect.target.loaders.mqtt import MQTTConnection + + connection = MQTTConnection(mock_broker, "") + connection.prefetch_factor_inc = 10 + assert connection.factor == 1 + assert connection.prev == -1 + connection.read(1, 0, 100, 0) + assert connection.factor == 1 + assert connection.prev == 0 + connection.read(1, 100, 100, 0) + assert connection.factor == connection.prefetch_factor_inc + 1 + assert connection.prev == 100 + connection.read(1, 1200, 100, 0) + assert connection.factor == (connection.prefetch_factor_inc * 2) + 1 + assert connection.prev == 1200