|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | | - |
17 | 16 | from twisted.internet import defer |
18 | 17 |
|
19 | 18 | from ._base import BaseHandler |
20 | 19 | from synapse.api.constants import LoginType |
21 | | -from synapse.types import UserID |
22 | 20 | from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError |
| 21 | +from synapse.module_api import ModuleApi |
| 22 | +from synapse.types import UserID |
23 | 23 | from synapse.util.async import run_on_reactor |
24 | 24 | from synapse.util.caches.expiringcache import ExpiringCache |
25 | 25 |
|
@@ -63,10 +63,7 @@ def __init__(self, hs): |
63 | 63 | reset_expiry_on_get=True, |
64 | 64 | ) |
65 | 65 |
|
66 | | - account_handler = _AccountHandler( |
67 | | - hs, check_user_exists=self.check_user_exists |
68 | | - ) |
69 | | - |
| 66 | + account_handler = ModuleApi(hs, self) |
70 | 67 | self.password_providers = [ |
71 | 68 | module(config=config, account_handler=account_handler) |
72 | 69 | for module, config in hs.config.password_providers |
@@ -843,66 +840,3 @@ def _generate_base_macaroon(self, user_id): |
843 | 840 | macaroon.add_first_party_caveat("gen = 1") |
844 | 841 | macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) |
845 | 842 | return macaroon |
846 | | - |
847 | | - |
848 | | -class _AccountHandler(object): |
849 | | - """A proxy object that gets passed to password auth providers so they |
850 | | - can register new users etc if necessary. |
851 | | - """ |
852 | | - def __init__(self, hs, check_user_exists): |
853 | | - self.hs = hs |
854 | | - |
855 | | - self._check_user_exists = check_user_exists |
856 | | - self._store = hs.get_datastore() |
857 | | - |
858 | | - def get_qualified_user_id(self, username): |
859 | | - """Qualify a user id, if necessary |
860 | | -
|
861 | | - Takes a user id provided by the user and adds the @ and :domain to |
862 | | - qualify it, if necessary |
863 | | -
|
864 | | - Args: |
865 | | - username (str): provided user id |
866 | | -
|
867 | | - Returns: |
868 | | - str: qualified @user:id |
869 | | - """ |
870 | | - if username.startswith('@'): |
871 | | - return username |
872 | | - return UserID(username, self.hs.hostname).to_string() |
873 | | - |
874 | | - def check_user_exists(self, user_id): |
875 | | - """Check if user exists. |
876 | | -
|
877 | | - Args: |
878 | | - user_id (str): Complete @user:id |
879 | | -
|
880 | | - Returns: |
881 | | - Deferred[str|None]: Canonical (case-corrected) user_id, or None |
882 | | - if the user is not registered. |
883 | | - """ |
884 | | - return self._check_user_exists(user_id) |
885 | | - |
886 | | - def register(self, localpart): |
887 | | - """Registers a new user with given localpart |
888 | | -
|
889 | | - Returns: |
890 | | - Deferred: a 2-tuple of (user_id, access_token) |
891 | | - """ |
892 | | - reg = self.hs.get_handlers().registration_handler |
893 | | - return reg.register(localpart=localpart) |
894 | | - |
895 | | - def run_db_interaction(self, desc, func, *args, **kwargs): |
896 | | - """Run a function with a database connection |
897 | | -
|
898 | | - Args: |
899 | | - desc (str): description for the transaction, for metrics etc |
900 | | - func (func): function to be run. Passed a database cursor object |
901 | | - as well as *args and **kwargs |
902 | | - *args: positional args to be passed to func |
903 | | - **kwargs: named args to be passed to func |
904 | | -
|
905 | | - Returns: |
906 | | - Deferred[object]: result of func |
907 | | - """ |
908 | | - return self._store.runInteraction(desc, func, *args, **kwargs) |
0 commit comments