Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Additional type hints for client REST servlets (part 4) #10728

Merged
merged 9 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10728.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
82 changes: 39 additions & 43 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse

from twisted.web.server import Request

from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
Expand All @@ -28,15 +30,17 @@
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(self, hs: "HomeServer"):
template_text=self.config.email_password_reset_template_text,
)

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -159,7 +163,7 @@ async def on_POST(self, request):
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
Expand All @@ -169,7 +173,7 @@ def __init__(self, hs):
self._set_password_handler = hs.get_set_password_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

# we do basic sanity checks here because the auth layer will store these
Expand All @@ -190,6 +194,7 @@ async def on_POST(self, request):
#
# In the second case, we require a password to confirm their identity.

requester = None
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
try:
Expand All @@ -206,16 +211,15 @@ async def on_POST(self, request):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
new_password_hash,
)
raise
user_id = requester.user.to_string()
else:
requester = None
try:
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
Expand All @@ -230,11 +234,11 @@ async def on_POST(self, request):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
new_password_hash,
)
raise

Expand Down Expand Up @@ -264,7 +268,7 @@ async def on_POST(self, request):
# If we have a password in this request, prefer it. Otherwise, use the
# password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
password_hash: Optional[str] = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
Expand All @@ -288,15 +292,15 @@ async def on_POST(self, request):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
Expand Down Expand Up @@ -338,7 +342,7 @@ async def on_POST(self, request):
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
Expand All @@ -353,7 +357,7 @@ def __init__(self, hs):
template_text=self.config.email_add_threepid_template_text,
)

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -449,7 +453,7 @@ def __init__(self, hs: "HomeServer"):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_identity_handler()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
Expand Down Expand Up @@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
"/add_threepid/email/submit_token$", releases=(), unstable=True
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
Expand All @@ -539,7 +539,7 @@ def __init__(self, hs):
self.config.email_add_threepid_template_failure_html
)

async def on_GET(self, request):
async def on_GET(self, request: Request) -> None:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
Expand Down Expand Up @@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.identity_handler = hs.get_identity_handler()

async def on_POST(self, request):
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
if not self.config.account_threepid_delegate_msisdn:
raise SynapseError(
400,
Expand All @@ -632,22 +628,22 @@ async def on_POST(self, request):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

threepids = await self.datastore.user_get_threepids(requester.user.to_string())

return 200, {"threepids": threepids}

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -688,15 +684,15 @@ async def on_POST(self, request):
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -738,13 +734,13 @@ async def on_POST(self, request):
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
Expand All @@ -767,14 +763,14 @@ async def on_POST(self, request):
class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
Expand All @@ -798,13 +794,13 @@ async def on_POST(self, request):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
Expand Down Expand Up @@ -835,7 +831,7 @@ async def on_POST(self, request):
return 200, {"id_server_unbind_result": id_server_unbind_result}


def assert_valid_next_link(hs: "HomeServer", next_link: str):
def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
"""
Raises a SynapseError if a given next_link value is invalid

Expand Down Expand Up @@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

response = {"user_id": requester.user.to_string()}
Expand All @@ -894,7 +890,7 @@ async def on_GET(self, request):
return 200, response


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
Expand Down
6 changes: 4 additions & 2 deletions synapse/rest/client/knock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple

from twisted.web.server import Request

Expand Down Expand Up @@ -96,7 +96,9 @@ async def on_POST(

return 200, {"room_id": room_id}

def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
def on_PUT(
self, request: Request, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
Expand Down
14 changes: 7 additions & 7 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
unstable=True,
)

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
Expand All @@ -390,7 +386,7 @@ def __init__(self, hs):
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)

async def on_GET(self, request):
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))

if not self.hs.config.enable_registration:
Expand Down Expand Up @@ -730,7 +726,11 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, return_dict

async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False
self,
username: str,
as_token: str,
body: JsonDict,
should_issue_refresh_token: bool = False,
) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
Expand Down
Loading