Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/19311.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Put membership updates in a background resumable task when changing the avatar or the display name.
82 changes: 80 additions & 2 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import random
from typing import TYPE_CHECKING

from twisted.internet.defer import CancelledError

from synapse.api.constants import ProfileFields
from synapse.api.errors import (
AuthError,
Expand All @@ -32,7 +34,16 @@
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
from synapse.types import (
JsonDict,
JsonMapping,
JsonValue,
Requester,
ScheduledTask,
TaskStatus,
UserID,
create_requester,
)
from synapse.util.caches.descriptors import cached
from synapse.util.duration import Duration
from synapse.util.stringutils import parse_and_validate_mxc_uri
Expand All @@ -46,6 +57,8 @@
MAX_AVATAR_URL_LEN = 1000
# Field name length is specced at 255 bytes.
MAX_CUSTOM_FIELD_LEN = 255
UPDATE_JOIN_STATES_ACTION_NAME = "update_join_states"
UPDATE_JOIN_STATES_LOCK_NAME = "update_join_states_lock"


class ProfileHandler:
Expand Down Expand Up @@ -78,6 +91,12 @@ def __init__(self, hs: "HomeServer"):

self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules

self._task_scheduler = hs.get_task_scheduler()
self._task_scheduler.register_action(
self._update_join_states_task, UPDATE_JOIN_STATES_ACTION_NAME
)
self._worker_locks = hs.get_worker_locks_handler()

async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
"""
Get a user's profile as a JSON dictionary.
Expand Down Expand Up @@ -587,7 +606,59 @@ async def _update_join_states(
await self.clock.sleep(Duration(seconds=random.randint(1, 10)))
return

room_ids = await self.store.get_rooms_for_user(target_user.to_string())
target_user_str = target_user.to_string()

# Cancel any ongoing profile membership updates for this user,
# and start a new one.
async with self._worker_locks.acquire_read_write_lock(
UPDATE_JOIN_STATES_LOCK_NAME,
target_user_str,
write=True,
):
tasks_to_cancel = await self._task_scheduler.get_tasks(
actions=[UPDATE_JOIN_STATES_ACTION_NAME],
resource_id=target_user_str,
statuses=[TaskStatus.ACTIVE, TaskStatus.SCHEDULED],
)

for task in tasks_to_cancel:
await self._task_scheduler.cancel_task(task.id)

await self._task_scheduler.schedule_task(
UPDATE_JOIN_STATES_ACTION_NAME,
resource_id=target_user_str,
params={
"requester_authenticated_entity": requester.authenticated_entity,
},
)

async def _update_join_states_task(
self,
task: ScheduledTask,
) -> tuple[TaskStatus, JsonMapping | None, str | None]:
assert task.resource_id
assert task.params

target_user = UserID.from_string(task.resource_id)
room_ids = sorted(await self.store.get_rooms_for_user(target_user.to_string()))

last_room_id = task.result.get("last_room_id", None) if task.result else None

if last_room_id:
unhandled_room_ids = []
# The unhandled rooms should be at the end of the list, so we iterate in reverse
# and break when we reach the last handled room.
for room_id in reversed(room_ids):
if room_id > last_room_id:
unhandled_room_ids.append(room_id)
else:
break
room_ids = unhandled_room_ids

requester = create_requester(
user_id=target_user,
authenticated_entity=task.params.get("requester_authenticated_entity"),
)

for room_id in room_ids:
handler = self.hs.get_room_member_handler()
Expand All @@ -601,10 +672,17 @@ async def _update_join_states(
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except CancelledError as e:
raise e
except Exception as e:
logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e)
)
await self._task_scheduler.update_task(
task.id, result={"last_room_id": last_room_id}
)

return TaskStatus.COMPLETE, None, None

async def check_profile_query_allowed(
self, target_user: UserID, requester: UserID | None = None
Expand Down
5 changes: 4 additions & 1 deletion synapse/util/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,12 @@ async def wrapper() -> None:
log_context,
start_time,
)
result = None
error = None
try:
(status, result, error) = await function(task)
except defer.CancelledError:
status = TaskStatus.CANCELLED
except Exception:
f = Failure()
logger.error(
Expand All @@ -481,7 +485,6 @@ async def wrapper() -> None:
exc_info=(f.type, f.value, f.getTracebackObject()),
)
status = TaskStatus.FAILED
result = None
error = f.getErrorMessage()

await self._store.update_scheduled_task(
Expand Down
172 changes: 170 additions & 2 deletions tests/handlers/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,35 @@
#
#
from typing import Any, Awaitable, Callable
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock, Mock, patch

from parameterized import parameterized

from twisted.internet.testing import MemoryReactor

import synapse.types
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.types.state import StateFilter
from synapse.util.clock import Clock
from synapse.util.duration import Duration
from synapse.util.task_scheduler import TaskStatus

from tests import unittest


class ProfileTestCase(unittest.HomeserverTestCase):
"""Tests profile management."""

servlets = [admin.register_servlets]
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = AsyncMock()
Expand All @@ -62,12 +71,15 @@ def register_query_handler(

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = self.hs.get_storage_controllers()
self.task_scheduler = hs.get_task_scheduler()

self.frank = UserID.from_string("@1234abcd:test")
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")

self.register_user(self.frank.localpart, "frankpassword")
self.frank_token = self.login(self.frank.localpart, "frankpassword")

self.handler = hs.get_profile_handler()

Expand Down Expand Up @@ -113,6 +125,162 @@ def test_set_my_name(self) -> None:
self.get_success(self.store.get_profile_displayname(self.frank))
)

def test_update_room_membership_on_set_displayname(self) -> None:
"""Test that `set_displayname` updates membership events in rooms."""

self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)

room_id = self.helper.create_room_as(
self.frank.to_string(), tok=self.frank_token
)

state_tuple = (EventTypes.Member, self.frank.to_string())

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(membership[state_tuple].content["displayname"], "Frank")

self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(membership[state_tuple].content["displayname"], "Frank Jr.")

def test_backgound_update_room_membership_on_set_displayname(self) -> None:
"""Test that `set_displayname` returns immediately and that room membership updates are still done in background."""

self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)

room_id = self.helper.create_room_as(
self.frank.to_string(), tok=self.frank_token
)

original_update_membership = self.hs.get_room_member_handler().update_membership

async def slow_update_membership(*args: Any, **kwargs: Any) -> tuple[str, int]:
await self.clock.sleep(Duration(milliseconds=10))
return await original_update_membership(*args, **kwargs)

with patch.object(
self.hs.get_room_member_handler(),
"update_membership",
side_effect=slow_update_membership,
):
state_tuple = (EventTypes.Member, self.frank.to_string())
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(membership[state_tuple].content["displayname"], "Frank")

# Let's be sure we are over the delay introduced by slow_update_membership
self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1)

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(
membership[state_tuple].content["displayname"], "Frank Jr."
)

def test_background_update_room_membership_resume_after_restart(self) -> None:
"""Test that room membership updates triggered by changing the avatar or the display name are resumed after a restart."""

self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)

room_id = self.helper.create_room_as(
self.frank.to_string(), tok=self.frank_token
)

original_update_membership = self.hs.get_room_member_handler().update_membership

async def slow_update_membership(*args: Any, **kwargs: Any) -> tuple[str, int]:
await self.clock.sleep(Duration(milliseconds=10))
return await original_update_membership(*args, **kwargs)

with patch.object(
self.hs.get_room_member_handler(),
"update_membership",
side_effect=slow_update_membership,
):
state_tuple = (EventTypes.Member, self.frank.to_string())
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)

# Simulate a synapse restart by emptying the list of running tasks
# and canceling the deferred
_, deferred = self.task_scheduler._running_tasks.popitem()
deferred.cancel()

# Let's be sure we are over the delay introduced by slow_update_membership
# and that the task was not executed as expected
self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1)

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(membership[state_tuple].content["displayname"], "Frank")

cancelled_task = self.get_success(
self.task_scheduler.get_tasks(
actions=["update_join_states"], statuses=[TaskStatus.CANCELLED]
)
)[0]

self.get_success(
self.task_scheduler.update_task(
cancelled_task.id, status=TaskStatus.ACTIVE
)
)

# Let's be sure we are over the delay introduced by slow_update_membership
self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1)

membership = self.get_success(
self.storage_controllers.state.get_current_state(
room_id, StateFilter.from_types([state_tuple])
)
)
self.assertEqual(
membership[state_tuple].content["displayname"], "Frank Jr."
)

def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False

Expand Down
Loading