Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/modules/media_repository_callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ callback that does not return `None` will be used. If this happens, Synapse will
any of the subsequent implementations of this callback.

If no module returns a non-`None` value then the default [media upload limits config](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#media_upload_limits) will be used.

### `on_media_upload_limit_exceeded`

_First introduced in Synapse v1.137.0_

```python
async def on_media_upload_limit_exceeded(user_id: str, limit: MediaUploadLimit, sent_bytes: int, attempted_bytes: int) -> None
```

**<span style="color:red">
Caution: This callback is currently experimental . The method signature or behaviour
may change without notice.
</span>**

Called when a user attempts to upload media that would exceed a configured media upload limit.

The arguments passed to this callback are:

* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.
* `limit`: The `MediaUploadLimit` that was reached.
* `sent_bytes`: The number of bytes already sent during the period of the limit.
* `attempted_bytes`: The number of bytes that the user attempted to send.
3 changes: 3 additions & 0 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ async def create_or_update_content(
)

if uploaded_media_size + content_length > limit.max_bytes:
await self.media_repository_callbacks.on_media_upload_limit_exceeded(
user_id=auth_user.to_string(), limit=limit, sent_bytes=uploaded_media_size, attempted_bytes=content_length
)
raise SynapseError(
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
)
Expand Down
3 changes: 3 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
GET_MEDIA_CONFIG_FOR_USER_CALLBACK,
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK,
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK,
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK,
)
from synapse.module_api.callbacks.ratelimit_callbacks import (
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK,
Expand Down Expand Up @@ -468,6 +469,7 @@ def register_media_repository_callbacks(
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
on_media_upload_limit_exceeded: Optional[ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK] = None,
) -> None:
"""Registers callbacks for media repository capabilities.
Added in Synapse v1.132.0.
Expand All @@ -476,6 +478,7 @@ def register_media_repository_callbacks(
get_media_config_for_user=get_media_config_for_user,
is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size,
get_media_upload_limits_for_user=get_media_upload_limits_for_user,
on_media_upload_limit_exceeded=on_media_upload_limit_exceeded
)

def register_third_party_rules_callbacks(
Expand Down
24 changes: 24 additions & 0 deletions synapse/module_api/callbacks/media_repository_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
[str], Awaitable[Optional[List[MediaUploadLimit]]]
]

ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK = Callable[[str, MediaUploadLimit, int, int], Awaitable[None]]


class MediaRepositoryModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
Expand All @@ -47,6 +49,9 @@ def __init__(self, hs: "HomeServer") -> None:
self._get_media_upload_limits_for_user_callbacks: List[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = []
self._on_media_upload_limit_exceeded_callbacks: List[
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
] = []

def register_callbacks(
self,
Expand All @@ -57,6 +62,7 @@ def register_callbacks(
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
on_media_upload_limit_exceeded: Optional[ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if get_media_config_for_user is not None:
Expand All @@ -72,6 +78,11 @@ def register_callbacks(
get_media_upload_limits_for_user
)

if on_media_upload_limit_exceeded is not None:
self._on_media_upload_limit_exceeded_callbacks.append(
on_media_upload_limit_exceeded
)

async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]:
for callback in self._get_media_config_for_user_callbacks:
with Measure(
Expand Down Expand Up @@ -116,3 +127,16 @@ async def get_media_upload_limits_for_user(
return res

return None

async def on_media_upload_limit_exceeded(
self, user_id: str, limit: MediaUploadLimit, sent_bytes: int, attempted_bytes: int
) -> None:
for callback in self._on_media_upload_limit_exceeded_callbacks:
with Measure(
self.clock,
name=f"{callback.__module__}.{callback.__qualname__}",
server_name=self.server_name,
):
# Use a copy of the data in case the module modifies it
limit_copy = MediaUploadLimit(max_bytes=limit.max_bytes, time_period_ms=limit.time_period_ms)
await delay_cancellation(callback(user_id, limit_copy, sent_bytes, attempted_bytes))
33 changes: 31 additions & 2 deletions tests/rest/client/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,6 +3029,20 @@ async def _get_media_upload_limits_for_user(
# otherwise use default
return None

async def _on_media_upload_limit_exceeded(
self,
user_id: str,
limit: MediaUploadLimit,
sent_bytes: int,
attempted_bytes: int
) -> None:
self.last_media_upload_limit_exceeded = {
"user_id": user_id,
"limit": limit,
"sent_bytes": sent_bytes,
"attempted_bytes": attempted_bytes
}

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
Expand All @@ -3037,10 +3051,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.tok1 = self.login("user1", "pass")
self.user2 = self.register_user("user2", "pass")
self.tok2 = self.login("user2", "pass")
self.register_user("user3", "pass")
self.user3 = self.register_user("user3", "pass")
self.tok3 = self.login("user3", "pass")
self.last_media_upload_limit_exceeded = None
self.hs.get_module_api().register_media_repository_callbacks(
get_media_upload_limits_for_user=self._get_media_upload_limits_for_user
get_media_upload_limits_for_user=self._get_media_upload_limits_for_user,
on_media_upload_limit_exceeded=self._on_media_upload_limit_exceeded
)

def create_resource_dict(self) -> Dict[str, Resource]:
Expand Down Expand Up @@ -3075,27 +3091,40 @@ def test_upload_under_limit(self) -> None:
channel = self.upload_media(100, self.tok3)
self.assertEqual(channel.code, 200)

self.assertEqual(self.last_media_upload_limit_exceeded, None)

def test_uses_custom_limit(self) -> None:
"""Test that uploading media over the module provided daily limit fails."""

# User 1 uploads 3000 bytes
channel = self.upload_media(3000, self.tok1)
self.assertEqual(channel.code, 200)

# User 1 attempts to upload 4000 bytes taking it over the limit
channel = self.upload_media(4000, self.tok1)
self.assertEqual(channel.code, 400)
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user1)
self.assertEqual(self.last_media_upload_limit_exceeded["limit"], MediaUploadLimit(max_bytes=5000, time_period_ms=Config.parse_duration("1d")))
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 3000)
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 4000)

def test_uses_unlimited(self) -> None:
"""Test that unlimited user is not limited when module returns []."""
# User 2 uploads 10000 bytes which is over the default limit
channel = self.upload_media(10000, self.tok2)
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_media_upload_limit_exceeded, None)

def test_uses_defaults(self) -> None:
"""Test that the default limits are applied when module returned None."""
# User 3 uploads 500 bytes
channel = self.upload_media(500, self.tok3)
self.assertEqual(channel.code, 200)

# User 3 uploads 800 bytes which is over the limit
channel = self.upload_media(800, self.tok3)
self.assertEqual(channel.code, 400)
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user3)
self.assertEqual(self.last_media_upload_limit_exceeded["limit"], MediaUploadLimit(max_bytes=1024, time_period_ms=Config.parse_duration("1d")))
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 500)
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 800)