|
35 | 35 | from synapse.handlers.device import DeviceHandler |
36 | 36 | from synapse.logging.context import make_deferred_yieldable, run_in_background |
37 | 37 | from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace |
| 38 | +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet |
38 | 39 | from synapse.types import ( |
39 | 40 | JsonDict, |
40 | 41 | JsonMapping, |
@@ -89,6 +90,12 @@ def __init__(self, hs: "HomeServer"): |
89 | 90 | edu_updater.incoming_signing_key_update, |
90 | 91 | ) |
91 | 92 |
|
| 93 | + self.device_key_uploader = self.upload_device_keys_for_user |
| 94 | + else: |
| 95 | + self.device_key_uploader = ( |
| 96 | + ReplicationUploadKeysForUserRestServlet.make_client(hs) |
| 97 | + ) |
| 98 | + |
92 | 99 | # doesn't really work as part of the generic query API, because the |
93 | 100 | # query request requires an object POST, but we abuse the |
94 | 101 | # "query handler" interface. |
@@ -796,36 +803,17 @@ async def upload_keys_for_user( |
796 | 803 | "one_time_keys": A mapping from algorithm to number of keys for that |
797 | 804 | algorithm, including those previously persisted. |
798 | 805 | """ |
799 | | - # This can only be called from the main process. |
800 | | - assert isinstance(self.device_handler, DeviceHandler) |
801 | | - |
802 | 806 | time_now = self.clock.time_msec() |
803 | 807 |
|
804 | 808 | # TODO: Validate the JSON to make sure it has the right keys. |
805 | 809 | device_keys = keys.get("device_keys", None) |
806 | 810 | if device_keys: |
807 | | - logger.info( |
808 | | - "Updating device_keys for device %r for user %s at %d", |
809 | | - device_id, |
810 | | - user_id, |
811 | | - time_now, |
| 811 | + await self.device_key_uploader( |
| 812 | + user_id=user_id, |
| 813 | + device_id=device_id, |
| 814 | + keys={"device_keys": device_keys}, |
812 | 815 | ) |
813 | | - log_kv( |
814 | | - { |
815 | | - "message": "Updating device_keys for user.", |
816 | | - "user_id": user_id, |
817 | | - "device_id": device_id, |
818 | | - } |
819 | | - ) |
820 | | - # TODO: Sign the JSON with the server key |
821 | | - changed = await self.store.set_e2e_device_keys( |
822 | | - user_id, device_id, time_now, device_keys |
823 | | - ) |
824 | | - if changed: |
825 | | - # Only notify about device updates *if* the keys actually changed |
826 | | - await self.device_handler.notify_device_update(user_id, [device_id]) |
827 | | - else: |
828 | | - log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) |
| 816 | + |
829 | 817 | one_time_keys = keys.get("one_time_keys", None) |
830 | 818 | if one_time_keys: |
831 | 819 | log_kv( |
@@ -861,18 +849,56 @@ async def upload_keys_for_user( |
861 | 849 | {"message": "Did not update fallback_keys", "reason": "no keys given"} |
862 | 850 | ) |
863 | 851 |
|
| 852 | + result = await self.store.count_e2e_one_time_keys(user_id, device_id) |
| 853 | + |
| 854 | + set_tag("one_time_key_counts", str(result)) |
| 855 | + return {"one_time_key_counts": result} |
| 856 | + |
| 857 | + @tag_args |
| 858 | + async def upload_device_keys_for_user( |
| 859 | + self, user_id: str, device_id: str, keys: JsonDict |
| 860 | + ) -> None: |
| 861 | + """ |
| 862 | + Args: |
| 863 | + user_id: user whose keys are being uploaded. |
| 864 | + device_id: device whose keys are being uploaded. |
| 865 | + device_keys: the `device_keys` of an /keys/upload request. |
| 866 | +
|
| 867 | + """ |
| 868 | + # This can only be called from the main process. |
| 869 | + assert isinstance(self.device_handler, DeviceHandler) |
| 870 | + |
| 871 | + time_now = self.clock.time_msec() |
| 872 | + |
| 873 | + device_keys = keys["device_keys"] |
| 874 | + logger.info( |
| 875 | + "Updating device_keys for device %r for user %s at %d", |
| 876 | + device_id, |
| 877 | + user_id, |
| 878 | + time_now, |
| 879 | + ) |
| 880 | + log_kv( |
| 881 | + { |
| 882 | + "message": "Updating device_keys for user.", |
| 883 | + "user_id": user_id, |
| 884 | + "device_id": device_id, |
| 885 | + } |
| 886 | + ) |
| 887 | + # TODO: Sign the JSON with the server key |
| 888 | + changed = await self.store.set_e2e_device_keys( |
| 889 | + user_id, device_id, time_now, device_keys |
| 890 | + ) |
| 891 | + if changed: |
| 892 | + # Only notify about device updates *if* the keys actually changed |
| 893 | + await self.device_handler.notify_device_update(user_id, [device_id]) |
| 894 | + |
864 | 895 | # the device should have been registered already, but it may have been |
865 | 896 | # deleted due to a race with a DELETE request. Or we may be using an |
866 | 897 | # old access_token without an associated device_id. Either way, we |
867 | 898 | # need to double-check the device is registered to avoid ending up with |
868 | 899 | # keys without a corresponding device. |
869 | 900 | await self.device_handler.check_device_registered(user_id, device_id) |
870 | 901 |
|
871 | | - result = await self.store.count_e2e_one_time_keys(user_id, device_id) |
872 | | - |
873 | | - set_tag("one_time_key_counts", str(result)) |
874 | | - return {"one_time_key_counts": result} |
875 | | - |
876 | 902 | async def _upload_one_time_keys_for_user( |
877 | 903 | self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict |
878 | 904 | ) -> None: |
|
0 commit comments