Skip to content

Commit

Permalink
WIP strict mypy for sydent.db
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Oct 14, 2021
1 parent 8ef46f4 commit 9c8c6b2
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 32 deletions.
5 changes: 3 additions & 2 deletions sydent/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, cast

from sydent.users.accounts import Account

Expand Down Expand Up @@ -104,6 +104,7 @@ def delToken(self, token: str) -> int:
"delete from tokens where token = ?",
(token,),
)
deleted = cur.rowcount
# Cast safety: DBAPI-2 says this is a "number"; c.f. python/typeshed#6150
deleted = cast(int, cur.rowcount)
self.sydent.db.commit()
return deleted
10 changes: 8 additions & 2 deletions sydent/db/hashing_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Actions on the hashing_metadata table which is defined in the migration process in
# sqlitedb.py
from sqlite3 import Cursor
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Tuple

if TYPE_CHECKING:
from sydent.sydent import Sydent
Expand All @@ -33,7 +33,13 @@ def get_lookup_pepper(self) -> Optional[str]:
"""
cur = self.sydent.db.cursor()
res = cur.execute("select lookup_pepper from hashing_metadata")
row = res.fetchone()
# Annotation safety: lookup_pepper is marked as varchar(256) in the
# schema, so could be null. I.e. `row` should strictly be
# Optional[Tuple[Optional[str]].
# But I think the application code is such that either
# - hashing_metadata contains no rows
# - or it contains exactly one row with a nonnull lookup_pepper.
row: Optional[Tuple[str]] = res.fetchone()

if not row:
return None
Expand Down
8 changes: 5 additions & 3 deletions sydent/db/invite_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast

if TYPE_CHECKING:
from sydent.sydent import Sydent
Expand Down Expand Up @@ -152,7 +152,9 @@ def validateEphemeralPublicKey(self, publicKey: str) -> bool:
(publicKey,),
)
self.sydent.db.commit()
return cur.rowcount > 0
# Cast safety: DBAPI-2 says this is a "number"; c.f. python/typeshed#6150
rows = cast(int, cur.rowcount)
return rows > 0

def getSenderForToken(self, token: str) -> Optional[str]:
"""
Expand All @@ -165,7 +167,7 @@ def getSenderForToken(self, token: str) -> Optional[str]:
"""
cur = self.sydent.db.cursor()
res = cur.execute("SELECT sender FROM invite_tokens WHERE token = ?", (token,))
rows = res.fetchall()
rows: List[Tuple[str]] = res.fetchall()
if rows:
return rows[0][0]
return None
Expand Down
15 changes: 10 additions & 5 deletions sydent/db/peers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from sydent.replication.peer import RemotePeer

Expand All @@ -39,11 +39,16 @@ def getPeerByName(self, name: str) -> Optional[RemotePeer]:
(name,),
)

serverName = None
port = None
lastSentVer = None
pubkeys = {}
# Type safety: if the query returns no rows, we'll pubkeys will be empty
# and we'll return None before using serverName. Otherwise, we'll read
# at least one row and assign serverName a string value, because the
# `name` column is declared `not null` in the DB.
serverName: str = None # type: ignore[assignment]
port: Optional[int] = None
lastSentVer: Optional[int] = None
pubkeys: Dict[str, str] = {}

row: Tuple[str, Optional[int], Optional[int], str, str]
for row in res.fetchall():
serverName = row[0]
port = row[1]
Expand Down
12 changes: 6 additions & 6 deletions sydent/db/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import os
import sqlite3
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

if TYPE_CHECKING:
from sydent.sydent import Sydent
Expand All @@ -42,7 +42,7 @@ def __init__(self, syd: "Sydent") -> None:
self._createSchema()
self._upgradeSchema()

def _createSchema(self):
def _createSchema(self) -> None:
logger.info("Running schema files...")
schemaDir = os.path.dirname(__file__)

Expand All @@ -64,7 +64,7 @@ def _createSchema(self):
c.close()
self.db.commit()

def _upgradeSchema(self):
def _upgradeSchema(self) -> None:
curVer = self._getSchemaVersion()

if curVer < 1:
Expand Down Expand Up @@ -212,13 +212,13 @@ def _upgradeSchema(self):
logger.info("v4 -> v5 schema migration complete")
self._setSchemaVersion(5)

def _getSchemaVersion(self):
def _getSchemaVersion(self) -> int:
cur = self.db.cursor()
cur.execute("PRAGMA user_version")
row = cur.fetchone()
row: Tuple[int] = cur.fetchone()
return row[0]

def _setSchemaVersion(self, ver):
def _setSchemaVersion(self, ver: int) -> None:
cur = self.db.cursor()
# NB. pragma doesn't support variable substitution so we
# do it in python (as a decimal so we don't risk SQL injection)
Expand Down
6 changes: 3 additions & 3 deletions sydent/db/threepid_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def signedAssociationStringForThreepid(
(medium, address, time_msec(), time_msec()),
)

row = res.fetchone()
row: Optional[Tuple[str]] = res.fetchone()

if not row:
return None
Expand Down Expand Up @@ -233,7 +233,7 @@ def getMxid(self, medium: str, normalised_address: str) -> Optional[str]:
(medium, normalised_address, time_msec(), time_msec()),
)

row = res.fetchone()
row: Tuple[str] = res.fetchone()

if not row:
return None
Expand Down Expand Up @@ -352,7 +352,7 @@ def lastIdFromServer(self, server: str) -> Optional[int]:
"where originServer = ?",
(server,),
)
row = res.fetchone()
row: Tuple[int, int] = res.fetchone()

if row[1] == 0:
return None
Expand Down
2 changes: 1 addition & 1 deletion sydent/replication/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
server_name: str,
port: Optional[int],
pubkeys: Dict[str, str],
lastSentVersion: int,
lastSentVersion: Optional[int],
) -> None:
"""
:param sydent: The current Sydent instance.
Expand Down
3 changes: 2 additions & 1 deletion sydent/sydent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import logging.handlers
import os
import sqlite3
from typing import Optional

import twisted.internet.reactor
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(

logger.info("Starting Sydent server")

self.db = SqliteDatabase(self).db
self.db: sqlite3.Connection = SqliteDatabase(self).db

if self.config.general.sentry_enabled:
import sentry_sdk
Expand Down
12 changes: 11 additions & 1 deletion sydent/threepid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


def threePidAssocFromDict(d):
Expand All @@ -36,7 +37,16 @@ def threePidAssocFromDict(d):


class ThreepidAssociation:
def __init__(self, medium, address, lookup_hash, mxid, ts, not_before, not_after):
def __init__(
self,
medium: str,
address: str,
lookup_hash: Optional[str],
mxid: str,
ts: int,
not_before: int,
not_after: int,
):
"""
:param medium: The medium of the 3pid (eg. email)
:param address: The identifier (eg. email address)
Expand Down
17 changes: 9 additions & 8 deletions sydent/validators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


class ValidationSession:
Expand All @@ -22,14 +23,14 @@ class ValidationSession:

def __init__(
self,
_id,
_medium,
_address,
_clientSecret,
_validated,
_mtime,
_token,
_sendAttemptNumber,
_id: int,
_medium: str,
_address: str,
_clientSecret: str,
_validated: int, # bool, but sqlite has no bool type
_mtime: int,
_token: Optional[str],
_sendAttemptNumber: Optional[int],
):
self.id = _id
self.medium = _medium
Expand Down

0 comments on commit 9c8c6b2

Please sign in to comment.