From 249b9e37c2bc1f3e24dd375a1b022eaad4b27818 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Tue, 17 Jun 2025 19:33:09 -0700 Subject: [PATCH 1/3] Add type annotations --- setup.py | 2 +- shotgun_api3/lib/httplib2/__init__.py | 1 + shotgun_api3/lib/httplib2/auth.py | 1 + shotgun_api3/lib/httplib2/certs.py | 1 + shotgun_api3/lib/httplib2/iri2uri.py | 1 + shotgun_api3/lib/httplib2/python3/__init__.py | 1 + shotgun_api3/lib/httplib2/python3/auth.py | 1 + shotgun_api3/lib/httplib2/python3/certs.py | 1 + shotgun_api3/lib/httplib2/python3/iri2uri.py | 1 + shotgun_api3/lib/mimetypes.py | 15 +- shotgun_api3/lib/mockgun/mockgun.py | 5 +- shotgun_api3/lib/mockgun/schema.py | 3 +- shotgun_api3/lib/pyparsing.py | 1 + shotgun_api3/lib/sgsix.py | 2 +- shotgun_api3/lib/sgutils.py | 4 +- shotgun_api3/lib/six.py | 1 + shotgun_api3/shotgun.py | 587 +++++++++++------- 17 files changed, 379 insertions(+), 249 deletions(-) diff --git a/setup.py b/setup.py index 9240486b5..2e25a17d5 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ include_package_data=True, package_data={"": ["cacerts.txt", "cacert.pem"]}, zip_safe=False, - python_requires=">=3.7.0", + python_requires=">=3.9.0", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", diff --git a/shotgun_api3/lib/httplib2/__init__.py b/shotgun_api3/lib/httplib2/__init__.py index 42c9916d1..65bf61636 100644 --- a/shotgun_api3/lib/httplib2/__init__.py +++ b/shotgun_api3/lib/httplib2/__init__.py @@ -1,3 +1,4 @@ +# type: ignore from .. import six # Define all here to keep linters happy. It should be overwritten by the code diff --git a/shotgun_api3/lib/httplib2/auth.py b/shotgun_api3/lib/httplib2/auth.py index 53f427be1..a4380f0d4 100644 --- a/shotgun_api3/lib/httplib2/auth.py +++ b/shotgun_api3/lib/httplib2/auth.py @@ -1,3 +1,4 @@ +# type: ignore import base64 import re diff --git a/shotgun_api3/lib/httplib2/certs.py b/shotgun_api3/lib/httplib2/certs.py index 59d1ffc70..8b13aec87 100644 --- a/shotgun_api3/lib/httplib2/certs.py +++ b/shotgun_api3/lib/httplib2/certs.py @@ -1,3 +1,4 @@ +# type: ignore """Utilities for certificate management.""" import os diff --git a/shotgun_api3/lib/httplib2/iri2uri.py b/shotgun_api3/lib/httplib2/iri2uri.py index 86e361e62..3aa662987 100644 --- a/shotgun_api3/lib/httplib2/iri2uri.py +++ b/shotgun_api3/lib/httplib2/iri2uri.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# type: ignore """Converts an IRI to a URI.""" __author__ = "Joe Gregorio (joe@bitworking.org)" diff --git a/shotgun_api3/lib/httplib2/python3/__init__.py b/shotgun_api3/lib/httplib2/python3/__init__.py index ba5fa2f23..a10d58bfe 100644 --- a/shotgun_api3/lib/httplib2/python3/__init__.py +++ b/shotgun_api3/lib/httplib2/python3/__init__.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# type: ignore """Small, fast HTTP client library for Python.""" __author__ = "Joe Gregorio (joe@bitworking.org)" diff --git a/shotgun_api3/lib/httplib2/python3/auth.py b/shotgun_api3/lib/httplib2/python3/auth.py index 53f427be1..a4380f0d4 100644 --- a/shotgun_api3/lib/httplib2/python3/auth.py +++ b/shotgun_api3/lib/httplib2/python3/auth.py @@ -1,3 +1,4 @@ +# type: ignore import base64 import re diff --git a/shotgun_api3/lib/httplib2/python3/certs.py b/shotgun_api3/lib/httplib2/python3/certs.py index 59d1ffc70..8b13aec87 100644 --- a/shotgun_api3/lib/httplib2/python3/certs.py +++ b/shotgun_api3/lib/httplib2/python3/certs.py @@ -1,3 +1,4 @@ +# type: ignore """Utilities for certificate management.""" import os diff --git a/shotgun_api3/lib/httplib2/python3/iri2uri.py b/shotgun_api3/lib/httplib2/python3/iri2uri.py index 86e361e62..3aa662987 100644 --- a/shotgun_api3/lib/httplib2/python3/iri2uri.py +++ b/shotgun_api3/lib/httplib2/python3/iri2uri.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# type: ignore """Converts an IRI to a URI.""" __author__ = "Joe Gregorio (joe@bitworking.org)" diff --git a/shotgun_api3/lib/mimetypes.py b/shotgun_api3/lib/mimetypes.py index bc8488535..3e6eca0e5 100644 --- a/shotgun_api3/lib/mimetypes.py +++ b/shotgun_api3/lib/mimetypes.py @@ -26,6 +26,7 @@ a patched version of the native mimetypes module and is used only in Python versions 2.7.0 - 2.7.9, which included a broken version of the mimetypes module. """ +from __future__ import print_function import os import sys @@ -568,14 +569,14 @@ def _default_mime_types(): """ def usage(code, msg=''): - print USAGE - if msg: print msg + print(USAGE) + if msg: print(msg) sys.exit(code) try: opts, args = getopt.getopt(sys.argv[1:], 'hle', ['help', 'lenient', 'extension']) - except getopt.error, msg: + except getopt.error as msg: usage(1, msg) strict = 1 @@ -590,9 +591,9 @@ def usage(code, msg=''): for gtype in args: if extension: guess = guess_extension(gtype, strict) - if not guess: print "I don't know anything about type", gtype - else: print guess + if not guess: print("I don't know anything about type", gtype) + else: print(guess) else: guess, encoding = guess_type(gtype, strict) - if not guess: print "I don't know anything about type", gtype - else: print 'type:', guess, 'encoding:', encoding \ No newline at end of file + if not guess: print("I don't know anything about type", gtype) + else: print('type:', guess, 'encoding:', encoding) diff --git a/shotgun_api3/lib/mockgun/mockgun.py b/shotgun_api3/lib/mockgun/mockgun.py index 18e4a142c..5d2548eec 100644 --- a/shotgun_api3/lib/mockgun/mockgun.py +++ b/shotgun_api3/lib/mockgun/mockgun.py @@ -120,6 +120,7 @@ from ...shotgun import _Config from .errors import MockgunError from .schema import SchemaFactory +from typing import Any # ---------------------------------------------------------------------------- # Version @@ -581,7 +582,7 @@ def _get_new_row(self, entity_type): row[field] = default_value return row - def _compare(self, field_type, lval, operator, rval): + def _compare(self, field_type: str, lval: Any, operator: str, rval: Any) -> bool: """ Compares a field using the operator and value provide by the filter. @@ -798,7 +799,7 @@ def _row_matches_filter(self, entity_type, row, sg_filter, retired_only): return self._compare(field_type, lval, operator, rval) - def _rearrange_filters(self, filters): + def _rearrange_filters(self, filters: list) -> None: """ Modifies the filter syntax to turn it into a list of three items regardless of the actual filter. Most of the filters are list of three elements, so this doesn't change much. diff --git a/shotgun_api3/lib/mockgun/schema.py b/shotgun_api3/lib/mockgun/schema.py index 5d5019df4..64edebf47 100644 --- a/shotgun_api3/lib/mockgun/schema.py +++ b/shotgun_api3/lib/mockgun/schema.py @@ -1,3 +1,4 @@ +# type: ignore """ ----------------------------------------------------------------------------- Copyright (c) 2009-2017, Shotgun Software Inc @@ -47,7 +48,7 @@ class SchemaFactory(object): _schema_cache_path = None @classmethod - def get_schemas(cls, schema_path, schema_entity_path): + def get_schemas(cls, schema_path: str, schema_entity_path: str) -> tuple: """ Retrieves the schemas from disk. diff --git a/shotgun_api3/lib/pyparsing.py b/shotgun_api3/lib/pyparsing.py index 774222bba..61b5e4dd7 100644 --- a/shotgun_api3/lib/pyparsing.py +++ b/shotgun_api3/lib/pyparsing.py @@ -22,6 +22,7 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # +# type: ignore __doc__ = \ """ diff --git a/shotgun_api3/lib/sgsix.py b/shotgun_api3/lib/sgsix.py index 6c2af1abc..1b98e7e30 100644 --- a/shotgun_api3/lib/sgsix.py +++ b/shotgun_api3/lib/sgsix.py @@ -58,7 +58,7 @@ ShotgunSSLError = SSLHandshakeError -def normalize_platform(platform, python2=True): +def normalize_platform(platform: str, python2: bool = True) -> str: """ Normalize the return of sys.platform between Python 2 and 3. diff --git a/shotgun_api3/lib/sgutils.py b/shotgun_api3/lib/sgutils.py index 0d49e4b39..53e691b6b 100644 --- a/shotgun_api3/lib/sgutils.py +++ b/shotgun_api3/lib/sgutils.py @@ -29,7 +29,7 @@ """ -def ensure_binary(s, encoding='utf-8', errors='strict'): +def ensure_binary(s, encoding='utf-8', errors='strict') -> bytes: """ Coerce **s** to bytes. @@ -44,7 +44,7 @@ def ensure_binary(s, encoding='utf-8', errors='strict'): raise TypeError(f"not expecting type '{type(s)}'") -def ensure_str(s, encoding='utf-8', errors='strict'): +def ensure_str(s, encoding='utf-8', errors='strict') -> str: """Coerce *s* to `str`. - `str` -> `str` diff --git a/shotgun_api3/lib/six.py b/shotgun_api3/lib/six.py index b22d2e57d..392e64a1f 100644 --- a/shotgun_api3/lib/six.py +++ b/shotgun_api3/lib/six.py @@ -17,6 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# type: ignore """Utilities for writing code that runs on Python 2 and 3""" diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index fea25deba..8f71da0fa 100644 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -50,6 +50,17 @@ import json from .lib.six.moves import urllib import shutil # used for attachment download +from typing import ( + Any, + BinaryIO, + Iterable, + Literal, + NoReturn, + Optional, + TypeVar, + Union, + TYPE_CHECKING, +) from .lib.six.moves import http_client # Used for secure file upload. from .lib.httplib2 import Http, ProxyInfo, socks, ssl_error_classes from .lib.sgtimezone import SgTimezone @@ -64,6 +75,7 @@ from base64 import encodestring as base64encode +T = TypeVar("T") LOG = logging.getLogger("shotgun_api3") """ Logging instance for shotgun_api3 @@ -97,7 +109,7 @@ def _is_mimetypes_broken(): ) -if _is_mimetypes_broken(): +if TYPE_CHECKING or _is_mimetypes_broken(): from .lib import mimetypes as mimetypes else: import mimetypes @@ -209,7 +221,7 @@ class ServerCapabilities(object): the future. Therefore, usage of this class is discouraged. """ - def __init__(self, host, meta): + def __init__(self, host: str, meta: dict[str, Any]) -> None: """ ServerCapabilities.__init__ @@ -249,14 +261,14 @@ def __init__(self, host, meta): self.version = tuple(self.version[:3]) self._ensure_json_supported() - def _ensure_python_version_supported(self): + def _ensure_python_version_supported(self) -> None: """ Checks the if current Python version is supported. """ if sys.version_info < (3, 7): raise ShotgunError("This module requires Python version 3.7 or higher.") - def _ensure_support(self, feature, raise_hell=True): + def _ensure_support(self, feature: dict[str, Any], raise_hell: bool = True) -> bool: """ Checks the server version supports a given feature, raises an exception if it does not. @@ -267,6 +279,7 @@ def _ensure_support(self, feature, raise_hell=True): :param bool raise_hell: Whether to raise an exception if the feature is not supported. Defaults to ``True`` :raises: :class:`ShotgunError` if the current server version does not support ``feature`` + :rtype: bool """ if not self.version or self.version < feature["version"]: @@ -284,13 +297,13 @@ def _ensure_support(self, feature, raise_hell=True): else: return True - def _ensure_json_supported(self): + def _ensure_json_supported(self) -> None: """ Ensures server has support for JSON API endpoint added in v2.4.0. """ self._ensure_support({"version": (2, 4, 0), "label": "JSON API"}) - def ensure_include_archived_projects(self): + def ensure_include_archived_projects(self) -> None: """ Ensures server has support for archived Projects feature added in v5.3.14. """ @@ -298,7 +311,7 @@ def ensure_include_archived_projects(self): {"version": (5, 3, 14), "label": "include_archived_projects parameter"} ) - def ensure_per_project_customization(self): + def ensure_per_project_customization(self) -> bool: """ Ensures server has support for per-project customization feature added in v5.4.4. """ @@ -306,7 +319,7 @@ def ensure_per_project_customization(self): {"version": (5, 4, 4), "label": "project parameter"}, True ) - def ensure_support_for_additional_filter_presets(self): + def ensure_support_for_additional_filter_presets(self) -> bool: """ Ensures server has support for additional filter presets feature added in v7.0.0. """ @@ -314,7 +327,7 @@ def ensure_support_for_additional_filter_presets(self): {"version": (7, 0, 0), "label": "additional_filter_presets parameter"}, True ) - def ensure_user_following_support(self): + def ensure_user_following_support(self) -> bool: """ Ensures server has support for listing items a user is following, added in v7.0.12. """ @@ -405,7 +418,7 @@ class _Config(object): Container for the client configuration. """ - def __init__(self, sg): + def __init__(self, sg: "Shotgun"): """ :param sg: Shotgun connection. """ @@ -426,42 +439,42 @@ def __init__(self, sg): # If the optional timeout parameter is given, blocking operations # (like connection attempts) will timeout after that many seconds # (if it is not given, the global default timeout setting is used) - self.timeout_secs = None + self.timeout_secs: Optional[float] = None self.api_ver = "api3" self.convert_datetimes_to_utc = True - self._records_per_page = None - self.api_key = None - self.script_name = None - self.user_login = None - self.user_password = None - self.auth_token = None - self.sudo_as_login = None + self._records_per_page: Optional[int] = None + self.api_key: Optional[str] = None + self.script_name: Optional[str] = None + self.user_login: Optional[str] = None + self.user_password: Optional[str] = None + self.auth_token: Optional[str] = None + self.sudo_as_login: Optional[str] = None # Authentication parameters to be folded into final auth_params dict - self.extra_auth_params = None + self.extra_auth_params: Optional[dict[str, Any]] = None # uuid as a string - self.session_uuid = None - self.scheme = None - self.server = None - self.api_path = None + self.session_uuid: Optional[str] = None + self.scheme: Optional[str] = None + self.server: Optional[str] = None + self.api_path: Optional[str] = None # The raw_http_proxy reflects the exact string passed in # to the Shotgun constructor. This can be useful if you # need to construct a Shotgun API instance based on # another Shotgun API instance. - self.raw_http_proxy = None + self.raw_http_proxy: Optional[str] = None # if a proxy server is being used, the proxy_handler # below will contain a urllib2.ProxyHandler instance # which can be used whenever a request needs to be made. - self.proxy_handler = None - self.proxy_server = None + self.proxy_handler: Optional["urllib.request.ProxyHandler"] = None + self.proxy_server: Optional[str] = None self.proxy_port = 8080 - self.proxy_user = None - self.proxy_pass = None - self.session_token = None - self.authorization = None + self.proxy_user: Optional[str] = None + self.proxy_pass: Optional[str] = None + self.session_token: Optional[str] = None + self.authorization: Optional[str] = None self.no_ssl_validation = False self.localized = False - def set_server_params(self, base_url): + def set_server_params(self, base_url: str) -> None: """ Set the different server related fields based on the passed in URL. @@ -483,7 +496,7 @@ def set_server_params(self, base_url): ) @property - def records_per_page(self): + def records_per_page(self) -> int: """ The records per page value from the server. """ @@ -516,20 +529,20 @@ class Shotgun(object): def __init__( self, - base_url, - script_name=None, - api_key=None, - convert_datetimes_to_utc=True, - http_proxy=None, - ensure_ascii=True, - connect=True, - ca_certs=None, - login=None, - password=None, - sudo_as_login=None, - session_token=None, - auth_token=None, - ): + base_url: str, + script_name: Optional[str] = None, + api_key: Optional[str] = None, + convert_datetimes_to_utc: bool = True, + http_proxy: Optional[str] = None, + ensure_ascii: bool = True, + connect: bool = True, + ca_certs: Optional[str] = None, + login: Optional[str] = None, + password: Optional[str] = None, + sudo_as_login: Optional[str] = None, + session_token: Optional[str] = None, + auth_token: Optional[str] = None, + ) -> None: """ Initializes a new instance of the Shotgun client. @@ -659,7 +672,7 @@ def __init__( "must provide login/password, session_token or script_name/api_key" ) - self.config = _Config(self) + self.config: _Config = _Config(self) self.config.api_key = api_key self.config.script_name = script_name self.config.user_login = login @@ -696,7 +709,7 @@ def __init__( ): SHOTGUN_API_DISABLE_ENTITY_OPTIMIZATION = True - self._connection = None + self._connection: Optional[Http] = None self.__ca_certs = self._get_certs_file(ca_certs) @@ -764,7 +777,7 @@ def __init__( # this relies on self.client_caps being set first self.reset_user_agent() - self._server_caps = None + self._server_caps: Optional[ServerCapabilities] = None # test to ensure the the server supports the json API # call to server will only be made once and will raise error if connect: @@ -778,7 +791,7 @@ def __init__( self.config.user_password = None self.config.auth_token = None - def _split_url(self, base_url): + def _split_url(self, base_url: str) -> tuple[str, str]: """ Extract the hostname:port and username/password/token from base_url sent when connect to the API. @@ -810,7 +823,7 @@ def _split_url(self, base_url): # API Functions @property - def server_info(self): + def server_info(self) -> dict[str, Any]: """ Property containing server information. @@ -828,7 +841,7 @@ def server_info(self): return self.server_caps.server_info @property - def server_caps(self): + def server_caps(self) -> ServerCapabilities: """ Property containing :class:`ServerCapabilities` object. @@ -843,7 +856,7 @@ def server_caps(self): self._server_caps = ServerCapabilities(self.config.server, self.info()) return self._server_caps - def connect(self): + def connect(self) -> None: """ Connect client to the server if it is not already connected. @@ -854,7 +867,7 @@ def connect(self): self.info() return - def close(self): + def close(self) -> None: """ Close the current connection to the server. @@ -863,7 +876,7 @@ def close(self): self._close_connection() return - def info(self): + def info(self) -> dict[str, Any]: """ Get API-related metadata from the Shotgun server. @@ -896,15 +909,15 @@ def info(self): def find_one( self, - entity_type, - filters, - fields=None, - order=None, - filter_operator=None, - retired_only=False, - include_archived_projects=True, - additional_filter_presets=None, - ): + entity_type: str, + filters: Union[list, tuple, dict[str, Any]], + fields: Optional[list[str]] = None, + order: Optional[list[dict[str, Any]]] = None, + filter_operator: Optional[Literal["all", "any"]] = None, + retired_only: bool = False, + include_archived_projects: bool = True, + additional_filter_presets: Optional[list[dict[str, Any]]] = None, + ) -> Optional[dict[str, Any]]: """ Shortcut for :meth:`~shotgun_api3.Shotgun.find` with ``limit=1`` so it returns a single result. @@ -919,7 +932,7 @@ def find_one( :param list fields: Optional list of fields to include in each entity record returned. Defaults to ``["id"]``. - :param int order: Optional list of fields to order the results by. List has the format:: + :param list order: Optional list of fields to order the results by. List has the format:: [ {'field_name':'foo', 'direction':'asc'}, @@ -936,7 +949,7 @@ def find_one( same query. :param bool include_archived_projects: Optional boolean flag to include entities whose projects have been archived. Defaults to ``True``. - :param additional_filter_presets: Optional list of presets to further filter the result + :param list additional_filter_presets: Optional list of presets to further filter the result set, list has the form:: [{ @@ -976,17 +989,17 @@ def find_one( def find( self, - entity_type, - filters, - fields=None, - order=None, - filter_operator=None, - limit=0, - retired_only=False, - page=0, - include_archived_projects=True, - additional_filter_presets=None, - ): + entity_type: str, + filters: Union[list, tuple, dict[str, Any]], + fields: Optional[list[str]] = None, + order: Optional[list[dict[str, Any]]] = None, + filter_operator: Optional[Literal["all", "any"]] = None, + limit: int = 0, + retired_only: bool = False, + page: int = 0, + include_archived_projects: bool = True, + additional_filter_presets: Optional[list[dict[str, Any]]] = None, + ) -> list[dict[str, Any]]: """ Find entities matching the given filters. @@ -1064,7 +1077,7 @@ def find( same query. :param bool include_archived_projects: Optional boolean flag to include entities whose projects have been archived. Defaults to ``True``. - :param additional_filter_presets: Optional list of presets to further filter the result + :param list additional_filter_presets: Optional list of presets to further filter the result set, list has the form:: [{ @@ -1175,15 +1188,15 @@ def find( def _construct_read_parameters( self, - entity_type, - fields, - filters, - retired_only, - order, - include_archived_projects, + entity_type: str, + fields: Optional[list[str]], + filters: dict[str, Any], + retired_only: bool, + order: Optional[list[dict[str, Any]]], + include_archived_projects: bool, additional_filter_presets, - ): - params = {} + ) -> dict[str, Any]: + params: dict[str, Any] = {} params["type"] = entity_type params["return_fields"] = fields or ["id"] params["filters"] = filters @@ -1213,7 +1226,9 @@ def _construct_read_parameters( params["sorts"] = sort_list return params - def _add_project_param(self, params, project_entity): + def _add_project_param( + self, params: dict[str, Any], project_entity + ) -> dict[str, Any]: if project_entity and self.server_caps.ensure_per_project_customization(): params["project"] = project_entity @@ -1221,8 +1236,8 @@ def _add_project_param(self, params, project_entity): return params def _translate_update_params( - self, entity_type, entity_id, data, multi_entity_update_modes - ): + self, entity_type: str, entity_id: int, data, multi_entity_update_modes + ) -> dict[str, Any]: global SHOTGUN_API_DISABLE_ENTITY_OPTIMIZATION def optimize_field(field_dict): @@ -1244,13 +1259,13 @@ def optimize_field(field_dict): def summarize( self, - entity_type, - filters, - summary_fields, - filter_operator=None, - grouping=None, - include_archived_projects=True, - ): + entity_type: str, + filters: Union[list, dict[str, Any]], + summary_fields: list[dict[str, str]], + filter_operator: Optional[str] = None, + grouping: Optional[list] = None, + include_archived_projects: bool = True, + ) -> dict[str, Any]: """ Summarize field data returned by a query. @@ -1438,7 +1453,11 @@ def summarize( # So we only need to check the server version if it is False self.server_caps.ensure_include_archived_projects() - params = {"type": entity_type, "summaries": summary_fields, "filters": filters} + params: dict[str, Any] = { + "type": entity_type, + "summaries": summary_fields, + "filters": filters, + } if include_archived_projects is False: # Defaults to True on the server, so only pass it if it's False @@ -1450,7 +1469,12 @@ def summarize( records = self._call_rpc("summarize", params) return records - def create(self, entity_type, data, return_fields=None): + def create( + self, + entity_type: str, + data: dict[str, Any], + return_fields: Optional[list] = None, + ) -> dict[str, Any]: """ Create a new entity of the specified ``entity_type``. @@ -1533,7 +1557,13 @@ def create(self, entity_type, data, return_fields=None): return result - def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): + def update( + self, + entity_type: str, + entity_id: int, + data: dict[str, Any], + multi_entity_update_modes: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: """ Update the specified entity with the supplied data. @@ -1554,7 +1584,7 @@ def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): } :param str entity_type: Entity type to update. - :param id entity_id: id of the entity to update. + :param id entity_id: int of the entity to update. :param dict data: key/value pairs where key is the field name and value is the value to set for that field. This method does not restrict the updating of fields hidden in the web UI via the Project Tracking Settings panel. @@ -1612,7 +1642,7 @@ def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): return result - def delete(self, entity_type, entity_id): + def delete(self, entity_type: str, entity_id: int) -> bool: """ Retire the specified entity. @@ -1636,7 +1666,7 @@ def delete(self, entity_type, entity_id): return self._call_rpc("delete", params) - def revive(self, entity_type, entity_id): + def revive(self, entity_type: str, entity_id: int) -> bool: """ Revive an entity that has previously been deleted. @@ -1644,7 +1674,7 @@ def revive(self, entity_type, entity_id): True :param str entity_type: Shotgun entity type to revive. - :param int entity_id: id of the entity to revive. + :param int entity_id: int of the entity to revive. :returns: ``True`` if the entity was revived, ``False`` otherwise (e.g. if the entity is not currently retired). :rtype: bool @@ -1654,7 +1684,7 @@ def revive(self, entity_type, entity_id): return self._call_rpc("revive", params) - def batch(self, requests): + def batch(self, requests: list) -> list: """ Make a batch request of several :meth:`~shotgun_api3.Shotgun.create`, :meth:`~shotgun_api3.Shotgun.update`, and :meth:`~shotgun_api3.Shotgun.delete` calls. @@ -1769,7 +1799,13 @@ def _required_keys(message, required_keys, data): records = self._call_rpc("batch", calls) return self._parse_records(records) - def work_schedule_read(self, start_date, end_date, project=None, user=None): + def work_schedule_read( + self, + start_date: str, + end_date: str, + project: Optional[dict[str, Any]] = None, + user: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: """ Return the work day rules for a given date range. @@ -1840,13 +1876,13 @@ def work_schedule_read(self, start_date, end_date, project=None, user=None): def work_schedule_update( self, - date, - working, - description=None, - project=None, - user=None, - recalculate_field=None, - ): + date: str, + working: bool, + description: Optional[str] = None, + project: Optional[dict[str, Any]] = None, + user: Optional[dict[str, Any]] = None, + recalculate_field: Optional[str] = None, + ) -> dict[str, Any]: """ Update the work schedule for a given date. @@ -1900,7 +1936,7 @@ def work_schedule_update( return self._call_rpc("work_schedule_update", params) - def follow(self, user, entity): + def follow(self, user: dict[str, Any], entity: dict[str, Any]) -> dict[str, Any]: """ Add the entity to the user's followed entities. @@ -1928,7 +1964,7 @@ def follow(self, user, entity): return self._call_rpc("follow", params) - def unfollow(self, user, entity): + def unfollow(self, user: dict[str, Any], entity: dict[str, Any]) -> dict[str, Any]: """ Remove entity from the user's followed entities. @@ -1955,7 +1991,7 @@ def unfollow(self, user, entity): return self._call_rpc("unfollow", params) - def followers(self, entity): + def followers(self, entity: dict[str, Any]) -> list: """ Return all followers for an entity. @@ -1983,7 +2019,12 @@ def followers(self, entity): return self._call_rpc("followers", params) - def following(self, user, project=None, entity_type=None): + def following( + self, + user: dict[str, Any], + project: Optional[dict[str, Any]] = None, + entity_type: Optional[str] = None, + ) -> list[dict[str, Any]]: """ Return all entity instances a user is following. @@ -2006,7 +2047,7 @@ def following(self, user, project=None, entity_type=None): self.server_caps.ensure_user_following_support() - params = {"user": user} + params: dict[str, Any] = {"user": user} if project: params["project"] = project if entity_type: @@ -2014,7 +2055,9 @@ def following(self, user, project=None, entity_type=None): return self._call_rpc("following", params) - def schema_entity_read(self, project_entity=None): + def schema_entity_read( + self, project_entity: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: """ Return all active entity types, their display names, and their visibility. @@ -2049,7 +2092,7 @@ def schema_entity_read(self, project_entity=None): The returned display names for this method will be localized when the ``localize`` Shotgun config property is set to ``True``. See :ref:`localization` for more information. """ - params = {} + params: dict[str, Any] = {} params = self._add_project_param(params, project_entity) @@ -2058,7 +2101,9 @@ def schema_entity_read(self, project_entity=None): else: return self._call_rpc("schema_entity_read", None) - def schema_read(self, project_entity=None): + def schema_read( + self, project_entity: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: """ Get the schema for all fields on all entities. @@ -2121,7 +2166,7 @@ def schema_read(self, project_entity=None): The returned display names for this method will be localized when the ``localize`` Shotgun config property is set to ``True``. See :ref:`localization` for more information. """ - params = {} + params: dict[str, Any] = {} params = self._add_project_param(params, project_entity) @@ -2130,7 +2175,12 @@ def schema_read(self, project_entity=None): else: return self._call_rpc("schema_read", None) - def schema_field_read(self, entity_type, field_name=None, project_entity=None): + def schema_field_read( + self, + entity_type: str, + field_name: Optional[str] = None, + project_entity: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: """ Get schema for all fields on the specified entity type or just the field name specified if provided. @@ -2195,8 +2245,12 @@ def schema_field_read(self, entity_type, field_name=None, project_entity=None): return self._call_rpc("schema_field_read", params) def schema_field_create( - self, entity_type, data_type, display_name, properties=None - ): + self, + entity_type: str, + data_type: str, + display_name: str, + properties: Optional[dict[str, Any]] = None, + ) -> str: """ Create a field for the specified entity type. @@ -2215,14 +2269,14 @@ def schema_field_create( :param str display_name: Specifies the display name of the field you are creating. The system name will be created from this display name and returned upon successful creation. - :param dict properties: Dict of valid properties for the new field. Use this to specify + :param dict properties: dict of valid properties for the new field. Use this to specify other field properties such as the 'description' or 'summary_default'. :returns: The internal Shotgun name for the new field, this is different to the ``display_name`` parameter passed in. :rtype: str """ - params = { + params: dict[str, Any] = { "type": entity_type, "data_type": data_type, "properties": [{"property_name": "name", "value": display_name}], @@ -2234,8 +2288,12 @@ def schema_field_create( return self._call_rpc("schema_field_create", params) def schema_field_update( - self, entity_type, field_name, properties, project_entity=None - ): + self, + entity_type: str, + field_name: str, + properties: dict[str, Any], + project_entity: Optional[dict[str, Any]] = None, + ) -> bool: """ Update the properties for the specified field on an entity. @@ -2249,9 +2307,9 @@ def schema_field_update( >>> sg.schema_field_update("Asset", "sg_test_number", properties) True - :param entity_type: Entity type of field to update. - :param field_name: Internal Shotgun name of the field to update. - :param properties: Dictionary with key/value pairs where the key is the property to be + :param str entity_type: Entity type of field to update. + :param str field_name: Internal Shotgun name of the field to update. + :param dict properties: Dictionary with key/value pairs where the key is the property to be updated and the value is the new value. :param dict project_entity: Optional Project entity specifying which project to modify the ``visible`` property for. If ``visible`` is present in ``properties`` and @@ -2277,7 +2335,7 @@ def schema_field_update( params = self._add_project_param(params, project_entity) return self._call_rpc("schema_field_update", params) - def schema_field_delete(self, entity_type, field_name): + def schema_field_delete(self, entity_type: str, field_name: str) -> bool: """ Delete the specified field from the entity type. @@ -2294,7 +2352,7 @@ def schema_field_delete(self, entity_type, field_name): return self._call_rpc("schema_field_delete", params) - def add_user_agent(self, agent): + def add_user_agent(self, agent: str) -> None: """ Add agent to the user-agent header. @@ -2306,7 +2364,7 @@ def add_user_agent(self, agent): """ self._user_agents.append(agent) - def reset_user_agent(self): + def reset_user_agent(self) -> None: """ Reset user agent to the default value. @@ -2330,7 +2388,7 @@ def reset_user_agent(self): "ssl %s (%s)" % (self.client_caps.ssl_version, validation_str), ] - def set_session_uuid(self, session_uuid): + def set_session_uuid(self, session_uuid: str) -> None: """ Set the browser session_uuid in the current Shotgun API instance. @@ -2348,12 +2406,12 @@ def set_session_uuid(self, session_uuid): def share_thumbnail( self, - entities, - thumbnail_path=None, - source_entity=None, - filmstrip_thumbnail=False, - **kwargs, - ): + entities: list[dict[str, Any]], + thumbnail_path: Optional[str] = None, + source_entity: Optional[dict[str, Any]] = None, + filmstrip_thumbnail: bool = False, + **kwargs: Any, + ) -> int: """ Associate a thumbnail with more than one Shotgun entity. @@ -2492,7 +2550,9 @@ def share_thumbnail( return attachment_id - def upload_thumbnail(self, entity_type, entity_id, path, **kwargs): + def upload_thumbnail( + self, entity_type: str, entity_id: int, path: str, **kwargs: Any + ) -> int: """ Upload a file from a local path and assign it as the thumbnail for the specified entity. @@ -2517,12 +2577,15 @@ def upload_thumbnail(self, entity_type, entity_id, path, **kwargs): :param int entity_id: Id of the entity to set the thumbnail for. :param str path: Full path to the thumbnail file on disk. :returns: Id of the new attachment + :rtype: int """ return self.upload( entity_type, entity_id, path, field_name="thumb_image", **kwargs ) - def upload_filmstrip_thumbnail(self, entity_type, entity_id, path, **kwargs): + def upload_filmstrip_thumbnail( + self, entity_type: str, entity_id: int, path: str, **kwargs: Any + ) -> int: """ Upload filmstrip thumbnail to specified entity. @@ -2573,13 +2636,13 @@ def upload_filmstrip_thumbnail(self, entity_type, entity_id, path, **kwargs): def upload( self, - entity_type, - entity_id, - path, - field_name=None, - display_name=None, - tag_list=None, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str] = None, + display_name: Optional[str] = None, + tag_list: Optional[str] = None, + ) -> int: """ Upload a file to the specified entity. @@ -2662,14 +2725,14 @@ def upload( def _upload_to_storage( self, - entity_type, - entity_id, - path, - field_name, - display_name, - tag_list, - is_thumbnail, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str], + display_name: Optional[str], + tag_list: Optional[str], + is_thumbnail: bool, + ) -> int: """ Internal function to upload a file to the Cloud storage and link it to the specified entity. @@ -2752,14 +2815,14 @@ def _upload_to_storage( def _upload_to_sg( self, - entity_type, - entity_id, - path, - field_name, - display_name, - tag_list, - is_thumbnail, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str], + display_name: Optional[str], + tag_list: Optional[str], + is_thumbnail: bool, + ) -> int: """ Internal function to upload a file to Shotgun and link it to the specified entity. @@ -2831,7 +2894,9 @@ def _upload_to_sg( attachment_id = int(result.split(":", 2)[1].split("\n", 1)[0]) return attachment_id - def _get_attachment_upload_info(self, is_thumbnail, filename, is_multipart_upload): + def _get_attachment_upload_info( + self, is_thumbnail: bool, filename: str, is_multipart_upload: bool + ) -> dict[str, Any]: """ Internal function to get the information needed to upload a file to Cloud storage. @@ -2849,7 +2914,7 @@ def _get_attachment_upload_info(self, is_thumbnail, filename, is_multipart_uploa else: upload_type = "Attachment" - params = {"upload_type": upload_type, "filename": filename} + params: dict[str, Any] = {"upload_type": upload_type, "filename": filename} params["multipart_upload"] = is_multipart_upload @@ -2878,7 +2943,12 @@ def _get_attachment_upload_info(self, is_thumbnail, filename, is_multipart_uploa "upload_info": upload_info, } - def download_attachment(self, attachment=False, file_path=None, attachment_id=None): + def download_attachment( + self, + attachment: Union[dict[str, Any], Literal[False]] = False, + file_path: Optional[str] = None, + attachment_id: Optional[int] = None, + ) -> Union[str, bytes, None]: """ Download the file associated with a Shotgun Attachment. @@ -2903,7 +2973,7 @@ def download_attachment(self, attachment=False, file_path=None, attachment_id=No be downloaded from the Shotgun server. :param str file_path: Optional file path to write the data directly to local disk. This avoids loading all of the data in memory and saves the file locally at the given path. - :param id attachment_id: (deprecated) Optional ``id`` of the Attachment entity in Shotgun to + :param int attachment_id: (deprecated) Optional ``id`` of the Attachment entity in Shotgun to download. .. note: @@ -2990,7 +3060,7 @@ def download_attachment(self, attachment=False, file_path=None, attachment_id=No else: return attachment - def get_auth_cookie_handler(self): + def get_auth_cookie_handler(self) -> urllib.request.HTTPCookieProcessor: """ Return an urllib cookie handler containing a cookie for FPTR authentication. @@ -3022,7 +3092,9 @@ def get_auth_cookie_handler(self): cj.set_cookie(c) return urllib.request.HTTPCookieProcessor(cj) - def get_attachment_download_url(self, attachment): + def get_attachment_download_url( + self, attachment: Optional[Union[int, dict[str, Any]]] + ) -> str: """ Return the URL for downloading provided Attachment. @@ -3080,7 +3152,9 @@ def get_attachment_download_url(self, attachment): ) return url - def authenticate_human_user(self, user_login, user_password, auth_token=None): + def authenticate_human_user( + self, user_login: str, user_password: str, auth_token: Optional[str] = None + ) -> dict[str, Any]: """ Authenticate Shotgun HumanUser. @@ -3139,7 +3213,9 @@ def authenticate_human_user(self, user_login, user_password, auth_token=None): self.config.auth_token = original_auth_token raise - def update_project_last_accessed(self, project, user=None): + def update_project_last_accessed( + self, project: dict[str, Any], user: Optional[dict[str, Any]] = None + ) -> None: """ Update a Project's ``last_accessed_by_current_user`` field to the current timestamp. @@ -3185,7 +3261,9 @@ def update_project_last_accessed(self, project, user=None): record = self._call_rpc("update_project_last_accessed_by_current_user", params) self._parse_records(record)[0] - def note_thread_read(self, note_id, entity_fields=None): + def note_thread_read( + self, note_id: int, entity_fields: Optional[dict[str, Any]] = None + ) -> list[dict[str, Any]]: """ Return the full conversation for a given note, including Replies and Attachments. @@ -3260,7 +3338,13 @@ def note_thread_read(self, note_id, entity_fields=None): result = self._parse_records(record) return result - def text_search(self, text, entity_types, project_ids=None, limit=None): + def text_search( + self, + text: str, + entity_types: dict[str, Any], + project_ids: Optional[list] = None, + limit: Optional[int] = None, + ) -> dict[str, Any]: """ Search across the specified entity types for the given text. @@ -3354,13 +3438,13 @@ def text_search(self, text, entity_types, project_ids=None, limit=None): def activity_stream_read( self, - entity_type, - entity_id, - entity_fields=None, - min_id=None, - max_id=None, - limit=None, - ): + entity_type: str, + entity_id: int, + entity_fields: Optional[dict[str, Any]] = None, + min_id: Optional[int] = None, + max_id: Optional[int] = None, + limit: Optional[int] = None, + ) -> dict[str, Any]: """ Retrieve activity stream data from Shotgun. @@ -3412,7 +3496,7 @@ def activity_stream_read( :param str entity_type: Entity type to retrieve activity stream for :param int entity_id: Entity id to retrieve activity stream for - :param list entity_fields: List of additional fields to include. + :param dict entity_fields: Dict of additional fields to include. See above for details :param int max_id: Do not retrieve ids greater than this id. This is useful when implementing paging. @@ -3450,7 +3534,7 @@ def activity_stream_read( result = self._parse_records(record)[0] return result - def nav_expand(self, path, seed_entity_field=None, entity_fields=None): + def nav_expand(self, path: str, seed_entity_field=None, entity_fields=None): """ Expand the navigation hierarchy for the supplied path. @@ -3470,7 +3554,9 @@ def nav_expand(self, path, seed_entity_field=None, entity_fields=None): }, ) - def nav_search_string(self, root_path, search_string, seed_entity_field=None): + def nav_search_string( + self, root_path: str, search_string: str, seed_entity_field=None + ): """ Search function adapted to work with the navigation hierarchy. @@ -3489,7 +3575,7 @@ def nav_search_string(self, root_path, search_string, seed_entity_field=None): }, ) - def nav_search_entity(self, root_path, entity, seed_entity_field=None): + def nav_search_entity(self, root_path: str, entity, seed_entity_field=None): """ Search function adapted to work with the navigation hierarchy. @@ -3509,7 +3595,7 @@ def nav_search_entity(self, root_path, entity, seed_entity_field=None): }, ) - def get_session_token(self): + def get_session_token(self) -> str: """ Get the session token associated with the current session. @@ -3533,7 +3619,7 @@ def get_session_token(self): return session_token - def preferences_read(self, prefs=None): + def preferences_read(self, prefs: Optional[list] = None) -> dict[str, Any]: """ Get a subset of the site preferences. @@ -3556,7 +3642,7 @@ def preferences_read(self, prefs=None): return self._call_rpc("preferences_read", {"prefs": prefs}) - def user_subscriptions_read(self): + def user_subscriptions_read(self) -> list: """ Get the list of user subscriptions. @@ -3568,8 +3654,9 @@ def user_subscriptions_read(self): return self._call_rpc("user_subscriptions_read", None) - def user_subscriptions_create(self, users): - # type: (list[dict[str, Union[str, list[str], None]) -> bool + def user_subscriptions_create( + self, users: list[dict[str, Union[str, list[str], None]]] + ) -> bool: """ Assign subscriptions to users. @@ -3590,7 +3677,7 @@ def user_subscriptions_create(self, users): return response.get("status") == "success" - def _build_opener(self, handler): + def _build_opener(self, handler) -> "urllib.request.OpenerDirector": """ Build urllib2 opener with appropriate proxy handler. """ @@ -3665,7 +3752,7 @@ def _get_certs_file(cls, ca_certs): cert_file = os.path.join(cur_dir, "lib", "certifi", "cacert.pem") return cert_file - def _turn_off_ssl_validation(self): + def _turn_off_ssl_validation(self) -> None: """ Turn off SSL certificate validation. """ @@ -3683,7 +3770,7 @@ def _turn_off_ssl_validation(self): ] # Deprecated methods from old wrapper - def schema(self, entity_type): + def schema(self, entity_type: str) -> NoReturn: """ .. deprecated:: 3.0.0 Use :meth:`~shotgun_api3.Shotgun.schema_field_read` instead. @@ -3692,7 +3779,7 @@ def schema(self, entity_type): "Deprecated: use schema_field_read('%s') instead" % entity_type ) - def entity_types(self): + def entity_types(self) -> NoReturn: """ .. deprecated:: 3.0.0 Use :meth:`~shotgun_api3.Shotgun.schema_entity_read` instead. @@ -3702,7 +3789,13 @@ def entity_types(self): # ======================================================================== # RPC Functions - def _call_rpc(self, method, params, include_auth_params=True, first=False): + def _call_rpc( + self, + method: str, + params: Any, + include_auth_params: bool = True, + first: bool = False, + ) -> Any: """ Call the specified method on the Shotgun Server sending the supplied payload. """ @@ -3766,13 +3859,13 @@ def _call_rpc(self, method, params, include_auth_params=True, first=False): return results[0] return results - def _auth_params(self): + def _auth_params(self) -> dict[str, Any]: """ Return a dictionary of the authentication parameters being used. """ # Used to authenticate HumanUser credentials if self.config.user_login and self.config.user_password: - auth_params = { + auth_params: dict[str, Any] = { "user_login": str(self.config.user_login), "user_password": str(self.config.user_password), } @@ -3821,7 +3914,7 @@ def _auth_params(self): return auth_params - def _sanitize_auth_params(self, params): + def _sanitize_auth_params(self, params: dict[str, Any]) -> dict[str, Any]: """ Given an authentication parameter dictionary, sanitize any sensitive information and return the sanitized dict copy. @@ -3832,7 +3925,9 @@ def _sanitize_auth_params(self, params): sanitized_params[k] = "********" return sanitized_params - def _build_payload(self, method, params, include_auth_params=True): + def _build_payload( + self, method: str, params, include_auth_params: bool = True + ) -> dict[str, Any]: """ Build the payload to be send to the rpc endpoint. """ @@ -3850,7 +3945,7 @@ def _build_payload(self, method, params, include_auth_params=True): return {"method_name": method, "params": call_params} - def _encode_payload(self, payload): + def _encode_payload(self, payload) -> bytes: """ Encode the payload to a string to be passed to the rpc endpoint. @@ -3862,7 +3957,9 @@ def _encode_payload(self, payload): wire = json.dumps(payload, ensure_ascii=False) return sgutils.ensure_binary(wire) - def _make_call(self, verb, path, body, headers): + def _make_call( + self, verb: str, path: str, body, headers: Optional[dict[str, Any]] + ) -> tuple[tuple[int, str], dict[str, Any], str]: """ Make an HTTP call to the server. @@ -3950,7 +4047,9 @@ def _make_call(self, verb, path, body, headers): ) time.sleep(rpc_attempt_interval) - def _http_request(self, verb, path, body, headers): + def _http_request( + self, verb: str, path: str, body, headers: dict[str, Any] + ) -> tuple[tuple[int, str], dict[str, Any], str]: """ Make the actual HTTP request. """ @@ -3974,7 +4073,9 @@ def _http_request(self, verb, path, body, headers): return (http_status, resp_headers, resp_body) - def _make_upload_request(self, request, opener): + def _make_upload_request( + self, request, opener: "urllib.request.OpenerDirector" + ) -> "urllib.request._UrlopenRet": """ Open the given request object, return the response, raises URLError on protocol errors. @@ -3986,7 +4087,7 @@ def _make_upload_request(self, request, opener): raise return result - def _parse_http_status(self, status): + def _parse_http_status(self, status: tuple) -> None: """ Parse the status returned from the http request. @@ -4004,7 +4105,9 @@ def _parse_http_status(self, status): return - def _decode_response(self, headers, body): + def _decode_response( + self, headers: dict[str, Any], body: str + ) -> Union[str, dict[str, Any]]: """ Decode the response from the server from the wire format to a python data structure. @@ -4025,7 +4128,7 @@ def _decode_response(self, headers, body): return self._json_loads(body) return body - def _json_loads(self, body): + def _json_loads(self, body: str): return json.loads(body) def _json_loads_ascii(self, body): @@ -4103,7 +4206,7 @@ def _response_errors(self, sg_response): raise Fault(sg_response.get("message", "Unknown Error")) return - def _visit_data(self, data, visitor): + def _visit_data(self, data: T, visitor) -> T: """ Walk the data (simple python types) and call the visitor. """ @@ -4113,17 +4216,17 @@ def _visit_data(self, data, visitor): recursive = self._visit_data if isinstance(data, list): - return [recursive(i, visitor) for i in data] + return [recursive(i, visitor) for i in data] # type: ignore[return-value] if isinstance(data, tuple): - return tuple(recursive(i, visitor) for i in data) + return tuple(recursive(i, visitor) for i in data) # type: ignore[return-value] if isinstance(data, dict): - return dict((k, recursive(v, visitor)) for k, v in six.iteritems(data)) + return dict((k, recursive(v, visitor)) for k, v in six.iteritems(data)) # type: ignore[return-value] return visitor(data) - def _transform_outbound(self, data): + def _transform_outbound(self, data: T) -> T: """ Transform data types or values before they are sent by the client. @@ -4174,7 +4277,7 @@ def _outbound_visitor(value): return self._visit_data(data, _outbound_visitor) - def _transform_inbound(self, data): + def _transform_inbound(self, data: T) -> T: """ Transforms data types or values after they are received from the server. """ @@ -4210,7 +4313,7 @@ def _inbound_visitor(value): # ======================================================================== # Connection Functions - def _get_connection(self): + def _get_connection(self) -> Http: """ Return the current connection or creates a new connection to the current server. """ @@ -4241,7 +4344,7 @@ def _get_connection(self): return self._connection - def _close_connection(self): + def _close_connection(self) -> None: """ Close the current connection. """ @@ -4260,7 +4363,7 @@ def _close_connection(self): # ======================================================================== # Utility - def _parse_records(self, records): + def _parse_records(self, records: list) -> list: """ Parse 'records' returned from the api to do local modifications: @@ -4316,14 +4419,14 @@ def _parse_records(self, records): return records - def _build_thumb_url(self, entity_type, entity_id): + def _build_thumb_url(self, entity_type: str, entity_id: int) -> str: """ Return the URL for the thumbnail of an entity given the entity type and the entity id. Note: This makes a call to the server for every thumbnail. - :param entity_type: Entity type the id is for. - :param entity_id: id of the entity to get the thumbnail for. + :param str entity_type: Entity type the id is for. + :param int entity_id: int of the entity to get the thumbnail for. :returns: Fully qualified url to the thumbnail. """ # Example response from the end point @@ -4339,7 +4442,7 @@ def _build_thumb_url(self, entity_type, entity_id): + "entity_type=%(e_type)s&entity_id=%(e_id)s" % entity_info ) - body = self._make_call("GET", url, None, None)[2] + body: str = self._make_call("GET", url, None, None)[2] code, thumb_url = body.splitlines() code = int(code) @@ -4364,8 +4467,12 @@ def _build_thumb_url(self, entity_type, entity_id): raise RuntimeError("Unknown code %s %s" % (code, thumb_url)) def _dict_to_list( - self, d, key_name="field_name", value_name="value", extra_data=None - ): + self, + d: Optional[dict[str, Any]], + key_name: str = "field_name", + value_name: str = "value", + extra_data=None, + ) -> list: """ Utility function to convert a dict into a list dicts using the key_name and value_name keys. @@ -4382,7 +4489,7 @@ def _dict_to_list( ret.append(d) return ret - def _dict_to_extra_data(self, d, key_name="value"): + def _dict_to_extra_data(self, d: Optional[dict], key_name="value") -> dict: """ Utility function to convert a dict into a dict compatible with the extra_data arg of _dict_to_list. @@ -4391,7 +4498,7 @@ def _dict_to_extra_data(self, d, key_name="value"): """ return dict([(k, {key_name: v}) for (k, v) in six.iteritems((d or {}))]) - def _upload_file_to_storage(self, path, storage_url): + def _upload_file_to_storage(self, path: str, storage_url: str) -> None: """ Internal function to upload an entire file to the Cloud storage. @@ -4411,7 +4518,9 @@ def _upload_file_to_storage(self, path, storage_url): LOG.debug("File uploaded to Cloud storage: %s", filename) - def _multipart_upload_file_to_storage(self, path, upload_info): + def _multipart_upload_file_to_storage( + self, path: str, upload_info: dict[str, Any] + ) -> None: """ Internal function to upload a file to the Cloud storage in multiple parts. @@ -4453,7 +4562,9 @@ def _multipart_upload_file_to_storage(self, path, upload_info): LOG.debug("File uploaded in multiple parts to Cloud storage: %s", path) - def _get_upload_part_link(self, upload_info, filename, part_number): + def _get_upload_part_link( + self, upload_info: dict[str, Any], filename: str, part_number: int + ) -> str: """ Internal function to get the url to upload the next part of a file to the Cloud storage, in a multi-part upload process. @@ -4493,7 +4604,9 @@ def _get_upload_part_link(self, upload_info, filename, part_number): LOG.debug("Got next upload link from server for multipart upload.") return result.split("\n", 2)[1] - def _upload_data_to_storage(self, data, content_type, size, storage_url): + def _upload_data_to_storage( + self, data: BinaryIO, content_type: str, size: int, storage_url: str + ) -> str: """ Internal function to upload data to Cloud storage. @@ -4548,13 +4661,15 @@ def _upload_data_to_storage(self, data, content_type, size, storage_url): LOG.debug("Part upload completed successfully.") return etag - def _complete_multipart_upload(self, upload_info, filename, etags): + def _complete_multipart_upload( + self, upload_info: dict[str, Any], filename: str, etags: Iterable[str] + ) -> None: """ Internal function to complete a multi-part upload to the Cloud storage. :param dict upload_info: Contains details received from the server, about the upload. :param str filename: Name of the file for which we want to complete the upload. - :param tupple etags: Contains the etag of each uploaded file part. + :param tuple etags: Contains the etag of each uploaded file part. """ params = { @@ -4581,7 +4696,9 @@ def _complete_multipart_upload(self, upload_info, filename, etags): if not result.startswith("1"): raise ShotgunError("Unable get upload part link: %s" % result) - def _requires_direct_s3_upload(self, entity_type, field_name): + def _requires_direct_s3_upload( + self, entity_type: str, field_name: Optional[str] + ) -> bool: """ Internal function that determines if an entity_type + field_name combination should be uploaded to cloud storage. @@ -4622,7 +4739,7 @@ def _requires_direct_s3_upload(self, entity_type, field_name): else: return False - def _send_form(self, url, params): + def _send_form(self, url: str, params: dict[str, Any]) -> str: """ Utility function to send a Form to Shotgun and process any HTTP errors that could occur. @@ -4680,7 +4797,7 @@ def __init__(self, *args, **kwargs): self.__ca_certs = kwargs.pop("ca_certs") super().__init__(self, *args, **kwargs) - def connect(self): + def connect(self) -> None: "Connect to a host on a given (SSL) port." super().connect(self) # Now that the regular HTTP socket has been created, wrap it with our SSL certs. @@ -4806,7 +4923,7 @@ def https_request(self, request): return self.http_request(request) -def _translate_filters(filters, filter_operator): +def _translate_filters(filters: Union[list, tuple], filter_operator) -> dict[str, Any]: """ Translate filters params into data structure expected by rpc call. """ @@ -4815,7 +4932,7 @@ def _translate_filters(filters, filter_operator): return _translate_filters_dict(wrapped_filters) -def _translate_filters_dict(sg_filter): +def _translate_filters_dict(sg_filter: dict[str, Any]) -> dict[str, Any]: new_filters = {} filter_operator = sg_filter.get("filter_operator") From cad90fd6d0d23321ec2d14d9ad9ea529b03b074b Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Wed, 18 Jun 2025 21:59:55 -0700 Subject: [PATCH 2/3] Add TypedDicts --- shotgun_api3/shotgun.py | 49 +++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index 8f71da0fa..07e5399d2 100644 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -57,6 +57,7 @@ Literal, NoReturn, Optional, + TypedDict, TypeVar, Union, TYPE_CHECKING, @@ -88,6 +89,22 @@ LOG.setLevel(logging.WARN) +class OrderItem(TypedDict): + field_name: str + direction: str + + +class GroupingItem(TypedDict): + field: str + type: str + direction: str + + +class BaseEntity(TypedDict, total=False): + id: int + type: str + + def _is_mimetypes_broken(): """ Checks if this version of Python ships with a broken version of mimetypes @@ -912,12 +929,12 @@ def find_one( entity_type: str, filters: Union[list, tuple, dict[str, Any]], fields: Optional[list[str]] = None, - order: Optional[list[dict[str, Any]]] = None, + order: Optional[list[OrderItem]] = None, filter_operator: Optional[Literal["all", "any"]] = None, retired_only: bool = False, include_archived_projects: bool = True, additional_filter_presets: Optional[list[dict[str, Any]]] = None, - ) -> Optional[dict[str, Any]]: + ) -> Optional[BaseEntity]: """ Shortcut for :meth:`~shotgun_api3.Shotgun.find` with ``limit=1`` so it returns a single result. @@ -992,14 +1009,14 @@ def find( entity_type: str, filters: Union[list, tuple, dict[str, Any]], fields: Optional[list[str]] = None, - order: Optional[list[dict[str, Any]]] = None, + order: Optional[list[OrderItem]] = None, filter_operator: Optional[Literal["all", "any"]] = None, limit: int = 0, retired_only: bool = False, page: int = 0, include_archived_projects: bool = True, additional_filter_presets: Optional[list[dict[str, Any]]] = None, - ) -> list[dict[str, Any]]: + ) -> list[BaseEntity]: """ Find entities matching the given filters. @@ -1194,7 +1211,7 @@ def _construct_read_parameters( retired_only: bool, order: Optional[list[dict[str, Any]]], include_archived_projects: bool, - additional_filter_presets, + additional_filter_presets: Optional[list[dict[str, Any]]], ) -> dict[str, Any]: params: dict[str, Any] = {} params["type"] = entity_type @@ -1263,7 +1280,7 @@ def summarize( filters: Union[list, dict[str, Any]], summary_fields: list[dict[str, str]], filter_operator: Optional[str] = None, - grouping: Optional[list] = None, + grouping: Optional[list[GroupingItem]] = None, include_archived_projects: bool = True, ) -> dict[str, Any]: """ @@ -1563,7 +1580,7 @@ def update( entity_id: int, data: dict[str, Any], multi_entity_update_modes: Optional[dict[str, Any]] = None, - ) -> dict[str, Any]: + ) -> BaseEntity: """ Update the specified entity with the supplied data. @@ -2024,7 +2041,7 @@ def following( user: dict[str, Any], project: Optional[dict[str, Any]] = None, entity_type: Optional[str] = None, - ) -> list[dict[str, Any]]: + ) -> list[BaseEntity]: """ Return all entity instances a user is following. @@ -2056,8 +2073,8 @@ def following( return self._call_rpc("following", params) def schema_entity_read( - self, project_entity: Optional[dict[str, Any]] = None - ) -> dict[str, Any]: + self, project_entity: Optional[BaseEntity] = None + ) -> dict[str, dict[str, Any]]: """ Return all active entity types, their display names, and their visibility. @@ -2102,8 +2119,8 @@ def schema_entity_read( return self._call_rpc("schema_entity_read", None) def schema_read( - self, project_entity: Optional[dict[str, Any]] = None - ) -> dict[str, Any]: + self, project_entity: Optional[BaseEntity] = None + ) -> dict[str, dict[str, Any]]: """ Get the schema for all fields on all entities. @@ -2179,8 +2196,8 @@ def schema_field_read( self, entity_type: str, field_name: Optional[str] = None, - project_entity: Optional[dict[str, Any]] = None, - ) -> dict[str, Any]: + project_entity: Optional[BaseEntity] = None, + ) -> dict[str, dict[str, Any]]: """ Get schema for all fields on the specified entity type or just the field name specified if provided. @@ -2292,7 +2309,7 @@ def schema_field_update( entity_type: str, field_name: str, properties: dict[str, Any], - project_entity: Optional[dict[str, Any]] = None, + project_entity: Optional[BaseEntity] = None, ) -> bool: """ Update the properties for the specified field on an entity. @@ -2408,7 +2425,7 @@ def share_thumbnail( self, entities: list[dict[str, Any]], thumbnail_path: Optional[str] = None, - source_entity: Optional[dict[str, Any]] = None, + source_entity: Optional[BaseEntity] = None, filmstrip_thumbnail: bool = False, **kwargs: Any, ) -> int: From 9d4e694ebe4a7249acf2e7325fd889c01e4cded9 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Wed, 18 Jun 2025 22:09:18 -0700 Subject: [PATCH 3/3] Add py.typed marker --- setup.py | 2 +- shotgun_api3/py.typed | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 shotgun_api3/py.typed diff --git a/setup.py b/setup.py index 2e25a17d5..0e1a02eef 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ packages=find_packages(exclude=("tests",)), script_args=sys.argv[1:], include_package_data=True, - package_data={"": ["cacerts.txt", "cacert.pem"]}, + package_data={"": ["cacerts.txt", "cacert.pem", "py.typed"]}, zip_safe=False, python_requires=">=3.9.0", classifiers=[ diff --git a/shotgun_api3/py.typed b/shotgun_api3/py.typed new file mode 100644 index 000000000..e69de29bb