Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Additional type hints for client REST servlets (part 4) (#10728)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 1, 2021
1 parent dc75fb7 commit d1f1b46
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 100 deletions.
1 change: 1 addition & 0 deletions changelog.d/10728.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
82 changes: 39 additions & 43 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse

from twisted.web.server import Request

from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
Expand All @@ -28,15 +30,17 @@
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(self, hs: "HomeServer"):
template_text=self.config.email_password_reset_template_text,
)

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -159,7 +163,7 @@ async def on_POST(self, request):
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
Expand All @@ -169,7 +173,7 @@ def __init__(self, hs):
self._set_password_handler = hs.get_set_password_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

# we do basic sanity checks here because the auth layer will store these
Expand All @@ -190,6 +194,7 @@ async def on_POST(self, request):
#
# In the second case, we require a password to confirm their identity.

requester = None
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
try:
Expand All @@ -206,16 +211,15 @@ async def on_POST(self, request):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
new_password_hash,
)
raise
user_id = requester.user.to_string()
else:
requester = None
try:
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
Expand All @@ -230,11 +234,11 @@ async def on_POST(self, request):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
new_password_hash,
)
raise

Expand Down Expand Up @@ -264,7 +268,7 @@ async def on_POST(self, request):
# If we have a password in this request, prefer it. Otherwise, use the
# password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
password_hash: Optional[str] = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
Expand All @@ -288,15 +292,15 @@ async def on_POST(self, request):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
Expand Down Expand Up @@ -338,7 +342,7 @@ async def on_POST(self, request):
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
Expand All @@ -353,7 +357,7 @@ def __init__(self, hs):
template_text=self.config.email_add_threepid_template_text,
)

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -449,7 +453,7 @@ def __init__(self, hs: "HomeServer"):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_identity_handler()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
Expand Down Expand Up @@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
"/add_threepid/email/submit_token$", releases=(), unstable=True
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
Expand All @@ -539,7 +539,7 @@ def __init__(self, hs):
self.config.email_add_threepid_template_failure_html
)

async def on_GET(self, request):
async def on_GET(self, request: Request) -> None:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.identity_handler = hs.get_identity_handler()

async def on_POST(self, request):
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
if not self.config.account_threepid_delegate_msisdn:
raise SynapseError(
400,
Expand All @@ -632,22 +628,22 @@ async def on_POST(self, request):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

threepids = await self.datastore.user_get_threepids(requester.user.to_string())

return 200, {"threepids": threepids}

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -688,15 +684,15 @@ async def on_POST(self, request):
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -738,13 +734,13 @@ async def on_POST(self, request):
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
Expand All @@ -767,14 +763,14 @@ async def on_POST(self, request):
class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
Expand All @@ -798,13 +794,13 @@ async def on_POST(self, request):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -835,7 +831,7 @@ async def on_POST(self, request):
return 200, {"id_server_unbind_result": id_server_unbind_result}


def assert_valid_next_link(hs: "HomeServer", next_link: str):
def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
"""
Raises a SynapseError if a given next_link value is invalid
Expand Down Expand Up @@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

response = {"user_id": requester.user.to_string()}
Expand All @@ -894,7 +890,7 @@ async def on_GET(self, request):
return 200, response


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
Expand Down
6 changes: 4 additions & 2 deletions synapse/rest/client/knock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple

from twisted.web.server import Request

Expand Down Expand Up @@ -96,7 +96,9 @@ async def on_POST(

return 200, {"room_id": room_id}

def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
def on_PUT(
self, request: Request, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
Expand Down
14 changes: 7 additions & 7 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
unstable=True,
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
Expand All @@ -390,7 +386,7 @@ def __init__(self, hs):
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)

async def on_GET(self, request):
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))

if not self.hs.config.enable_registration:
Expand Down Expand Up @@ -730,7 +726,11 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, return_dict

async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False
self,
username: str,
as_token: str,
body: JsonDict,
should_issue_refresh_token: bool = False,
) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
Expand Down
Loading

0 comments on commit d1f1b46

Please sign in to comment.