Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing type hints to synapse.api. (#11109)
Browse files Browse the repository at this point in the history
* Convert UserPresenceState to attrs.
* Remove args/kwargs from error classes and explicitly pass msg/errorcode.
  • Loading branch information
clokep authored Oct 18, 2021
1 parent cc33d9e commit 3ab55d4
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 99 deletions.
1 change: 1 addition & 0 deletions changelog.d/11109.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.api` module.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.api.*]
disallow_untyped_defs = True

[mypy-synapse.events.*]
disallow_untyped_defs = True

Expand Down
14 changes: 11 additions & 3 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def get_user_by_req(

async def validate_appservice_can_control_user_id(
self, app_service: ApplicationService, user_id: str
):
) -> None:
"""Validates that the app service is allowed to control
the given user.
Expand Down Expand Up @@ -618,5 +618,13 @@ async def check_user_in_room_or_world_readable(
% (user_id, room_id),
)

async def check_auth_blocking(self, *args, **kwargs) -> None:
await self._auth_blocking.check_auth_blocking(*args, **kwargs)
async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)
69 changes: 23 additions & 46 deletions synapse/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from twisted.web import http

Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN):
super().__init__(code, msg)
self.errcode = errcode

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode)


Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
else:
self._additional_fields = dict(additional_fields)

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)


Expand All @@ -196,7 +196,7 @@ def __init__(self, msg: str, consent_uri: str):
)
self._consent_uri = consent_uri

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)


Expand Down Expand Up @@ -262,14 +262,10 @@ def __init__(self, session_id: str, result: "JsonDict"):
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super().__init__(400, message, **kwargs)
def __init__(
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
):
super().__init__(400, msg, errcode)


class NotFoundError(SynapseError):
Expand All @@ -284,10 +280,8 @@ class AuthError(SynapseError):
other poorly-defined times.
"""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(code, msg, errcode)


class InvalidClientCredentialsError(SynapseError):
Expand Down Expand Up @@ -321,7 +315,7 @@ def __init__(
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout

def error_dict(self):
def error_dict(self) -> "JsonDict":
d = super().error_dict()
d["soft_logout"] = self._soft_logout
return d
Expand All @@ -345,7 +339,7 @@ def __init__(
self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode)

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(
self.msg,
self.errcode,
Expand All @@ -357,32 +351,17 @@ def error_dict(self):
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
super().__init__(413, *args, **kwargs)


class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
super().__init__(*args, **kwargs)
def __init__(self, msg: str):
super().__init__(413, msg, Codes.TOO_LARGE)


class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""

pass


class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""

pass


class InvalidCaptchaError(SynapseError):
def __init__(
Expand All @@ -395,7 +374,7 @@ def __init__(
super().__init__(code, msg, errcode)
self.error_url = error_url

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url)


Expand All @@ -412,7 +391,7 @@ def __init__(
super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)


Expand Down Expand Up @@ -443,10 +422,8 @@ def __init__(self, msg: str = "Homeserver does not support this room version"):
class ThreepidValidationError(SynapseError):
"""An error raised when there was a problem authorising an event."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(400, msg, errcode)


class IncompatibleRoomVersionError(SynapseError):
Expand All @@ -466,7 +443,7 @@ def __init__(self, room_version: str):

self._room_version = room_version

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version)


Expand Down Expand Up @@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
errors (like programming errors).
"""

def __init__(self, inner_exception, can_retry):
def __init__(self, inner_exception: BaseException, can_retry: bool):
super().__init__(
"Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception)
Expand All @@ -503,7 +480,7 @@ def __init__(self, inner_exception, can_retry):
self.can_retry = can_retry


def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server
interactions.
Expand Down Expand Up @@ -551,7 +528,7 @@ def __init__(
msg = "%s %s: %s" % (level, code, reason)
super().__init__(msg)

def get_dict(self):
def get_dict(self) -> "JsonDict":
return {
"level": self.level,
"code": self.code,
Expand Down Expand Up @@ -580,7 +557,7 @@ def __init__(self, code: int, msg: str, response: bytes):
super().__init__(code, msg)
self.response = response

def to_synapse_error(self):
def to_synapse_error(self) -> SynapseError:
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to
Expand Down
18 changes: 9 additions & 9 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,24 @@ def lazy_load_members(self) -> bool:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()

def filter_presence(self, events):
def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events)

def filter_account_data(self, events):
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events)

def filter_room_state(self, events):
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events))

def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))

def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))

def filter_room_account_data(
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_account_data.filter(self._room_filter.filter(events))

def blocks_all_presence(self) -> bool:
Expand Down Expand Up @@ -309,7 +309,7 @@ def check(self, event: FilterEvent) -> bool:
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
sender = event.user_id
sender: Optional[str] = event.user_id
room_id = None
ev_type = "m.presence"
contains_url = False
Expand Down
51 changes: 25 additions & 26 deletions synapse/api/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from typing import Any, Optional

import attr

from synapse.api.constants import PresenceState
from synapse.types import JsonDict


class UserPresenceState(
namedtuple(
"UserPresenceState",
(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPresenceState:
"""Represents the current presence state of the user.
user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
user_id
last_active: Time in msec that the user last interacted with server.
last_federation_update: Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
last_user_sync: Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
status_msg: User set status message.
"""

def as_dict(self):
return dict(self._asdict())
user_id: str
state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: Optional[str]
currently_active: bool

def as_dict(self) -> JsonDict:
return attr.asdict(self)

@staticmethod
def from_dict(d):
def from_dict(d: JsonDict) -> "UserPresenceState":
return UserPresenceState(**d)

def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return attr.evolve(self, **kwargs)

@classmethod
def default(cls, user_id):
def default(cls, user_id: str) -> "UserPresenceState":
"""Returns a default presence state."""
return cls(
user_id=user_id,
Expand Down
4 changes: 2 additions & 2 deletions synapse/api/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def can_do_action(

return allowed, time_allowed

def _prune_message_counts(self, time_now_s: float):
def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit
Expand Down Expand Up @@ -190,7 +190,7 @@ async def ratelimit(
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
):
) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError
Checks if the user has ratelimiting disabled in the database by looking
Expand Down
Loading

0 comments on commit 3ab55d4

Please sign in to comment.