diff --git a/dissect/util/compression/lz4.py b/dissect/util/compression/lz4.py index 6b8adf8..5ddd622 100644 --- a/dissect/util/compression/lz4.py +++ b/dissect/util/compression/lz4.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import io import struct -from typing import BinaryIO, Union +from typing import BinaryIO from dissect.util.exceptions import CorruptDataError @@ -23,12 +25,12 @@ def _get_length(src: BinaryIO, length: int) -> int: def decompress( - src: Union[bytes, BinaryIO], + src: bytes | BinaryIO, uncompressed_size: int = -1, max_length: int = -1, return_bytearray: bool = False, return_bytes_read: bool = False, -) -> Union[bytes, tuple[bytes, int]]: +) -> bytes | tuple[bytes, int]: """LZ4 decompress from a file-like object up to a certain length. Assumes no header. Args: @@ -92,5 +94,5 @@ def decompress( if return_bytes_read: return dst, src.tell() - start - else: - return dst + + return dst diff --git a/dissect/util/compression/lznt1.py b/dissect/util/compression/lznt1.py index bd0d25e..5981b86 100644 --- a/dissect/util/compression/lznt1.py +++ b/dissect/util/compression/lznt1.py @@ -1,8 +1,10 @@ # Reference: https://github.com/google/rekall/blob/master/rekall-core/rekall/plugins/filesystems/lznt1.py +from __future__ import annotations + import array import io import struct -from typing import BinaryIO, Union +from typing import BinaryIO def _get_displacement(offset: int) -> int: @@ -20,10 +22,10 @@ def _get_displacement(offset: int) -> int: COMPRESSED_MASK = 1 << 15 SIGNATURE_MASK = 3 << 12 SIZE_MASK = (1 << 12) - 1 -TAG_MASKS = [(1 << i) for i in range(0, 8)] +TAG_MASKS = [(1 << i) for i in range(8)] -def decompress(src: Union[bytes, BinaryIO]) -> bytes: +def decompress(src: bytes | BinaryIO) -> bytes: """LZNT1 decompress from a file-like object or bytes. Args: @@ -87,5 +89,4 @@ def decompress(src: Union[bytes, BinaryIO]) -> bytes: data = src.read(hsize + 1) dst.write(data) - result = dst.getvalue() - return result + return dst.getvalue() diff --git a/dissect/util/compression/lzo.py b/dissect/util/compression/lzo.py index d6a37c4..0c8b1e7 100644 --- a/dissect/util/compression/lzo.py +++ b/dissect/util/compression/lzo.py @@ -2,10 +2,11 @@ # - https://github.com/FFmpeg/FFmpeg/blob/master/libavutil/lzo.c # - https://docs.kernel.org/staging/lzo.html # - https://github.com/torvalds/linux/blob/master/lib/lzo/lzo1x_decompress_safe.c +from __future__ import annotations import io import struct -from typing import BinaryIO, Union +from typing import BinaryIO MAX_READ_LENGTH = (1 << 32) - 1000 @@ -22,7 +23,7 @@ def _read_length(src: BinaryIO, val: int, mask: int) -> int: return length + mask + val -def decompress(src: Union[bytes, BinaryIO], header: bool = True, buflen: int = -1) -> bytes: +def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) -> bytes: """LZO decompress from a file-like object or bytes. Assumes no header. Arguments are largely compatible with python-lzo API. diff --git a/dissect/util/compression/lzxpress.py b/dissect/util/compression/lzxpress.py index d954332..2eb8b95 100644 --- a/dissect/util/compression/lzxpress.py +++ b/dissect/util/compression/lzxpress.py @@ -1,10 +1,12 @@ # Reference: [MS-XCA] +from __future__ import annotations + import io import struct -from typing import BinaryIO, Union +from typing import BinaryIO -def decompress(src: Union[bytes, BinaryIO]) -> bytes: +def decompress(src: bytes | BinaryIO) -> bytes: """LZXPRESS decompress from a file-like object or bytes. Args: @@ -62,7 +64,7 @@ def decompress(src: Union[bytes, BinaryIO]) -> bytes: match_length = struct.unpack(" int: @@ -16,7 +19,7 @@ def _read_16_bit(fh: BinaryIO) -> int: class Node: __slots__ = ("symbol", "is_leaf", "children") - def __init__(self, symbol: Optional[Symbol] = None, is_leaf: bool = False): + def __init__(self, symbol: Symbol | None = None, is_leaf: bool = False): self.symbol = symbol self.is_leaf = is_leaf self.children = [None, None] @@ -120,7 +123,7 @@ def decode(self, root: Node) -> Symbol: return node.symbol -def decompress(src: Union[bytes, BinaryIO]) -> bytes: +def decompress(src: bytes | BinaryIO) -> bytes: """LZXPRESS decompress from a file-like object or bytes. Decompresses until EOF of the input data. diff --git a/dissect/util/compression/sevenbit.py b/dissect/util/compression/sevenbit.py index e4d7d14..58e6fba 100644 --- a/dissect/util/compression/sevenbit.py +++ b/dissect/util/compression/sevenbit.py @@ -1,8 +1,10 @@ +from __future__ import annotations + from io import BytesIO -from typing import BinaryIO, Union +from typing import BinaryIO -def compress(src: Union[bytes, BinaryIO]) -> bytes: +def compress(src: bytes | BinaryIO) -> bytes: """Sevenbit compress from a file-like object or bytes. Args: @@ -37,7 +39,7 @@ def compress(src: Union[bytes, BinaryIO]) -> bytes: return bytes(dst) -def decompress(src: Union[bytes, BinaryIO], wide: bool = False) -> bytes: +def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes: """Sevenbit decompress from a file-like object or bytes. Args: diff --git a/dissect/util/cpio.py b/dissect/util/cpio.py index 93da41f..f92ea97 100644 --- a/dissect/util/cpio.py +++ b/dissect/util/cpio.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import stat import struct import tarfile @@ -38,7 +40,7 @@ class CpioInfo(tarfile.TarInfo): """ @classmethod - def fromtarfile(cls, tarfile: tarfile.TarFile) -> tarfile.TarInfo: + def fromtarfile(cls, tarfile: tarfile.TarFile) -> CpioInfo: if tarfile.format not in ( FORMAT_CPIO_BIN, FORMAT_CPIO_ODC, @@ -64,7 +66,7 @@ def fromtarfile(cls, tarfile: tarfile.TarFile) -> tarfile.TarInfo: return obj._proc_member(tarfile) @classmethod - def frombuf(cls, buf: bytes, format: int, encoding: str, errors: str) -> tarfile.TarInfo: + def frombuf(cls, buf: bytes, format: int, encoding: str, errors: str) -> CpioInfo: if format in (FORMAT_CPIO_BIN, FORMAT_CPIO_ODC, FORMAT_CPIO_HPBIN, FORMAT_CPIO_HPODC): obj = cls._old_frombuf(buf, format) elif format in (FORMAT_CPIO_NEWC, FORMAT_CPIO_CRC): @@ -78,7 +80,7 @@ def frombuf(cls, buf: bytes, format: int, encoding: str, errors: str) -> tarfile return obj @classmethod - def _old_frombuf(cls, buf: bytes, format: int): + def _old_frombuf(cls, buf: bytes, format: int) -> CpioInfo: if format in (FORMAT_CPIO_BIN, FORMAT_CPIO_HPBIN): values = list(struct.unpack("<13H", buf)) if values[0] == _swap16(CPIO_MAGIC_OLD): @@ -131,7 +133,7 @@ def _old_frombuf(cls, buf: bytes, format: int): return obj @classmethod - def _new_frombuf(cls, buf: bytes, format: int): + def _new_frombuf(cls, buf: bytes, format: int) -> CpioInfo: values = struct.unpack("<6s8s8s8s8s8s8s8s8s8s8s8s8s8s", buf) values = [int(values[0], 8)] + [int(v, 16) for v in values[1:]] if values[0] not in (CPIO_MAGIC_NEW, CPIO_MAGIC_CRC): @@ -157,7 +159,7 @@ def _new_frombuf(cls, buf: bytes, format: int): return obj - def _proc_member(self, tarfile: tarfile.TarFile) -> tarfile.TarInfo: + def _proc_member(self, tarfile: tarfile.TarFile) -> CpioInfo | None: self.name = tarfile.fileobj.read(self.namesize - 1).decode(tarfile.encoding, tarfile.errors) if self.name == "TRAILER!!!": # The last entry in a cpio file has the special name ``TRAILER!!!``, indicating the end of the archive @@ -177,10 +179,11 @@ def _proc_member(self, tarfile: tarfile.TarFile) -> tarfile.TarInfo: def _round_word(self, offset: int) -> int: if self.format in (FORMAT_CPIO_BIN, FORMAT_CPIO_HPBIN): return (offset + 1) & ~0x01 - elif self.format in (FORMAT_CPIO_NEWC, FORMAT_CPIO_CRC): + + if self.format in (FORMAT_CPIO_NEWC, FORMAT_CPIO_CRC): return (offset + 3) & ~0x03 - else: - return offset + + return offset def issocket(self) -> bool: """Return True if it is a socket.""" @@ -211,13 +214,13 @@ def _swap16(value: int) -> int: return ((value & 0xFF) << 8) | (value >> 8) -def CpioFile(*args, **kwargs): +def CpioFile(*args, **kwargs) -> tarfile.TarFile: """Utility wrapper around ``tarfile.TarFile`` to easily open cpio archives.""" kwargs.setdefault("format", FORMAT_CPIO_UNKNOWN) return tarfile.TarFile(*args, **kwargs, tarinfo=CpioInfo) -def open(*args, **kwargs): +def open(*args, **kwargs) -> tarfile.TarFile: """Utility wrapper around ``tarfile.open`` to easily open cpio archives.""" kwargs.setdefault("format", FORMAT_CPIO_UNKNOWN) return tarfile.open(*args, **kwargs, tarinfo=CpioInfo) diff --git a/dissect/util/encoding/surrogateescape.py b/dissect/util/encoding/surrogateescape.py index a62f709..7240cfe 100644 --- a/dissect/util/encoding/surrogateescape.py +++ b/dissect/util/encoding/surrogateescape.py @@ -1,7 +1,7 @@ import codecs -def error_handler(error): +def error_handler(error: Exception) -> tuple[str, int]: if not isinstance(error, UnicodeDecodeError): raise error diff --git a/dissect/util/feature.py b/dissect/util/feature.py index 70cec0f..8dac3d7 100644 --- a/dissect/util/feature.py +++ b/dissect/util/feature.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import functools import os from enum import Enum -from typing import Callable, Optional +from typing import Callable, NoReturn # Register feature flags in a central place to avoid chaos @@ -43,7 +45,7 @@ def parse_blob(): return feature in feature_flags() -def feature(flag: Feature, alternative: Optional[Callable] = None) -> Callable: +def feature(flag: Feature, alternative: Callable | None = None) -> Callable: """Feature flag decorator allowing you to guard a function behind a feature flag. Usage:: @@ -57,7 +59,7 @@ def my_func( ... ) -> ... if alternative is None: - def alternative(): + def alternative() -> NoReturn: raise FeatureException( "\n".join( [ @@ -68,7 +70,7 @@ def alternative(): ) ) - def decorator(func): + def decorator(func: Callable) -> Callable: return func if feature_enabled(flag) else alternative return decorator diff --git a/dissect/util/plist.py b/dissect/util/plist.py index a2d6c0e..cc636cf 100644 --- a/dissect/util/plist.py +++ b/dissect/util/plist.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import plistlib import uuid from collections import UserDict +from typing import TYPE_CHECKING, Any, BinaryIO from dissect.util.ts import cocoatimestamp +if TYPE_CHECKING: + from datetime import datetime + class NSKeyedArchiver: - def __init__(self, fh): + def __init__(self, fh: BinaryIO): self.plist = plistlib.load(fh) if not isinstance(self.plist, dict) or not all( @@ -21,13 +27,16 @@ def __init__(self, fh): for name, value in self.plist.get("$top", {}).items(): self.top[name] = self._parse(value) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.top[key] - def get(self, key, default=None): + def __repr__(self) -> str: + return f"" + + def get(self, key: str, default: Any | None = None) -> Any: return self.top.get(key, default) - def _parse(self, uid): + def _parse(self, uid: Any) -> Any: if not isinstance(uid, plistlib.UID): return uid @@ -38,7 +47,7 @@ def _parse(self, uid): self._cache[num] = result return result - def _parse_obj(self, obj): + def _parse_obj(self, obj: Any) -> Any: if isinstance(obj, dict): klass = obj.get("$class") if klass: @@ -55,12 +64,11 @@ def _parse_obj(self, obj): if isinstance(obj, str): return None if obj == "$null" else obj - def __repr__(self): - return f"" + return None class NSObject: - def __init__(self, nskeyed, obj): + def __init__(self, nskeyed: NSKeyedArchiver, obj: dict[str, Any]): self.nskeyed = nskeyed self.obj = obj @@ -68,20 +76,11 @@ def __init__(self, nskeyed, obj): self._classname = self._class.get("$classname", "Unknown") self._classes = self._class.get("$classes", []) - def keys(self): - return self.obj.keys() - - def get(self, attr, default=None): - try: - return self[attr] - except KeyError: - return default - - def __getitem__(self, attr): + def __getitem__(self, attr: str) -> Any: obj = self.obj[attr] return self.nskeyed._parse(obj) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: try: return self[attr] except KeyError: @@ -90,45 +89,54 @@ def __getattr__(self, attr): def __repr__(self): return f"<{self._classname}>" + def keys(self) -> list[str]: + return self.obj.keys() + + def get(self, attr: str, default: Any | None = None) -> Any: + try: + return self[attr] + except KeyError: + return default + class NSDictionary(UserDict, NSObject): - def __init__(self, nskeyed, obj): + def __init__(self, nskeyed: NSKeyedArchiver, obj: dict[str, Any]): NSObject.__init__(self, nskeyed, obj) self.data = {nskeyed._parse(key): obj for key, obj in zip(obj["NS.keys"], obj["NS.objects"])} - def __getitem__(self, key): - return self.nskeyed._parse(self.data[key]) - - def __repr__(self): + def __repr__(self) -> str: return NSObject.__repr__(self) + def __getitem__(self, key: str) -> Any: + return self.nskeyed._parse(self.data[key]) + -def parse_nsarray(nskeyed, obj): +def parse_nsarray(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> list[Any]: return list(map(nskeyed._parse, obj["NS.objects"])) -def parse_nsset(nskeyed, obj): +def parse_nsset(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> list[Any]: # Some values are not hashable, so return as list return parse_nsarray(nskeyed, obj) -def parse_nsdata(nskeyed, obj): +def parse_nsdata(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> Any: return obj["NS.data"] -def parse_nsdate(nskeyed, obj): +def parse_nsdate(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> datetime: return cocoatimestamp(obj["NS.time"]) -def parse_nsuuid(nskeyed, obj): +def parse_nsuuid(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> uuid.UUID: return uuid.UUID(bytes=obj["NS.uuidbytes"]) -def parse_nsurl(nskeyed, obj): +def parse_nsurl(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> str: base = nskeyed._parse(obj["NS.base"]) relative = nskeyed._parse(obj["NS.relative"]) if base: - return "/".join([base, relative]) + return f"{base}/{relative}" return relative diff --git a/dissect/util/sid.py b/dissect/util/sid.py index 4587d2d..3180ba6 100644 --- a/dissect/util/sid.py +++ b/dissect/util/sid.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import io import struct -from typing import BinaryIO, Union +from typing import BinaryIO -def read_sid(fh: Union[BinaryIO, bytes], endian: str = "<", swap_last: bool = False) -> str: +def read_sid(fh: BinaryIO | bytes, endian: str = "<", swap_last: bool = False) -> str: """Read a Windows SID from bytes. Normally we'd do this with cstruct, but do it with just struct to keep dissect.util dependency-free. @@ -43,6 +45,4 @@ def read_sid(fh: Union[BinaryIO, bytes], endian: str = "<", swap_last: bool = Fa f"{authority}", ] sid_elements.extend(map(str, sub_authorities)) - readable_sid = "-".join(sid_elements) - - return readable_sid + return "-".join(sid_elements) diff --git a/dissect/util/stream.py b/dissect/util/stream.py index e602c0f..bb6fb4e 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import io import os import sys import zlib from bisect import bisect_left, bisect_right from threading import Lock -from typing import BinaryIO, Optional, Union +from typing import BinaryIO STREAM_BUFFER_SIZE = int(os.getenv("DISSECT_STREAM_BUFFER_SIZE", io.DEFAULT_BUFFER_SIZE)) @@ -28,7 +30,7 @@ class AlignedStream(io.RawIOBase): align: The alignment size. Read operations are aligned on this boundary. Also determines buffer size. """ - def __init__(self, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE): + def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__() self.size = size self.align = align @@ -107,11 +109,7 @@ def read(self, n: int = -1) -> bytes: if size is not None: remaining = size - self._pos - - if n == -1: - n = remaining - else: - n = min(n, remaining) + n = remaining if n == -1 else min(n, remaining) if n == 0 or size is not None and size <= self._pos: return b"" @@ -228,7 +226,7 @@ class RelativeStream(AlignedStream): align: The alignment size. """ - def __init__(self, fh: BinaryIO, offset: int, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE): + def __init__(self, fh: BinaryIO, offset: int, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__(size, align) self._fh = fh self.offset = offset @@ -259,7 +257,7 @@ class BufferedStream(RelativeStream): align: The alignment size. """ - def __init__(self, fh: BinaryIO, offset: int = 0, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE): + def __init__(self, fh: BinaryIO, offset: int = 0, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__(fh, offset, size, align) @@ -271,7 +269,7 @@ class MappingStream(AlignedStream): align: The alignment size. """ - def __init__(self, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE): + def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__(size, align) self._runs: list[tuple[int, int, BinaryIO, int]] = [] @@ -360,7 +358,7 @@ class RunlistStream(AlignedStream): """ def __init__( - self, fh: BinaryIO, runlist: list[tuple[int, int]], size: int, block_size: int, align: Optional[int] = None + self, fh: BinaryIO, runlist: list[tuple[int, int]], size: int, block_size: int, align: int | None = None ): super().__init__(size, align or block_size) @@ -448,13 +446,13 @@ class OverlayStream(AlignedStream): align: The alignment size. """ - def __init__(self, fh: BinaryIO, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE): + def __init__(self, fh: BinaryIO, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__(size, align) self._fh = fh self.overlays: dict[int, tuple[int, BinaryIO]] = {} self._lookup: list[int] = [] - def add(self, offset: int, data: Union[bytes, BinaryIO], size: Optional[int] = None) -> None: + def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> None: """Add an overlay at the given offset. Args: @@ -469,7 +467,7 @@ def add(self, offset: int, data: Union[bytes, BinaryIO], size: Optional[int] = N size = data.size if hasattr(data, "size") else data.seek(0, io.SEEK_END) if not size: - return + return None if size < 0: raise ValueError("Size must be positive") @@ -565,7 +563,7 @@ class ZlibStream(AlignedStream): size: The size the stream should be. """ - def __init__(self, fh: BinaryIO, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE, **kwargs): + def __init__(self, fh: BinaryIO, size: int | None = None, align: int = STREAM_BUFFER_SIZE, **kwargs): self._fh = fh self._zlib = None diff --git a/dissect/util/tools/dump_nskeyedarchiver.py b/dissect/util/tools/dump_nskeyedarchiver.py index a319a5d..422e4c8 100644 --- a/dissect/util/tools/dump_nskeyedarchiver.py +++ b/dissect/util/tools/dump_nskeyedarchiver.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import argparse +from typing import Any from dissect.util.plist import NSKeyedArchiver, NSObject -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("file", type=argparse.FileType("rb"), help="NSKeyedArchiver plist file to dump") args = parser.parse_args() @@ -18,7 +21,7 @@ def main(): print_object(obj.top) -def print_object(obj, indent=0, seen=None): +def print_object(obj: Any, indent: int = 0, seen: set | None = None) -> None: if seen is None: seen = set() @@ -50,7 +53,7 @@ def print_object(obj, indent=0, seen=None): print(fmt(obj, indent)) -def fmt(obj, indent): +def fmt(obj: Any, indent: int) -> str: return f"{' ' * (indent * 4)}{obj}" diff --git a/dissect/util/ts.py b/dissect/util/ts.py index 1118945..753d4f0 100644 --- a/dissect/util/ts.py +++ b/dissect/util/ts.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import struct import sys from datetime import datetime, timedelta, timezone, tzinfo -from typing import Dict if sys.platform in ("win32", "emscripten"): _EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) @@ -275,9 +276,16 @@ def dostimestamp(ts: int, centiseconds: int = 0, swap: bool = False) -> datetime # extra_seconds will be at most 1. extra_seconds, centiseconds = divmod(centiseconds, 100) microseconds = centiseconds * 10000 - timestamp = datetime(year, month, day, hours, minutes, seconds + extra_seconds, microseconds) - return timestamp + return datetime( # noqa: DTZ001 + year, + month, + day, + hours, + minutes, + seconds + extra_seconds, + microseconds, + ) class UTC(tzinfo): @@ -287,17 +295,17 @@ class UTC(tzinfo): tz_dict: Dictionary of ``{"name": "timezone name", "offset": offset_from_UTC_in_minutes}`` """ - def __init__(self, tz_dict: Dict[str, int]): + def __init__(self, tz_dict: dict[str, str | int]): # offset should be in minutes self.name = tz_dict["name"] self.offset = timedelta(minutes=tz_dict["offset"]) - def utcoffset(self, dt): + def utcoffset(self, dt: datetime) -> timedelta: return self.offset - def tzname(self, dt): + def tzname(self, dt: datetime) -> str: return self.name - def dst(self, dt): + def dst(self, dt: datetime) -> timedelta: # do not account for daylight saving return timedelta(0) diff --git a/dissect/util/xmemoryview.py b/dissect/util/xmemoryview.py index 494df27..f5dfc78 100644 --- a/dissect/util/xmemoryview.py +++ b/dissect/util/xmemoryview.py @@ -2,10 +2,13 @@ import struct import sys -from typing import Any, Iterator, Union +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + from collections.abc import Iterator -def xmemoryview(view: bytes, format: str) -> Union[memoryview, _xmemoryview]: + +def xmemoryview(view: bytes, format: str) -> memoryview | _xmemoryview: """Cast a memoryview to the specified format, including endianness. The regular ``memoryview.cast()`` method only supports host endianness. While that should be fine 99% of the time @@ -43,9 +46,9 @@ def xmemoryview(view: bytes, format: str) -> Union[memoryview, _xmemoryview]: ): # Native endianness, don't need to do anything return view - else: - # Non-native endianness - return _xmemoryview(view, format) + + # Non-native endianness + return _xmemoryview(view, format) class _xmemoryview: @@ -67,7 +70,7 @@ def __init__(self, view: memoryview, format: str): def tolist(self) -> list[int]: return self._convert(self._view.tolist()) - def _convert(self, value: Union[list[int], int]) -> int: + def _convert(self, value: list[int] | int) -> int: if isinstance(value, list): endian = self._format[0] fmt = self._format[1] @@ -75,13 +78,15 @@ def _convert(self, value: Union[list[int], int]) -> int: return list(struct.unpack(f"{endian}{pck}", struct.pack(f"={pck}", *value))) return self._struct_to.unpack(self._struct_frm.pack(value))[0] - def __getitem__(self, idx: Union[int, slice]) -> Union[int, bytes]: + def __getitem__(self, idx: int | slice) -> int | bytes: value = self._view[idx] if isinstance(idx, int): return self._convert(value) if isinstance(idx, slice): return _xmemoryview(self._view[idx], self._format) + raise TypeError("Invalid index type") + def __setitem__(self, *args, **kwargs) -> None: # setitem looks like it's a no-op on cast memoryviews? pass @@ -89,7 +94,7 @@ def __setitem__(self, *args, **kwargs) -> None: def __len__(self) -> int: return len(self._view) - def __eq__(self, other: Union[memoryview, _xmemoryview]): + def __eq__(self, other: memoryview | _xmemoryview) -> bool: if isinstance(other, _xmemoryview): other = other._view return self._view.__eq__(other) diff --git a/pyproject.toml b/pyproject.toml index b9703bb..55ac04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,13 +44,56 @@ lz4 = [ [project.scripts] dump-nskeyedarchiver = "dissect.util.tools.dump_nskeyedarchiver:main" -[tool.black] +[tool.ruff] line-length = 120 +required-version = ">=0.6.0" -[tool.isort] -profile = "black" -known_first_party = ["dissect.util"] -known_third_party = ["dissect"] +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "F", + "E", + "W", + "I", + "UP", + "YTT", + "ANN", + "B", + "C4", + "DTZ", + "T10", + "FA", + "ISC", + "G", + "INP", + "PIE", + "PYI", + "PT", + "Q", + "RSE", + "RET", + "SLOT", + "SIM", + "TID", + "TCH", + "PTH", + "PLC", + "TRY", + "FLY", + "PERF", + "FURB", + "RUF", +] +ignore = ["E203", "B904", "UP024", "ANN002", "ANN003", "ANN101", "ANN102", "ANN204", "ANN401", "SIM105", "TRY003"] + +[tool.ruff.lint.per-file-ignores] +"tests/docs/**" = ["INP001"] + +[tool.ruff.lint.isort] +known-first-party = ["dissect.util"] +known-third-party = ["dissect"] [tool.setuptools] license-files = ["LICENSE", "COPYRIGHT"] diff --git a/tests/test_compression.py b/tests/test_compression.py index 68f75f9..72b772b 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -4,24 +4,16 @@ import pytest -from dissect.util.compression import ( - lz4, - lznt1, - lzo, - lzxpress, - lzxpress_huffman, - sevenbit, - xz, -) +from dissect.util.compression import lz4, lznt1, lzo, lzxpress, lzxpress_huffman, sevenbit, xz -def test_lz4_decompress(): +def test_lz4_decompress() -> None: assert ( lz4.decompress(b"\xff\x0cLZ4 compression test string\x1b\x00\xdbPtring") == b"LZ4 compression test string" * 10 ) -def test_lznt1_decompress(): +def test_lznt1_decompress() -> None: assert lznt1.decompress( bytes.fromhex( "38b08846232000204720410010a24701a045204400084501507900c045200524" @@ -34,7 +26,7 @@ def test_lznt1_decompress(): ) -def test_lzo_decompress(): +def test_lzo_decompress() -> None: assert ( lzo.decompress(bytes.fromhex("0361626361626320f314000f616263616263616263616263616263616263110000"), False) == b"abc" * 100 @@ -220,7 +212,7 @@ def test_lzo_decompress(): assert lzo.decompress(bytes.fromhex("f0000000041574657374110000")) == b"test" -def test_lzxpress_huffman_decompress(): +def test_lzxpress_huffman_decompress() -> None: assert ( lzxpress_huffman.decompress( bytes.fromhex( @@ -239,29 +231,29 @@ def test_lzxpress_huffman_decompress(): ) -def test_lzxpress_decompress(): +def test_lzxpress_decompress() -> None: assert lzxpress.decompress(bytes.fromhex("ffff ff1f 6162 6317 000f ff26 01")) == b"abc" * 100 -def test_sevenbit_compress(): +def test_sevenbit_compress() -> None: result = sevenbit.compress(b"7-bit compression test string") target = bytes.fromhex("b796384d078ddf6db8bc3c9fa7df6e10bd3ca783e67479da7d06") assert result == target -def test_sevenbit_decompress(): +def test_sevenbit_decompress() -> None: result = sevenbit.decompress(bytes.fromhex("b796384d078ddf6db8bc3c9fa7df6e10bd3ca783e67479da7d06")) target = b"7-bit compression test string" assert result == target -def test_sevenbit_decompress_wide(): +def test_sevenbit_decompress_wide() -> None: result = sevenbit.decompress(bytes.fromhex("b796384d078ddf6db8bc3c9fa7df6e10bd3ca783e67479da7d06"), wide=True) target = "7-bit compression test string".encode("utf-16-le") assert result == target -def test_xz_repair_checksum(): +def test_xz_repair_checksum() -> None: buf = BytesIO( bytes.fromhex( "fd377a585a000004deadbeef0200210116000000deadbeefe00fff001e5d003a" diff --git a/tests/test_cpio.py b/tests/test_cpio.py index 56c0705..1962f1e 100644 --- a/tests/test_cpio.py +++ b/tests/test_cpio.py @@ -1,16 +1,17 @@ import gzip -import os +from pathlib import Path +from tarfile import TarFile import pytest from dissect.util import cpio -def absolute_path(filename): - return os.path.join(os.path.dirname(__file__), filename) +def absolute_path(filename: str) -> Path: + return Path(__file__).parent / filename -def _verify_archive(archive): +def _verify_archive(archive: TarFile) -> None: assert sorted(archive.getnames()) == sorted( [f"dir/file_{i}" for i in range(1, 101)] + ["large-file", "small-file", "symlink-1", "symlink-2"] ) @@ -39,7 +40,7 @@ def _verify_archive(archive): @pytest.mark.parametrize( - "path,format", + ("path", "format"), [ ("data/bin.cpio.gz", cpio.FORMAT_CPIO_BIN), ("data/odc.cpio.gz", cpio.FORMAT_CPIO_ODC), @@ -49,7 +50,7 @@ def _verify_archive(archive): ("data/crc.cpio.gz", cpio.FORMAT_CPIO_CRC), ], ) -def test_cpio_formats(path, format): +def test_cpio_formats(path: str, format: int) -> None: # With explicit format archive = cpio.open(absolute_path(path), format=format) _verify_archive(archive) diff --git a/tests/test_crc32c.py b/tests/test_crc32c.py index db04836..68a5f1e 100644 --- a/tests/test_crc32c.py +++ b/tests/test_crc32c.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize( - "data, value, expected", + ("data", "value", "expected"), [ (b"hello, world!", 0, 0xCE8F3C63), (b"hello, world!", 0x12345678, 0x30663976), @@ -22,5 +22,5 @@ (bytes(reversed(range(32))), 0, 0x113FDB5C), ], ) -def test_crc32c(data: bytes, value: int, expected: int): +def test_crc32c(data: bytes, value: int, expected: int) -> None: assert crc32c.crc32c(data, value) == expected diff --git a/tests/test_feature.py b/tests/test_feature.py index d55d0a7..6f9e53e 100644 --- a/tests/test_feature.py +++ b/tests/test_feature.py @@ -4,23 +4,23 @@ def test_feature_flags() -> None: - def fallback(): + def fallback() -> bool: return False @feature(Feature.BETA, fallback) - def experimental(): + def experimental() -> bool: return True @feature(Feature.ADVANCED, fallback) - def advanced(): + def advanced() -> bool: return True @feature(Feature.LATEST) - def latest(): + def latest() -> bool: return True @feature("expert") - def expert(): + def expert() -> bool: return True assert experimental() is False diff --git a/tests/test_plist.py b/tests/test_plist.py index cbea887..60df5c5 100644 --- a/tests/test_plist.py +++ b/tests/test_plist.py @@ -10,7 +10,7 @@ @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") -def test_plist_nskeyedarchiver(): +def test_plist_nskeyedarchiver() -> None: UID = plistlib.UID data = { @@ -142,7 +142,7 @@ def test_plist_nskeyedarchiver(): assert root.String == "TestString" assert root.Bytes == b"bytes" assert root.Data == b"\x00" * 4 - assert root.UUID == uuid.UUID("00000000-0000-0000-0000-000000000000") + assert root.UUID == uuid.UUID("00000000-0000-0000-0000-000000000000") # noqa: SIM300 assert root.Date == datetime.datetime(2021, 12, 10, 13, 55, 52, 84823, tzinfo=datetime.timezone.utc) assert root.URL == "http://base/relative" assert root.URLBaseless == "relative" diff --git a/tests/test_sid.py b/tests/test_sid.py index 428dffd..59388a2 100644 --- a/tests/test_sid.py +++ b/tests/test_sid.py @@ -1,21 +1,28 @@ from __future__ import annotations import io +from typing import BinaryIO import pytest from dissect.util import sid -def id_fn(val): - if isinstance(val, (str,)): +def id_fn(val: bytes | str) -> str: + if isinstance(val, io.BytesIO): + val = val.getvalue() + + if isinstance(val, str): return val - else: - return "" + + if isinstance(val, bytes): + return val.hex() + + return "" @pytest.mark.parametrize( - "binary_sid, readable_sid, endian, swap_last", + ("binary_sid", "readable_sid", "endian", "swap_last"), [ ( b"\x01\x00\x00\x00\x00\x00\x00\x00", @@ -62,5 +69,5 @@ def id_fn(val): ], ids=id_fn, ) -def test_read_sid(binary_sid: bytes | io.BinaryIO, endian: str, swap_last: bool, readable_sid: str) -> None: +def test_read_sid(binary_sid: bytes | BinaryIO, endian: str, swap_last: bool, readable_sid: str) -> None: assert readable_sid == sid.read_sid(binary_sid, endian, swap_last) diff --git a/tests/test_stream.py b/tests/test_stream.py index 1f0cdd1..083c85f 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -7,7 +7,7 @@ from dissect.util import stream -def test_range_stream(): +def test_range_stream() -> None: buf = io.BytesIO(b"\x01" * 10 + b"\x02" * 10 + b"\x03" * 10) fh = stream.RangeStream(buf, 5, 15) @@ -41,7 +41,7 @@ def test_range_stream(): assert fh.tell() == 0 -def test_relative_stream(): +def test_relative_stream() -> None: buf = io.BytesIO(b"\x01" * 10 + b"\x02" * 10 + b"\x03" * 10) fh = stream.RelativeStream(buf, 5) @@ -65,7 +65,7 @@ def test_relative_stream(): assert fh.read() == b"\x01" * 5 + b"\x02" * 10 + b"\x03" * 10 -def test_buffered_stream(): +def test_buffered_stream() -> None: buf = io.BytesIO(b"\x01" * 512 + b"\x02" * 512 + b"\x03" * 512) fh = stream.BufferedStream(buf, size=None) @@ -75,7 +75,7 @@ def test_buffered_stream(): assert fh.read(1) == b"" -def test_mapping_stream(): +def test_mapping_stream() -> None: buffers = [ io.BytesIO(b"\x01" * 512), io.BytesIO(b"\x02" * 512), @@ -96,7 +96,7 @@ def test_mapping_stream(): assert fh.read(1) == b"" -def test_runlist_stream(): +def test_runlist_stream() -> None: buf = io.BytesIO(b"\x01" * 512 + b"\x02" * 512 + b"\x03" * 512) fh = stream.RunlistStream(buf, [(0, 32), (32, 16), (48, 48)], 1536, 16) @@ -112,7 +112,7 @@ def test_runlist_stream(): assert fh.read(1) == b"" -def test_aligned_stream_buffer(): +def test_aligned_stream_buffer() -> None: buf = io.BytesIO(b"\x01" * 512 + b"\x02" * 512 + b"\x03" * 512 + b"\x04" * 512) fh = stream.RelativeStream(buf, 0, align=512) @@ -130,7 +130,7 @@ def test_aligned_stream_buffer(): assert fh._buf == b"\x03" * 512 -def test_overlay_stream(): +def test_overlay_stream() -> None: buf = io.BytesIO(b"\x00" * 512 * 8) fh = stream.OverlayStream(buf, size=512 * 8, align=512) @@ -139,13 +139,13 @@ def test_overlay_stream(): fh.seek(0) # Add a small overlay - fh.add(512, b"\xFF" * 4) + fh.add(512, b"\xff" * 4) assert fh.read(512) == b"\x00" * 512 - assert fh.read(512) == (b"\xFF" * 4) + (b"\x00" * 508) + assert fh.read(512) == (b"\xff" * 4) + (b"\x00" * 508) fh.seek(510) - assert fh.read(4) == b"\x00\x00\xFF\xFF" + assert fh.read(4) == b"\x00\x00\xff\xff" # Add a large unaligned overlay fh.add(1000, b"\x01" * 1024) @@ -164,9 +164,9 @@ def test_overlay_stream(): fh.add(516, b"\x02" * 10) fh.seek(510) - assert fh.read(32) == b"\x00\x00" + (b"\xFF" * 4) + (b"\x02" * 10) + (b"\x00" * 16) + assert fh.read(32) == b"\x00\x00" + (b"\xff" * 4) + (b"\x02" * 10) + (b"\x00" * 16) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Overlap with existing overlay: \\(\\(512, 4\\)\\)"): fh.add(500, b"\x03" * 100) fh.add((512 * 8) - 4, b"\x04" * 100) @@ -174,7 +174,7 @@ def test_overlay_stream(): assert fh.read(100) == b"\x04" * 4 -def test_zlib_stream(): +def test_zlib_stream() -> None: data = b"\x01" * 8192 + b"\x02" * 8192 + b"\x03" * 8192 + b"\x04" * 8192 fh = stream.ZlibStream(io.BytesIO(zlib.compress(data)), size=8192 * 4, align=512) diff --git a/tests/test_ts.py b/tests/test_ts.py index 283f46e..ca021d1 100644 --- a/tests/test_ts.py +++ b/tests/test_ts.py @@ -1,27 +1,28 @@ import platform from datetime import datetime, timedelta, timezone from importlib import reload +from types import ModuleType from unittest.mock import patch import pytest @pytest.fixture(params=["windows", "emscripten", "linux"]) -def imported_ts(request): +def imported_ts(request: pytest.FixtureRequest) -> ModuleType: with patch.object(platform, "system", return_value=request.param): from dissect.util import ts - yield reload(ts) + return reload(ts) @pytest.fixture -def ts(): +def ts() -> ModuleType: from dissect.util import ts - yield reload(ts) + return reload(ts) -def test_now(ts): +def test_now(ts: ModuleType) -> None: ts_now = ts.now() datetime_now = datetime.now(timezone.utc) time_diff = datetime_now - ts_now @@ -31,109 +32,109 @@ def test_now(ts): assert ts_now.tzinfo == timezone.utc -def test_unix_now(imported_ts): +def test_unix_now(imported_ts: ModuleType) -> None: timestamp = imported_ts.unix_now() assert isinstance(timestamp, int) assert datetime.fromtimestamp(timestamp, tz=timezone.utc).microsecond == 0 -def test_unix_now_ms(imported_ts): +def test_unix_now_ms(imported_ts: ModuleType) -> None: timestamp = imported_ts.unix_now_ms() assert isinstance(timestamp, int) assert imported_ts.from_unix_ms(timestamp).microsecond == (timestamp % 1e3) * 1000 -def test_unix_now_us(imported_ts): +def test_unix_now_us(imported_ts: ModuleType) -> None: timestamp = imported_ts.unix_now_us() assert isinstance(timestamp, int) assert imported_ts.from_unix_us(timestamp).microsecond == timestamp % 1e6 -def test_unix_now_ns(imported_ts): +def test_unix_now_ns(imported_ts: ModuleType) -> None: timestamp = imported_ts.unix_now_ns() assert isinstance(timestamp, int) assert imported_ts.from_unix_ns(timestamp).microsecond == int((timestamp // 1000) % 1e6) -def test_to_unix(ts): +def test_to_unix(ts: ModuleType) -> None: dt = datetime(2018, 4, 11, 23, 34, 32, 915138, tzinfo=timezone.utc) assert ts.to_unix(dt) == 1523489672 -def test_to_unix_ms(ts): +def test_to_unix_ms(ts: ModuleType) -> None: dt = datetime(2018, 4, 11, 23, 34, 32, 915000, tzinfo=timezone.utc) assert ts.to_unix_ms(dt) == 1523489672915 -def test_to_unix_us(ts): +def test_to_unix_us(ts: ModuleType) -> None: dt = datetime(2018, 4, 11, 23, 34, 32, 915138, tzinfo=timezone.utc) assert ts.to_unix_us(dt) == 1523489672915138 -def test_to_unix_ns(ts): +def test_to_unix_ns(ts: ModuleType) -> None: dt = datetime(2018, 4, 11, 23, 34, 32, 915138, tzinfo=timezone.utc) assert ts.to_unix_ns(dt) == 1523489672915138000 -def test_from_unix(imported_ts): +def test_from_unix(imported_ts: ModuleType) -> None: assert imported_ts.from_unix(1523489672) == datetime(2018, 4, 11, 23, 34, 32, tzinfo=timezone.utc) -def test_from_unix_ms(imported_ts): +def test_from_unix_ms(imported_ts: ModuleType) -> None: assert imported_ts.from_unix_ms(1511260448882) == datetime(2017, 11, 21, 10, 34, 8, 882000, tzinfo=timezone.utc) -def test_from_unix_us(imported_ts): +def test_from_unix_us(imported_ts: ModuleType) -> None: assert imported_ts.from_unix_us(1511260448882000) == datetime(2017, 11, 21, 10, 34, 8, 882000, tzinfo=timezone.utc) -def test_from_unix_ns(imported_ts): +def test_from_unix_ns(imported_ts: ModuleType) -> None: assert imported_ts.from_unix_ns(1523489672915138048) == datetime( 2018, 4, 11, 23, 34, 32, 915138, tzinfo=timezone.utc ) -def test_xfstimestamp(imported_ts): +def test_xfstimestamp(imported_ts: ModuleType) -> None: assert imported_ts.xfstimestamp(1582541380, 451742903) == datetime( 2020, 2, 24, 10, 49, 40, 451743, tzinfo=timezone.utc ) -def test_ufstimestamp(imported_ts): +def test_ufstimestamp(imported_ts: ModuleType) -> None: assert imported_ts.ufstimestamp(1582541380, 451742903) == datetime( 2020, 2, 24, 10, 49, 40, 451743, tzinfo=timezone.utc ) -def test_wintimestamp(imported_ts): +def test_wintimestamp(imported_ts: ModuleType) -> None: assert imported_ts.wintimestamp(131679632729151386) == datetime( 2018, 4, 11, 23, 34, 32, 915138, tzinfo=timezone.utc ) -def test_oatimestamp(imported_ts): +def test_oatimestamp(imported_ts: ModuleType) -> None: dt = datetime(2016, 10, 17, 4, 6, 38, 362003, tzinfo=timezone.utc) assert imported_ts.oatimestamp(42660.171277338) == dt assert imported_ts.oatimestamp(4676095982878497960) == dt assert imported_ts.oatimestamp(-4542644417712532139) == datetime(1661, 4, 17, 11, 30, tzinfo=timezone.utc) -def test_webkittimestamp(imported_ts): +def test_webkittimestamp(imported_ts: ModuleType) -> None: assert imported_ts.webkittimestamp(13261574439236538) == datetime( 2021, 3, 30, 10, 40, 39, 236538, tzinfo=timezone.utc ) -def test_cocoatimestamp(imported_ts): +def test_cocoatimestamp(imported_ts: ModuleType) -> None: assert imported_ts.cocoatimestamp(622894123) == datetime(2020, 9, 27, 10, 8, 43, tzinfo=timezone.utc) assert imported_ts.cocoatimestamp(622894123.221783) == datetime(2020, 9, 27, 10, 8, 43, 221783, tzinfo=timezone.utc) -def test_negative_timestamps(imported_ts): +def test_negative_timestamps(imported_ts: ModuleType) -> None: # -5000.0 converted to a int representation assert imported_ts.oatimestamp(13885591609694748672) == datetime(1886, 4, 22, 0, 0, tzinfo=timezone.utc) assert imported_ts.oatimestamp(-5000.0) == datetime(1886, 4, 22, 0, 0, tzinfo=timezone.utc) diff --git a/tests/test_xmemoryview.py b/tests/test_xmemoryview.py index 59b3096..b965bcf 100644 --- a/tests/test_xmemoryview.py +++ b/tests/test_xmemoryview.py @@ -1,7 +1,7 @@ from dissect.util.xmemoryview import xmemoryview -def test_xmemoryview_little(): +def test_xmemoryview_little() -> None: # This is mostly a sanity test, since this will be a native memoryview on little endian systems buf = bytearray(range(256)) view = memoryview(buf).cast("I") @@ -18,7 +18,7 @@ def test_xmemoryview_little(): assert list(it)[:5] == [0x03020100, 0x07060504, 0x0B0A0908, 0x0F0E0D0C, 0x13121110] -def test_xmemoryview_big(): +def test_xmemoryview_big() -> None: buf = bytearray(range(256)) view = memoryview(buf).cast("I") diff --git a/tox.ini b/tox.ini index 3419558..b517c66 100644 --- a/tox.ini +++ b/tox.ini @@ -31,32 +31,19 @@ commands = [testenv:fix] package = skip deps = - black==23.1.0 - isort==5.11.4 + ruff==0.6.9 commands = - black dissect tests - isort dissect tests + ruff format dissect tests [testenv:lint] package = skip deps = - black==23.1.0 - flake8 - flake8-black - flake8-isort - isort==5.11.4 + ruff==0.6.9 vermin commands = - flake8 dissect tests + ruff check dissect tests vermin -t=3.9- --no-tips --lint dissect tests -[flake8] -max-line-length = 120 -extend-ignore = - # See https://github.com/PyCQA/pycodestyle/issues/373 - E203, -statistics = True - [testenv:docs-build] allowlist_externals = make deps =