Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,18 @@ async def update_e2e_room_key(
Raises:
StoreError
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")

await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
"version": version,
"version": version_int,
"room_id": room_id,
"session_id": session_id,
},
Expand All @@ -85,13 +91,19 @@ async def add_e2e_room_keys(
version: the version ID of the backup for the set of keys we're adding to
room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")

values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
{
"user_id": user_id,
"version": version,
"version": version_int,
"room_id": room_id,
"session_id": session_id,
"first_message_index": room_key["first_message_index"],
Expand Down Expand Up @@ -203,20 +215,26 @@ async def get_e2e_room_keys_multi(
Returns:
A map of room IDs to session IDs to room key
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return {}

return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
version,
version_int,
room_keys,
)

@staticmethod
def _get_e2e_room_keys_multi_txn(
txn: LoggingTransaction,
user_id: str,
version: str,
version: int,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
Expand Down Expand Up @@ -272,10 +290,16 @@ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return 0

return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
Expand All @@ -301,8 +325,14 @@ async def delete_e2e_room_keys(
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return

keyvalues = {"user_id": user_id, "version": int(version)}
keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
Expand All @@ -319,6 +349,8 @@ def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
"WHERE user_id=? AND deleted=0",
(user_id,),
)
# `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
# be `NULL` when there are no available versions.
row = cast(Tuple[Optional[int]], txn.fetchone())
if row[0] is None:
raise StoreError(404, "No current backup version")
Comment on lines +363 to 365
Copy link
Copy Markdown
Contributor Author

@squahtx squahtx Dec 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code was buggy. The query above will always return 1 row.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, because MAX(version) will aggregate to NULL if there are no rows permitted by the where clause? Might that be worth a comment?

Expand Down Expand Up @@ -395,7 +427,7 @@ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
if current_version is None:
current_version = 0

new_version = str(int(current_version) + 1)
new_version = current_version + 1

self.db_pool.simple_insert_txn(
txn,
Expand All @@ -408,7 +440,7 @@ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
},
)

return new_version
return str(new_version)

return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
Expand Down Expand Up @@ -440,9 +472,16 @@ async def update_e2e_room_keys_version(
updatevalues["etag"] = version_etag

if updatevalues:
await self.db_pool.simple_update(
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")

await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
Expand All @@ -467,7 +506,12 @@ def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
this_version = int(version)
try:
this_version = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
raise StoreError(404, "No row found")

self.db_pool.simple_delete_txn(
txn,
Expand Down