Skip to content
57 changes: 41 additions & 16 deletions dissect/target/loaders/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def _read(self, offset: int, length: int, optimization_strategy: int = 0) -> byt
class MQTTConnection:
broker = None
host = None
prev = -1
factor = 1
prefetch_factor_inc = 10

def __init__(self, broker: Broker, host: str):
self.broker = broker
Expand Down Expand Up @@ -95,20 +98,32 @@ def info(self) -> list[MQTTStream]:

def read(self, disk_id: int, offset: int, length: int, optimization_strategy: int) -> bytes:
message = None
self.broker.seek(self.host, disk_id, offset, length, optimization_strategy)

message = self.broker.read(self.host, disk_id, offset, length)
if message:
return message.data

if self.prev == offset - (length * self.factor):
if self.factor < 500:
self.factor += self.prefetch_factor_inc
else:
self.factor = 1

self.prev = offset
flength = length * self.factor
self.broker.factor = self.factor
self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy)
attempts = 0
while True:
message = self.broker.read(self.host, disk_id, offset, length)
# don't waste time with sleep if we have a response
if message:
if message := self.broker.read(self.host, disk_id, offset, length):
# don't waste time with sleep if we have a response
break

attempts += 1
time.sleep(0.01)
if attempts > 100:
time.sleep(0.1)
if attempts > 300:
# message might have not reached agent, resend...
self.broker.seek(self.host, disk_id, offset, length, optimization_strategy)
self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy)
attempts = 0

return message.data
Expand All @@ -127,6 +142,7 @@ class Broker:
diskinfo = {}
index = {}
topo = {}
factor = 1

def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case: str, **kwargs):
self.broker_host = broker
Expand All @@ -137,10 +153,13 @@ def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case:
self.case = case
self.command = kwargs.get("command", None)

def clear_cache(self) -> None:
self.index = {}

@suppress
def read(self, host: str, disk_id: int, seek_address: int, read_length: int) -> SeekMessage:
key = f"{host}-{disk_id}-{seek_address}-{read_length}"
return self.index.pop(key)
return self.index.get(key)

@suppress
def disk(self, host: str) -> DiskMessage:
Expand All @@ -165,14 +184,15 @@ def _on_read(self, hostname: str, tokens: list[str], payload: bytes) -> None:
disk_id = tokens[3]
seek_address = int(tokens[4], 16)
read_length = int(tokens[5], 16)
msg = SeekMessage(data=payload)

key = f"{hostname}-{disk_id}-{seek_address}-{read_length}"
for i in range(self.factor):
sublength = int(read_length / self.factor)
start = i * sublength
key = f"{hostname}-{disk_id}-{seek_address+start}-{sublength}"
if key in self.index:
continue

if key in self.index:
return

self.index[key] = msg
self.index[key] = SeekMessage(data=payload[start : start + sublength])

def _on_id(self, hostname: str, payload: bytes) -> None:
key = hostname
Expand Down Expand Up @@ -204,9 +224,14 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.client.MQTTM
elif response == "ID":
self._on_id(hostname, msg.payload)

def seek(self, host: str, disk_id: int, offset: int, length: int, optimization_strategy: int) -> None:
def seek(self, host: str, disk_id: int, offset: int, flength: int, optimization_strategy: int) -> None:
length = int(flength / self.factor)
key = f"{host}-{disk_id}-{offset}-{length}"
if key in self.index:
return

self.mqtt_client.publish(
f"{self.case}/{host}/SEEK/{disk_id}/{hex(offset)}/{hex(length)}", pack("<I", optimization_strategy)
f"{self.case}/{host}/SEEK/{disk_id}/{hex(offset)}/{hex(flength)}", pack("<I", optimization_strategy)
)

def info(self, host: str) -> None:
Expand Down
3 changes: 1 addition & 2 deletions dissect/target/tools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def main():
collected_plugins = {}

if targets:
for target in targets:
plugin_target = Target.open(target)
for plugin_target in Target.open_all(targets, args.children):
if isinstance(plugin_target._loader, ProxyLoader):
parser.error("can't list compatible plugins for remote targets.")
funcs, _ = find_plugin_functions(plugin_target, args.list, compatibility=True, show_hidden=True)
Expand Down
44 changes: 44 additions & 0 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import sys
import time
from dataclasses import dataclass
from struct import pack
from typing import Iterator
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -44,6 +47,24 @@ def publish(self, topic: str, *args) -> None:
self.on_message(self, None, response)


@dataclass
class MockSeekMessage:
data: bytes = b""


class MockBroker(MagicMock):
_seek = False

def seek(self, *args) -> None:
self._seek = True

def read(self, *args) -> MockSeekMessage | None:
if self._seek:
self._seek = False
return MockSeekMessage(data=b"010101")
return None


@pytest.fixture
def mock_paho(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]:
with monkeypatch.context() as m:
Expand All @@ -62,6 +83,11 @@ def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]:
yield mock_client


@pytest.fixture
def mock_broker() -> Iterator[MockBroker]:
yield MockBroker()


@pytest.mark.parametrize(
"alias, hosts, disks, disk, seek, read, expected",
[
Expand Down Expand Up @@ -102,3 +128,21 @@ def test_remote_loader_stream(
target.disks[disk].seek(seek)
data = target.disks[disk].read(read)
assert data == expected


def test_mqtt_loader_prefetch(mock_broker: MockBroker) -> None:
from dissect.target.loaders.mqtt import MQTTConnection

connection = MQTTConnection(mock_broker, "")
connection.prefetch_factor_inc = 10
assert connection.factor == 1
assert connection.prev == -1
connection.read(1, 0, 100, 0)
assert connection.factor == 1
assert connection.prev == 0
connection.read(1, 100, 100, 0)
assert connection.factor == connection.prefetch_factor_inc + 1
assert connection.prev == 100
connection.read(1, 1200, 100, 0)
assert connection.factor == (connection.prefetch_factor_inc * 2) + 1
assert connection.prev == 1200