Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38339: Update mypy configuration and type annotations #211

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,17 @@ python_files = ["tests/*.py", "tests/*/*.py"]
disallow_untyped_defs = true
disallow_incomplete_defs = true
ignore_missing_imports = true
local_partial_types = true
plugins = ["pydantic.mypy"]
no_implicit_reexport = true
show_error_codes = true
strict_equality = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
warn_untyped_fields = true
6 changes: 3 additions & 3 deletions src/mobu/business/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from __future__ import annotations

import asyncio
from asyncio import Queue, QueueEmpty, TimeoutError
from asyncio import Queue, QueueEmpty
from collections.abc import AsyncIterable, AsyncIterator
from datetime import datetime, timezone
from enum import Enum
from typing import AsyncIterable, AsyncIterator, TypeVar
from typing import TypeVar

from structlog import BoundLogger

Expand Down Expand Up @@ -194,7 +195,6 @@ async def iter_next() -> T:
def dump(self) -> BusinessData:
return BusinessData(
name=type(self).__name__,
config=self.config,
failure_count=self.failure_count,
success_count=self.success_count,
timings=self.timings.dump(),
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/business/jupyterloginloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Dict, Optional
from typing import Optional

from aiohttp import ClientError, ClientResponseError
from structlog import BoundLogger
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
async def close(self) -> None:
await self._client.close()

def annotations(self) -> Dict[str, str]:
def annotations(self) -> dict[str, str]:
"""Timer annotations to use.

Subclasses should override this to add more annotations based on
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/business/jupyterpythonloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Dict, Optional
from typing import Optional

from structlog import BoundLogger

Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
super().__init__(logger, business_config, user)
self.node: Optional[str] = None

def annotations(self) -> Dict[str, str]:
def annotations(self) -> dict[str, str]:
result = super().annotations()
if self.node:
result["node"] = self.node
Expand Down
16 changes: 8 additions & 8 deletions src/mobu/business/notebookrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import random
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional
from typing import Any, Optional

import git
from git.repo import Repo
from structlog import BoundLogger

from ..exceptions import NotebookRepositoryError
Expand All @@ -37,10 +37,10 @@ def __init__(
self.notebook: Optional[Path] = None
self.running_code: Optional[str] = None
self._repo_dir = TemporaryDirectory()
self._repo: Optional[git.Repo] = None
self._notebook_paths: Optional[List[Path]] = None
self._repo: Optional[Repo] = None
self._notebook_paths: Optional[list[Path]] = None

def annotations(self) -> Dict[str, str]:
def annotations(self) -> dict[str, str]:
result = super().annotations()
if self.notebook:
result["notebook"] = self.notebook.name
Expand All @@ -58,9 +58,9 @@ def clone_repo(self) -> None:
branch = self.config.repo_branch
path = self._repo_dir.name
with self.timings.start("clone_repo"):
self._repo = git.Repo.clone_from(url, path, branch=branch)
self._repo = Repo.clone_from(url, path, branch=branch)

def find_notebooks(self) -> List[Path]:
def find_notebooks(self) -> list[Path]:
with self.timings.start("find_notebooks"):
notebooks = [
p
Expand All @@ -79,7 +79,7 @@ def next_notebook(self) -> None:
self._notebook_paths = self.find_notebooks()
self.notebook = self._notebook_paths.pop()

def read_notebook(self, notebook: Path) -> List[Dict[str, Any]]:
def read_notebook(self, notebook: Path) -> list[dict[str, Any]]:
with self.timings.start("read_notebook", {"notebook": notebook.name}):
try:
notebook_text = notebook.read_text()
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/business/tapqueryrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Optional

import jinja2
import pyvo
Expand Down Expand Up @@ -87,7 +87,7 @@ def _generate_random_polygon(
poly.append(dec + r * math.cos(theta))
return ", ".join([str(x) for x in poly])

def _generate_parameters(self) -> Dict[str, Union[int, float, str]]:
def _generate_parameters(self) -> dict[str, int | float | str]:
"""Generate some random parameters for the query."""
min_ra = self._params.get("min_ra", 55.0)
max_ra = self._params.get("max_ra", 70.0)
Expand Down
4 changes: 1 addition & 3 deletions src/mobu/cachemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import List

from aiohttp import ClientSession

from .config import config
Expand Down Expand Up @@ -69,7 +67,7 @@ async def get_recommended(self) -> JupyterImage:
raise CachemachineError(self._username, "No images found")
return images[0]

async def _get_images(self) -> List[JupyterImage]:
async def _get_images(self) -> list[JupyterImage]:
headers = {"Authorization": f"bearer {self._token}"}
async with self._session.get(self._url, headers=headers) as r:
if r.status != 200:
Expand Down
7 changes: 3 additions & 4 deletions src/mobu/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
from dataclasses import dataclass
from typing import Optional

__all__ = ["Configuration", "config"]

Expand All @@ -13,15 +12,15 @@
class Configuration:
"""Configuration for mobu."""

alert_hook: Optional[str] = os.getenv("ALERT_HOOK")
alert_hook: str | None = os.getenv("ALERT_HOOK")
"""The slack webhook used for alerting exceptions to slack.

Set with the ``ALERT_HOOK`` environment variable.
This is an https URL which should be considered secret.
If not set or set to "None", this feature will be disabled.
"""

autostart: Optional[str] = os.getenv("AUTOSTART")
autostart: str | None = os.getenv("AUTOSTART")
"""The path to a YAML file defining what flocks to automatically start.

The YAML file should, if given, be a list of flock specifications. All
Expand All @@ -48,7 +47,7 @@ class Configuration:
Set with the ``CACHEMACHINE_IMAGE_POLICY`` environment variable.
"""

gafaelfawr_token: Optional[str] = os.getenv("GAFAELFAWR_TOKEN")
gafaelfawr_token: str | None = os.getenv("GAFAELFAWR_TOKEN")
"""The Gafaelfawr admin token to use to create user tokens.

This token is used to make an admin API call to Gafaelfawr to get a token
Expand Down
8 changes: 4 additions & 4 deletions src/mobu/dependencies/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from typing import Dict, List, Optional
from typing import Optional

from aiohttp import ClientSession
from aiojobs import Scheduler
Expand All @@ -19,7 +19,7 @@ class MonkeyBusinessManager:
"""Manages all of the running monkeys."""

def __init__(self) -> None:
self._flocks: Dict[str, Flock] = {}
self._flocks: dict[str, Flock] = {}
self._scheduler: Optional[Scheduler] = None
self._session: Optional[ClientSession] = None

Expand Down Expand Up @@ -56,10 +56,10 @@ def get_flock(self, name: str) -> Flock:
raise FlockNotFoundException(name)
return flock

def list_flocks(self) -> List[str]:
def list_flocks(self) -> list[str]:
return sorted(self._flocks.keys())

def summarize_flocks(self) -> List[FlockSummary]:
def summarize_flocks(self) -> list[FlockSummary]:
return [f.summary() for _, f in sorted(self._flocks.items())]

async def stop_flock(self, name: str) -> None:
Expand Down
35 changes: 17 additions & 18 deletions src/mobu/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

from datetime import datetime
from typing import Dict, Optional
from typing import Optional, Self

from aiohttp import ClientResponse, ClientResponseError
from safir.datetime import format_datetime_for_logging
from safir.slack.blockkit import (
SlackBaseBlock,
SlackBaseField,
SlackCodeBlock,
SlackException,
SlackMessage,
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(self, user: str, msg: str) -> None:
super().__init__(msg, user)
self.started_at: Optional[datetime] = None
self.event: Optional[str] = None
self.annotations: Dict[str, str] = {}
self.annotations: dict[str, str] = {}

def to_slack(self) -> SlackMessage:
"""Format the error as a Slack Block Kit message.
Expand All @@ -73,13 +75,14 @@ def to_slack(self) -> SlackMessage:
"""
return SlackMessage(message=str(self), fields=self.common_fields())

def common_fields(self) -> list[SlackTextField]:
def common_fields(self) -> list[SlackBaseField]:
"""Return common fields to put in any alert."""
failed_at = format_datetime_for_logging(self.failed_at)
fields = [
fields: list[SlackBaseField] = [
SlackTextField(heading="Failed at", text=failed_at),
SlackTextField(heading="User", text=self.user),
]
if self.user:
fields.append(SlackTextField(heading="User", text=self.user))
if self.started_at:
started_at = format_datetime_for_logging(self.started_at)
field = SlackTextField(heading="Started at", text=started_at)
Expand Down Expand Up @@ -149,7 +152,9 @@ def to_slack(self) -> SlackMessage:
if self.status:
intro += f" (status: {self.status})"

attachments = [SlackCodeBlock(heading="Code executed", code=self.code)]
attachments: list[SlackBaseBlock] = [
SlackCodeBlock(heading="Code executed", code=self.code)
]
if self.error:
attachment = SlackCodeBlock(heading="Error", code=self.error)
attachments.insert(0, attachment)
Expand All @@ -169,26 +174,22 @@ class JupyterResponseError(MobuSlackException):
"""Web response error from JupyterHub or JupyterLab."""

@classmethod
def from_exception(
cls, user: str, exc: ClientResponseError
) -> JupyterResponseError:
def from_exception(cls, user: str, exc: ClientResponseError) -> Self:
return cls(
url=str(exc.request_info.url),
user=user,
status=exc.status,
reason=exc.message if exc.message else type(exc).__name__,
reason=exc.message or type(exc).__name__,
method=exc.request_info.method,
)

@classmethod
async def from_response(
cls, user: str, response: ClientResponse
) -> JupyterResponseError:
async def from_response(cls, user: str, response: ClientResponse) -> Self:
return cls(
url=str(response.url),
user=user,
status=response.status,
reason=response.reason,
reason=response.reason or "",
method=response.method,
body=await response.text(),
)
Expand All @@ -199,7 +200,7 @@ def __init__(
url: str,
user: str,
status: int,
reason: Optional[str],
reason: str,
method: str,
body: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -232,9 +233,7 @@ class JupyterSpawnError(MobuSlackException):
"""The Jupyter Lab pod failed to spawn."""

@classmethod
def from_exception(
cls, user: str, log: str, exc: Exception
) -> JupyterSpawnError:
def from_exception(cls, user: str, log: str, exc: Exception) -> Self:
return cls(user, log, f"{type(exc).__name__}: {str(exc)}")

def __init__(
Expand Down
14 changes: 3 additions & 11 deletions src/mobu/jupyterclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,12 @@
import random
import re
import string
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import wraps
from http.cookies import BaseCookie
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Optional,
TypeVar,
cast,
)
from typing import Any, Optional, TypeVar, cast
from uuid import uuid4

from aiohttp import (
Expand Down Expand Up @@ -501,7 +493,7 @@ def _remove_ansi_escapes(string: str) -> str:
"""
return _ANSI_REGEX.sub("", string)

def _build_jupyter_spawn_form(self, image: JupyterImage) -> Dict[str, str]:
def _build_jupyter_spawn_form(self, image: JupyterImage) -> dict[str, str]:
"""Construct the form to submit to the JupyterHub login page."""
return {
"image_list": str(image),
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/models/business.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for monkey business."""

from typing import List, Optional
from typing import Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -187,7 +187,7 @@ class BusinessData(BaseModel):

success_count: int = Field(..., title="Number of successes", example=25)

timings: List[StopwatchData] = Field(..., title="Timings of events")
timings: list[StopwatchData] = Field(..., title="Timings of events")

image: Optional[JupyterImage] = Field(
None,
Expand Down
Loading