Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/18277.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hashes of media files are now tracked by Synapse. Media quarantines will now apply to all files with the same hash.
46 changes: 40 additions & 6 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@
respond_with_responder,
)
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage
from synapse.media.media_storage import (
MediaStorage,
SHA256TransparentIOReader,
SHA256TransparentIOWriter,
)
from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
Expand Down Expand Up @@ -301,15 +305,26 @@ async def update_content(
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)

if should_quarantine:
logger.warn(
"Media has been automatically quarantined as it matched existing quarantined media"
)

await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)

try:
Expand Down Expand Up @@ -342,18 +357,29 @@ async def create_content(
media_id = random_string(24)

file_info = FileInfo(server_name=None, file_id=media_id)

fname = await self.media_storage.store_file(content, file_info)
# This implements all of IO as it has a passthrough
sha256reader = SHA256TransparentIOReader(content)
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)

logger.info("Stored local media in file %r", fname)

if should_quarantine:
logger.warn(
"Media has been automatically quarantined as it matched existing quarantined media"
)

await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
# TODO: Better name?
quarantined_by="system" if should_quarantine else None,
)

try:
Expand Down Expand Up @@ -756,11 +782,13 @@ async def _download_remote_file(
file_info = FileInfo(server_name=server_name, file_id=file_id)

async with self.media_storage.store_into_file(file_info) as (f, fname):
sha256writer = SHA256TransparentIOWriter(f)
try:
length, headers = await self.client.download_media(
server_name,
media_id,
output_stream=f,
# This implements all of BinaryIO as it has a passthrough
output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
Expand Down Expand Up @@ -825,6 +853,7 @@ async def _download_remote_file(
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
sha256=sha256writer.hexdigest(),
)

logger.info("Stored remote media in file %r", fname)
Expand All @@ -845,6 +874,7 @@ async def _download_remote_file(
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
sha256=sha256writer.hexdigest(),
)

async def _federation_download_remote_file(
Expand Down Expand Up @@ -879,11 +909,13 @@ async def _federation_download_remote_file(
file_info = FileInfo(server_name=server_name, file_id=file_id)

async with self.media_storage.store_into_file(file_info) as (f, fname):
sha256writer = SHA256TransparentIOWriter(f)
try:
res = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
# This implements all of BinaryIO as it has a passthrough
output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
Expand Down Expand Up @@ -954,6 +986,7 @@ async def _federation_download_remote_file(
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
sha256=sha256writer.hexdigest(),
)

logger.debug("Stored remote media in file %r", fname)
Expand All @@ -974,6 +1007,7 @@ async def _federation_download_remote_file(
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
sha256=sha256writer.hexdigest(),
)

def _get_thumbnail_requirements(
Expand Down
84 changes: 83 additions & 1 deletion synapse/media/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#
#
import contextlib
import hashlib
import json
import logging
import os
Expand Down Expand Up @@ -70,6 +71,88 @@
CRLF = b"\r\n"


class SHA256TransparentIOWriter:
"""Will generate a SHA256 hash from a source stream transparently.

Args:
source: Source stream.
"""

def __init__(self, source: BinaryIO):
self._hash = hashlib.sha256()
self._source = source

def write(self, buffer: Union[bytes, bytearray]) -> int:
"""Wrapper for source.write()

Args:
buffer

Returns:
the value of source.write()
"""
res = self._source.write(buffer)
self._hash.update(buffer)
return res

def hexdigest(self) -> str:
"""The digest of the written or read value.

Returns:
The digest in hex formaat.
"""
return self._hash.hexdigest()

def wrap(self) -> BinaryIO:
# This class implements a subset the IO interface and passes through everything else via __getattr__
return cast(BinaryIO, self)

# Passthrough any other calls
def __getattr__(self, attr_name: str) -> Any:
return getattr(self._source, attr_name)


class SHA256TransparentIOReader:
"""Will generate a SHA256 hash from a source stream transparently.

Args:
source: Source IO stream.
"""

def __init__(self, source: IO):
self._hash = hashlib.sha256()
self._source = source

def read(self, n: int = -1) -> bytes:
"""Wrapper for source.read()

Args:
n

Returns:
the value of source.read()
"""
bytes = self._source.read(n)
self._hash.update(bytes)
return bytes

def hexdigest(self) -> str:
"""The digest of the written or read value.

Returns:
The digest in hex formaat.
"""
return self._hash.hexdigest()

def wrap(self) -> IO:
# This class implements a subset the IO interface and passes through everything else via __getattr__
return cast(IO, self)

# Passthrough any other calls
def __getattr__(self, attr_name: str) -> Any:
return getattr(self._source, attr_name)


class MediaStorage:
"""Responsible for storing/fetching files from local sources.

Expand Down Expand Up @@ -107,7 +190,6 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str:
Returns:
the file path written to in the primary media store
"""

async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
Expand Down
Loading
Loading