diff --git a/setup.cfg b/setup.cfg index 1d4a31c202..d86316a13e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,6 +61,7 @@ warn_unreachable = True strict_equality = True disallow_untyped_defs = True disallow_untyped_calls = True +show_error_codes = True files = tuf/api/, tuf/ngclient, diff --git a/tests/repository_simulator.py b/tests/repository_simulator.py index ec8a899d5f..a763fc053d 100644 --- a/tests/repository_simulator.py +++ b/tests/repository_simulator.py @@ -93,11 +93,7 @@ class RepositorySimulator(FetcherInterface): """Simulates a repository that can be used for testing.""" # pylint: disable=too-many-instance-attributes - def __init__(self): - self.md_root: Metadata[Root] = None - self.md_timestamp: Metadata[Timestamp] = None - self.md_snapshot: Metadata[Snapshot] = None - self.md_targets: Metadata[Targets] = None + def __init__(self) -> None: self.md_delegates: Dict[str, Metadata[Targets]] = {} # other metadata is signed on-demand (when fetched) but roots must be @@ -117,7 +113,7 @@ def __init__(self): # Enable hash-prefixed target file names self.prefix_targets_with_hash = True - self.dump_dir = None + self.dump_dir: Optional[str] = None self.dump_version = 0 now = datetime.utcnow() @@ -152,12 +148,12 @@ def create_key() -> Tuple[Key, SSlibSigner]: sslib_key = generate_ed25519_key() return Key.from_securesystemslib_key(sslib_key), SSlibSigner(sslib_key) - def add_signer(self, role: str, signer: SSlibSigner): + def add_signer(self, role: str, signer: SSlibSigner) -> None: if role not in self.signers: self.signers[role] = {} self.signers[role][signer.key_dict["keyid"]] = signer - def _initialize(self): + def _initialize(self) -> None: """Setup a minimal valid repository.""" targets = Targets(1, SPEC_VER, self.safe_expiry, {}, None) @@ -182,7 +178,7 @@ def _initialize(self): self.md_root = Metadata(root, OrderedDict()) self.publish_root() - def publish_root(self): + def publish_root(self) -> None: """Sign and store a new serialized version of root.""" self.md_root.signatures.clear() for signer in self.signers["root"].values(): @@ -199,12 +195,12 @@ def fetch(self, url: str) -> Iterator[bytes]: if path.startswith("/metadata/") and path.endswith(".json"): # figure out rolename and version ver_and_name = path[len("/metadata/") :][: -len(".json")] - version, _, role = ver_and_name.partition(".") + version_str, _, role = ver_and_name.partition(".") # root is always version-prefixed while timestamp is always NOT if role == "root" or ( self.root.consistent_snapshot and ver_and_name != "timestamp" ): - version = int(version) + version: Optional[int] = int(version_str) else: # the file is not version-prefixed role = ver_and_name @@ -216,11 +212,10 @@ def fetch(self, url: str) -> Iterator[bytes]: target_path = path[len("/targets/") :] dir_parts, sep, prefixed_filename = target_path.rpartition("/") # extract the hash prefix, if any + prefix: Optional[str] = None + filename = prefixed_filename if self.root.consistent_snapshot and self.prefix_targets_with_hash: prefix, _, filename = prefixed_filename.partition(".") - else: - filename = prefixed_filename - prefix = None target_path = f"{dir_parts}{sep}{filename}" yield self._fetch_target(target_path, prefix) @@ -261,8 +256,9 @@ def _fetch_metadata( return self.signed_roots[version - 1] # sign and serialize the requested metadata + md: Optional[Metadata] if role == "timestamp": - md: Metadata = self.md_timestamp + md = self.md_timestamp elif role == "snapshot": md = self.md_snapshot elif role == "targets": @@ -294,7 +290,7 @@ def _compute_hashes_and_length( hashes = {sslib_hash.DEFAULT_HASH_ALGORITHM: digest_object.hexdigest()} return hashes, len(data) - def update_timestamp(self): + def update_timestamp(self) -> None: """Update timestamp and assign snapshot version to snapshot_meta version. """ @@ -307,7 +303,7 @@ def update_timestamp(self): self.timestamp.version += 1 - def update_snapshot(self): + def update_snapshot(self) -> None: """Update snapshot, assign targets versions and update timestamp.""" for role, delegate in self.all_targets(): hashes = None @@ -322,7 +318,7 @@ def update_snapshot(self): self.snapshot.version += 1 self.update_timestamp() - def add_target(self, role: str, data: bytes, path: str): + def add_target(self, role: str, data: bytes, path: str) -> None: """Create a target from data and add it to the target_files.""" if role == "targets": targets = self.targets @@ -341,7 +337,7 @@ def add_delegation( terminating: bool, paths: Optional[List[str]], hash_prefixes: Optional[List[str]], - ): + ) -> None: """Add delegated target role to the repository.""" if delegator_name == "targets": delegator = self.targets @@ -351,7 +347,7 @@ def add_delegation( # Create delegation role = DelegatedRole(name, [], 1, terminating, paths, hash_prefixes) if delegator.delegations is None: - delegator.delegations = Delegations({}, {}) + delegator.delegations = Delegations({}, OrderedDict()) # put delegation last by default delegator.delegations.roles[role.name] = role @@ -363,7 +359,7 @@ def add_delegation( # Add metadata for the role self.md_delegates[role.name] = Metadata(targets, OrderedDict()) - def write(self): + def write(self) -> None: """Dump current repository metadata to self.dump_dir This is a debugging tool: dumping repository state before running diff --git a/tests/test_api.py b/tests/test_api.py index 401d061867..02c6521725 100755 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -14,6 +14,7 @@ import tempfile import unittest from datetime import datetime, timedelta +from typing import ClassVar, Dict from dateutil.relativedelta import relativedelta from securesystemslib import hash as sslib_hash @@ -28,6 +29,7 @@ from tuf import exceptions from tuf.api.metadata import ( DelegatedRole, + Delegations, Key, Metadata, MetaFile, @@ -47,8 +49,13 @@ class TestMetadata(unittest.TestCase): """Tests for public API of all classes in 'tuf/api/metadata.py'.""" + temporary_directory: ClassVar[str] + repo_dir: ClassVar[str] + keystore_dir: ClassVar[str] + keystore: ClassVar[Dict[str, str]] + @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: # Create a temporary directory to store the repository, metadata, and # target files. 'temporary_directory' must be deleted in # TearDownClass() so that temporary files are always removed, even when @@ -78,12 +85,12 @@ def setUpClass(cls): ) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: # Remove the temporary repository directory, which should contain all # the metadata, targets, and key files generated for the test cases. shutil.rmtree(cls.temporary_directory) - def test_generic_read(self): + def test_generic_read(self) -> None: for metadata, inner_metadata_cls in [ ("root", Root), ("snapshot", Snapshot), @@ -120,7 +127,7 @@ def test_generic_read(self): os.remove(bad_metadata_path) - def test_compact_json(self): + def test_compact_json(self) -> None: path = os.path.join(self.repo_dir, "metadata", "targets.json") md_obj = Metadata.from_file(path) self.assertTrue( @@ -128,7 +135,7 @@ def test_compact_json(self): < len(JSONSerializer().serialize(md_obj)) ) - def test_read_write_read_compare(self): + def test_read_write_read_compare(self) -> None: for metadata in ["root", "snapshot", "timestamp", "targets"]: path = os.path.join(self.repo_dir, "metadata", metadata + ".json") md_obj = Metadata.from_file(path) @@ -140,7 +147,7 @@ def test_read_write_read_compare(self): os.remove(path_2) - def test_to_from_bytes(self): + def test_to_from_bytes(self) -> None: for metadata in ["root", "snapshot", "timestamp", "targets"]: path = os.path.join(self.repo_dir, "metadata", metadata + ".json") with open(path, "rb") as f: @@ -157,7 +164,7 @@ def test_to_from_bytes(self): metadata_obj_2 = Metadata.from_bytes(obj_bytes) self.assertEqual(metadata_obj_2.to_bytes(), obj_bytes) - def test_sign_verify(self): + def test_sign_verify(self) -> None: root_path = os.path.join(self.repo_dir, "metadata", "root.json") root = Metadata[Root].from_file(root_path).signed @@ -183,7 +190,7 @@ def test_sign_verify(self): # Test verifying with explicitly set serializer targets_key.verify_signature(md_obj, CanonicalJSONSerializer()) with self.assertRaises(exceptions.UnsignedMetadataError): - targets_key.verify_signature(md_obj, JSONSerializer()) + targets_key.verify_signature(md_obj, JSONSerializer()) # type: ignore[arg-type] sslib_signer = SSlibSigner(self.keystore["snapshot"]) # Append a new signature with the unrelated key and assert that ... @@ -206,7 +213,7 @@ def test_sign_verify(self): with self.assertRaises(exceptions.UnsignedMetadataError): targets_key.verify_signature(md_obj) - def test_verify_failures(self): + def test_verify_failures(self) -> None: root_path = os.path.join(self.repo_dir, "metadata", "root.json") root = Metadata[Root].from_file(root_path).signed @@ -248,7 +255,7 @@ def test_verify_failures(self): timestamp_key.verify_signature(md_obj) sig.signature = correct_sig - def test_metadata_base(self): + def test_metadata_base(self) -> None: # Use of Snapshot is arbitrary, we're just testing the base class # features with real data snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json") @@ -290,7 +297,7 @@ def test_metadata_base(self): with self.assertRaises(ValueError): Metadata.from_dict(data) - def test_metadata_snapshot(self): + def test_metadata_snapshot(self) -> None: snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json") snapshot = Metadata[Snapshot].from_file(snapshot_path) @@ -309,7 +316,7 @@ def test_metadata_snapshot(self): snapshot.signed.meta["role1.json"].to_dict(), fileinfo.to_dict() ) - def test_metadata_timestamp(self): + def test_metadata_timestamp(self) -> None: timestamp_path = os.path.join( self.repo_dir, "metadata", "timestamp.json" ) @@ -349,7 +356,7 @@ def test_metadata_timestamp(self): timestamp.signed.snapshot_meta.to_dict(), fileinfo.to_dict() ) - def test_metadata_verify_delegate(self): + def test_metadata_verify_delegate(self) -> None: root_path = os.path.join(self.repo_dir, "metadata", "root.json") root = Metadata[Root].from_file(root_path) snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json") @@ -410,14 +417,14 @@ def test_metadata_verify_delegate(self): snapshot.sign(SSlibSigner(self.keystore["timestamp"]), append=True) root.verify_delegate("snapshot", snapshot) - def test_key_class(self): + def test_key_class(self) -> None: # Test if from_securesystemslib_key removes the private key from keyval # of a securesystemslib key dictionary. sslib_key = generate_ed25519_key() key = Key.from_securesystemslib_key(sslib_key) self.assertFalse("private" in key.keyval.keys()) - def test_root_add_key_and_remove_key(self): + def test_root_add_key_and_remove_key(self) -> None: root_path = os.path.join(self.repo_dir, "metadata", "root.json") root = Metadata[Root].from_file(root_path) @@ -475,7 +482,7 @@ def test_root_add_key_and_remove_key(self): with self.assertRaises(ValueError): root.signed.remove_key("nosuchrole", keyid) - def test_is_target_in_pathpattern(self): + def test_is_target_in_pathpattern(self) -> None: # pylint: disable=protected-access supported_use_cases = [ ("foo.tgz", "foo.tgz"), @@ -507,7 +514,7 @@ def test_is_target_in_pathpattern(self): DelegatedRole._is_target_in_pathpattern(targetpath, pathpattern) ) - def test_metadata_targets(self): + def test_metadata_targets(self) -> None: targets_path = os.path.join(self.repo_dir, "metadata", "targets.json") targets = Metadata[Targets].from_file(targets_path) @@ -531,7 +538,7 @@ def test_metadata_targets(self): targets.signed.targets[filename].to_dict(), fileinfo.to_dict() ) - def test_targets_key_api(self): + def test_targets_key_api(self) -> None: targets_path = os.path.join(self.repo_dir, "metadata", "targets.json") targets: Targets = Metadata[Targets].from_file(targets_path).signed @@ -545,6 +552,7 @@ def test_targets_key_api(self): "threshold": 1, } ) + assert isinstance(targets.delegations, Delegations) targets.delegations.roles["role2"] = delegated_role key_dict = { @@ -608,7 +616,7 @@ def test_targets_key_api(self): targets.remove_key("role1", key.keyid) self.assertTrue(targets.delegations is None) - def test_length_and_hash_validation(self): + def test_length_and_hash_validation(self) -> None: # Test metadata files' hash and length verification. # Use timestamp to get a MetaFile object and snapshot @@ -648,7 +656,7 @@ def test_length_and_hash_validation(self): # Test wrong algorithm format (sslib.FormatError) snapshot_metafile.hashes = { - 256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab" + 256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab" # type: ignore[dict-item] } with self.assertRaises(exceptions.LengthOrHashMismatchError): snapshot_metafile.verify_length_and_hashes(data) @@ -678,7 +686,7 @@ def test_length_and_hash_validation(self): with self.assertRaises(exceptions.LengthOrHashMismatchError): file1_targetfile.verify_length_and_hashes(file1) - def test_targetfile_from_file(self): + def test_targetfile_from_file(self) -> None: # Test with an existing file and valid hash algorithm file_path = os.path.join(self.repo_dir, "targets", "file1.txt") targetfile_from_file = TargetFile.from_file( @@ -700,7 +708,7 @@ def test_targetfile_from_file(self): with self.assertRaises(exceptions.UnsupportedAlgorithmError): TargetFile.from_file(file_path, file_path, ["123"]) - def test_targetfile_from_data(self): + def test_targetfile_from_data(self) -> None: data = b"Inline test content" target_file_path = os.path.join(self.repo_dir, "targets", "file1.txt") @@ -714,7 +722,7 @@ def test_targetfile_from_data(self): targetfile_from_data = TargetFile.from_data(target_file_path, data) targetfile_from_data.verify_length_and_hashes(data) - def test_is_delegated_role(self): + def test_is_delegated_role(self) -> None: # test path matches # see more extensive tests in test_is_target_in_pathpattern() for paths in [ diff --git a/tests/test_fetcher_ng.py b/tests/test_fetcher_ng.py index 55dec1e301..4958da9f15 100644 --- a/tests/test_fetcher_ng.py +++ b/tests/test_fetcher_ng.py @@ -13,6 +13,7 @@ import sys import tempfile import unittest +from typing import Any, ClassVar, Iterator from unittest.mock import Mock, patch import requests @@ -28,17 +29,19 @@ class TestFetcher(unittest_toolbox.Modified_TestCase): """Test RequestsFetcher class.""" + server_process_handler: ClassVar[utils.TestServerProcess] + @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: # Launch a SimpleHTTPServer (serves files in the current dir). cls.server_process_handler = utils.TestServerProcess(log=logger) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: # Stop server process and perform clean up. cls.server_process_handler.clean() - def setUp(self): + def setUp(self) -> None: """ Create a temporary file and launch a simple server in the current working directory. @@ -64,12 +67,12 @@ def setUp(self): # Instantiate a concrete instance of FetcherInterface self.fetcher = RequestsFetcher() - def tearDown(self): + def tearDown(self) -> None: # Remove temporary directory unittest_toolbox.Modified_TestCase.tearDown(self) # Simple fetch. - def test_fetch(self): + def test_fetch(self) -> None: with tempfile.TemporaryFile() as temp_file: for chunk in self.fetcher.fetch(self.url): temp_file.write(chunk) @@ -80,7 +83,7 @@ def test_fetch(self): ) # URL data downloaded in more than one chunk - def test_fetch_in_chunks(self): + def test_fetch_in_chunks(self) -> None: # Set a smaller chunk size to ensure that the file will be downloaded # in more than one chunk self.fetcher.chunk_size = 4 @@ -105,12 +108,12 @@ def test_fetch_in_chunks(self): self.assertEqual(chunks_count, expected_chunks_count) # Incorrect URL parsing - def test_url_parsing(self): + def test_url_parsing(self) -> None: with self.assertRaises(exceptions.URLParsingError): self.fetcher.fetch(self.random_string()) # File not found error - def test_http_error(self): + def test_http_error(self) -> None: with self.assertRaises(exceptions.FetcherHTTPError) as cm: self.url = f"{self.url_prefix}/non-existing-path" self.fetcher.fetch(self.url) @@ -118,7 +121,7 @@ def test_http_error(self): # Response read timeout error @patch.object(requests.Session, "get") - def test_response_read_timeout(self, mock_session_get): + def test_response_read_timeout(self, mock_session_get: Any) -> None: mock_response = Mock() attr = { "raw.read.side_effect": urllib3.exceptions.ReadTimeoutError( @@ -136,28 +139,28 @@ def test_response_read_timeout(self, mock_session_get): @patch.object( requests.Session, "get", side_effect=urllib3.exceptions.TimeoutError ) - def test_session_get_timeout(self, mock_session_get): + def test_session_get_timeout(self, mock_session_get: Any) -> None: with self.assertRaises(exceptions.SlowRetrievalError): self.fetcher.fetch(self.url) mock_session_get.assert_called_once() # Simple bytes download - def test_download_bytes(self): + def test_download_bytes(self) -> None: data = self.fetcher.download_bytes(self.url, self.file_length) self.assertEqual(self.file_contents, data.decode("utf-8")) # Download file smaller than required max_length - def test_download_bytes_upper_length(self): + def test_download_bytes_upper_length(self) -> None: data = self.fetcher.download_bytes(self.url, self.file_length + 4) self.assertEqual(self.file_contents, data.decode("utf-8")) # Download a file bigger than expected - def test_download_bytes_length_mismatch(self): + def test_download_bytes_length_mismatch(self) -> None: with self.assertRaises(exceptions.DownloadLengthMismatchError): self.fetcher.download_bytes(self.url, self.file_length - 4) # Simple file download - def test_download_file(self): + def test_download_file(self) -> None: with self.fetcher.download_file( self.url, self.file_length ) as temp_file: @@ -165,7 +168,7 @@ def test_download_file(self): self.assertEqual(self.file_length, temp_file.tell()) # Download file smaller than required max_length - def test_download_file_upper_length(self): + def test_download_file_upper_length(self) -> None: with self.fetcher.download_file( self.url, self.file_length + 4 ) as temp_file: @@ -173,8 +176,10 @@ def test_download_file_upper_length(self): self.assertEqual(self.file_length, temp_file.tell()) # Download a file bigger than expected - def test_download_file_length_mismatch(self): + def test_download_file_length_mismatch(self) -> Iterator[Any]: with self.assertRaises(exceptions.DownloadLengthMismatchError): + # Force download_file to execute and raise the error since it is a + # context manager and returns Iterator[IO] yield self.fetcher.download_file(self.url, self.file_length - 4) diff --git a/tests/test_metadata_serialization.py b/tests/test_metadata_serialization.py index ed3958d142..16f402dba1 100644 --- a/tests/test_metadata_serialization.py +++ b/tests/test_metadata_serialization.py @@ -11,7 +11,6 @@ import logging import sys import unittest -from typing import Dict from tests import utils from tuf.api.metadata import ( @@ -54,7 +53,7 @@ class TestSerialization(unittest.TestCase): } @utils.run_sub_tests_with_dataset(invalid_signed) - def test_invalid_signed_serialization(self, test_case_data: Dict[str, str]): + def test_invalid_signed_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((KeyError, ValueError, TypeError)): Snapshot.from_dict(copy.deepcopy(case_dict)) @@ -69,7 +68,7 @@ def test_invalid_signed_serialization(self, test_case_data: Dict[str, str]): } @utils.run_sub_tests_with_dataset(valid_keys) - def test_valid_key_serialization(self, test_case_data: str): + def test_valid_key_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) key = Key.from_dict("id", copy.copy(case_dict)) self.assertDictEqual(case_dict, key.to_dict()) @@ -86,7 +85,7 @@ def test_valid_key_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(invalid_keys) - def test_invalid_key_serialization(self, test_case_data: Dict[str, str]): + def test_invalid_key_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((TypeError, KeyError)): keyid = case_dict.pop("keyid") @@ -101,7 +100,7 @@ def test_invalid_key_serialization(self, test_case_data: Dict[str, str]): } @utils.run_sub_tests_with_dataset(invalid_roles) - def test_invalid_role_serialization(self, test_case_data: Dict[str, str]): + def test_invalid_role_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((KeyError, TypeError, ValueError)): Role.from_dict(copy.deepcopy(case_dict)) @@ -114,7 +113,7 @@ def test_invalid_role_serialization(self, test_case_data: Dict[str, str]): } @utils.run_sub_tests_with_dataset(valid_roles) - def test_role_serialization(self, test_case_data: str): + def test_role_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) role = Role.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, role.to_dict()) @@ -162,7 +161,7 @@ def test_role_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(valid_roots) - def test_root_serialization(self, test_case_data: str): + def test_root_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) root = Root.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, root.to_dict()) @@ -204,7 +203,7 @@ def test_root_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(invalid_roots) - def test_invalid_root_serialization(self, test_case_data: Dict[str, str]): + def test_invalid_root_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises(ValueError): Root.from_dict(copy.deepcopy(case_dict)) @@ -219,9 +218,7 @@ def test_invalid_root_serialization(self, test_case_data: Dict[str, str]): } @utils.run_sub_tests_with_dataset(invalid_metafiles) - def test_invalid_metafile_serialization( - self, test_case_data: Dict[str, str] - ): + def test_invalid_metafile_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((TypeError, ValueError, AttributeError)): MetaFile.from_dict(copy.deepcopy(case_dict)) @@ -235,7 +232,7 @@ def test_invalid_metafile_serialization( } @utils.run_sub_tests_with_dataset(valid_metafiles) - def test_metafile_serialization(self, test_case_data: str): + def test_metafile_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) metafile = MetaFile.from_dict(copy.copy(case_dict)) self.assertDictEqual(case_dict, metafile.to_dict()) @@ -245,9 +242,7 @@ def test_metafile_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(invalid_timestamps) - def test_invalid_timestamp_serialization( - self, test_case_data: Dict[str, str] - ): + def test_invalid_timestamp_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((ValueError, KeyError)): Timestamp.from_dict(copy.deepcopy(case_dict)) @@ -260,7 +255,7 @@ def test_invalid_timestamp_serialization( } @utils.run_sub_tests_with_dataset(valid_timestamps) - def test_timestamp_serialization(self, test_case_data: str): + def test_timestamp_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) timestamp = Timestamp.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, timestamp.to_dict()) @@ -279,7 +274,7 @@ def test_timestamp_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(valid_snapshots) - def test_snapshot_serialization(self, test_case_data: str): + def test_snapshot_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) snapshot = Snapshot.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, snapshot.to_dict()) @@ -300,7 +295,7 @@ def test_snapshot_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(valid_delegated_roles) - def test_delegated_role_serialization(self, test_case_data: str): + def test_delegated_role_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) deserialized_role = DelegatedRole.from_dict(copy.copy(case_dict)) self.assertDictEqual(case_dict, deserialized_role.to_dict()) @@ -317,13 +312,22 @@ def test_delegated_role_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(invalid_delegated_roles) - def test_invalid_delegated_role_serialization(self, test_case_data: str): + def test_invalid_delegated_role_serialization( + self, test_case_data: str + ) -> None: case_dict = json.loads(test_case_data) with self.assertRaises(ValueError): DelegatedRole.from_dict(copy.copy(case_dict)) invalid_delegations: utils.DataSet = { "empty delegations": "{}", + "missing keys": '{ "roles": [ \ + {"keyids": ["keyid"], "name": "a", "terminating": true, "paths": ["fn1"], "threshold": 3}, \ + {"keyids": ["keyid2"], "name": "b", "terminating": true, "paths": ["fn2"], "threshold": 4} ] \ + }', + "missing roles": '{"keys": { \ + "keyid1" : {"keytype": "rsa", "scheme": "rsassa-pss-sha256", "keyval": {"public": "foo"}}, \ + "keyid2" : {"keytype": "ed25519", "scheme": "ed25519", "keyval": {"public": "bar"}}}}', "bad keys": '{"keys": "foo", \ "roles": [{"keyids": ["keyid"], "name": "a", "paths": ["fn1", "fn2"], "terminating": false, "threshold": 3}]}', "bad roles": '{"keys": {"keyid" : {"keytype": "rsa", "scheme": "rsassa-pss-sha256", "keyval": {"public": "foo"}}}, \ @@ -337,7 +341,9 @@ def test_invalid_delegated_role_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(invalid_delegations) - def test_invalid_delegation_serialization(self, test_case_data: str): + def test_invalid_delegation_serialization( + self, test_case_data: str + ) -> None: case_dict = json.loads(test_case_data) with self.assertRaises((ValueError, KeyError, AttributeError)): Delegations.from_dict(copy.deepcopy(case_dict)) @@ -359,7 +365,7 @@ def test_invalid_delegation_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(valid_delegations) - def test_delegation_serialization(self, test_case_data: str): + def test_delegation_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) delegation = Delegations.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, delegation.to_dict()) @@ -373,8 +379,8 @@ def test_delegation_serialization(self, test_case_data: str): @utils.run_sub_tests_with_dataset(invalid_targetfiles) def test_invalid_targetfile_serialization( - self, test_case_data: Dict[str, str] - ): + self, test_case_data: str + ) -> None: case_dict = json.loads(test_case_data) with self.assertRaises(KeyError): TargetFile.from_dict(copy.deepcopy(case_dict), "file1.txt") @@ -388,7 +394,7 @@ def test_invalid_targetfile_serialization( } @utils.run_sub_tests_with_dataset(valid_targetfiles) - def test_targetfile_serialization(self, test_case_data: str): + def test_targetfile_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) target_file = TargetFile.from_dict(copy.copy(case_dict), "file1.txt") self.assertDictEqual(case_dict, target_file.to_dict()) @@ -420,7 +426,7 @@ def test_targetfile_serialization(self, test_case_data: str): } @utils.run_sub_tests_with_dataset(valid_targets) - def test_targets_serialization(self, test_case_data): + def test_targets_serialization(self, test_case_data: str) -> None: case_dict = json.loads(test_case_data) targets = Targets.from_dict(copy.deepcopy(case_dict)) self.assertDictEqual(case_dict, targets.to_dict()) diff --git a/tests/test_trusted_metadata_set.py b/tests/test_trusted_metadata_set.py index 9452c9bff6..9dfacf1a1d 100644 --- a/tests/test_trusted_metadata_set.py +++ b/tests/test_trusted_metadata_set.py @@ -4,7 +4,7 @@ import sys import unittest from datetime import datetime -from typing import Callable, Optional +from typing import Callable, ClassVar, Dict, List, Optional, Tuple from securesystemslib.interface import ( import_ed25519_privatekey_from_file, @@ -18,7 +18,6 @@ Metadata, MetaFile, Root, - Signed, Snapshot, Targets, Timestamp, @@ -31,8 +30,13 @@ class TestTrustedMetadataSet(unittest.TestCase): """Tests for all public API of the TrustedMetadataSet class.""" + keystore: ClassVar[Dict[str, SSlibSigner]] + metadata: ClassVar[Dict[str, bytes]] + repo_dir: ClassVar[str] + + @classmethod def modify_metadata( - self, rolename: str, modification_func: Callable[[Signed], None] + cls, rolename: str, modification_func: Callable ) -> bytes: """Instantiate metadata from rolename type, call modification_func and sign it again with self.keystore[rolename] signer. @@ -42,13 +46,13 @@ def modify_metadata( modification_func: Function that will be called to modify the signed portion of metadata bytes. """ - metadata = Metadata.from_bytes(self.metadata[rolename]) + metadata = Metadata.from_bytes(cls.metadata[rolename]) modification_func(metadata.signed) - metadata.sign(self.keystore[rolename]) + metadata.sign(cls.keystore[rolename]) return metadata.to_bytes() @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.repo_dir = os.path.join( os.getcwd(), "repository_data", "repository", "metadata" ) @@ -81,7 +85,7 @@ def hashes_length_modifier(timestamp: Timestamp) -> None: timestamp.snapshot_meta.length = None cls.metadata["timestamp"] = cls.modify_metadata( - cls, "timestamp", hashes_length_modifier + "timestamp", hashes_length_modifier ) def setUp(self) -> None: @@ -91,7 +95,7 @@ def _update_all_besides_targets( self, timestamp_bytes: Optional[bytes] = None, snapshot_bytes: Optional[bytes] = None, - ): + ) -> None: """Update all metadata roles besides targets. Args: @@ -109,7 +113,7 @@ def _update_all_besides_targets( snapshot_bytes = snapshot_bytes or self.metadata["snapshot"] self.trusted_set.update_snapshot(snapshot_bytes) - def test_update(self): + def test_update(self) -> None: self.trusted_set.update_timestamp(self.metadata["timestamp"]) self.trusted_set.update_snapshot(self.metadata["snapshot"]) self.trusted_set.update_targets(self.metadata["targets"]) @@ -129,8 +133,10 @@ def test_update(self): self.assertTrue(count, 6) - def test_update_metadata_output(self): - timestamp = self.trusted_set.update_timestamp(self.metadata["timestamp"]) + def test_update_metadata_output(self) -> None: + timestamp = self.trusted_set.update_timestamp( + self.metadata["timestamp"] + ) snapshot = self.trusted_set.update_snapshot(self.metadata["snapshot"]) targets = self.trusted_set.update_targets(self.metadata["targets"]) delegeted_targets_1 = self.trusted_set.update_delegated_targets( @@ -145,7 +151,7 @@ def test_update_metadata_output(self): self.assertIsInstance(delegeted_targets_1.signed, Targets) self.assertIsInstance(delegeted_targets_2.signed, Targets) - def test_out_of_order_ops(self): + def test_out_of_order_ops(self) -> None: # Update snapshot before timestamp with self.assertRaises(RuntimeError): self.trusted_set.update_snapshot(self.metadata["snapshot"]) @@ -182,7 +188,7 @@ def test_out_of_order_ops(self): self.metadata["role1"], "role1", "targets" ) - def test_root_with_invalid_json(self): + def test_root_with_invalid_json(self) -> None: # Test loading initial root and root update for test_func in [TrustedMetadataSet, self.trusted_set.update_root]: # root is not json @@ -199,8 +205,8 @@ def test_root_with_invalid_json(self): with self.assertRaises(exceptions.RepositoryError): test_func(self.metadata["snapshot"]) - def test_top_level_md_with_invalid_json(self): - top_level_md = [ + def test_top_level_md_with_invalid_json(self) -> None: + top_level_md: List[Tuple[bytes, Callable[[bytes], Metadata]]] = [ (self.metadata["timestamp"], self.trusted_set.update_timestamp), (self.metadata["snapshot"], self.trusted_set.update_snapshot), (self.metadata["targets"], self.trusted_set.update_targets), @@ -222,7 +228,7 @@ def test_top_level_md_with_invalid_json(self): update_func(metadata) - def test_update_root_new_root(self): + def test_update_root_new_root(self) -> None: # test that root can be updated with a new valid version def root_new_version_modifier(root: Root) -> None: root.version += 1 @@ -230,19 +236,19 @@ def root_new_version_modifier(root: Root) -> None: root = self.modify_metadata("root", root_new_version_modifier) self.trusted_set.update_root(root) - def test_update_root_new_root_cannot_be_verified_with_threshold(self): + def test_update_root_new_root_fail_threshold_verification(self) -> None: # new_root data with threshold which cannot be verified. root = Metadata.from_bytes(self.metadata["root"]) # remove root role keyids representing root signatures - root.signed.roles["root"].keyids = [] + root.signed.roles["root"].keyids = set() with self.assertRaises(exceptions.UnsignedMetadataError): self.trusted_set.update_root(root.to_bytes()) - def test_update_root_new_root_ver_same_as_trusted_root_ver(self): + def test_update_root_new_root_ver_same_as_trusted_root_ver(self) -> None: with self.assertRaises(exceptions.ReplayedMetadataError): self.trusted_set.update_root(self.metadata["root"]) - def test_root_expired_final_root(self): + def test_root_expired_final_root(self) -> None: def root_expired_modifier(root: Root) -> None: root.expires = datetime(1970, 1, 1) @@ -253,7 +259,7 @@ def root_expired_modifier(root: Root) -> None: with self.assertRaises(exceptions.ExpiredMetadataError): tmp_trusted_set.update_timestamp(self.metadata["timestamp"]) - def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self): + def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self) -> None: # new_timestamp.version < trusted_timestamp.version def version_modifier(timestamp: Timestamp) -> None: timestamp.version = 3 @@ -263,7 +269,7 @@ def version_modifier(timestamp: Timestamp) -> None: with self.assertRaises(exceptions.ReplayedMetadataError): self.trusted_set.update_timestamp(self.metadata["timestamp"]) - def test_update_timestamp_snapshot_ver_below_current(self): + def test_update_timestamp_snapshot_ver_below_current(self) -> None: def bump_snapshot_version(timestamp: Timestamp) -> None: timestamp.snapshot_meta.version = 2 @@ -275,7 +281,7 @@ def bump_snapshot_version(timestamp: Timestamp) -> None: with self.assertRaises(exceptions.ReplayedMetadataError): self.trusted_set.update_timestamp(self.metadata["timestamp"]) - def test_update_timestamp_expired(self): + def test_update_timestamp_expired(self) -> None: # new_timestamp has expired def timestamp_expired_modifier(timestamp: Timestamp) -> None: timestamp.expires = datetime(1970, 1, 1) @@ -291,7 +297,7 @@ def timestamp_expired_modifier(timestamp: Timestamp) -> None: with self.assertRaises(exceptions.ExpiredMetadataError): self.trusted_set.update_snapshot(self.metadata["snapshot"]) - def test_update_snapshot_length_or_hash_mismatch(self): + def test_update_snapshot_length_or_hash_mismatch(self) -> None: def modify_snapshot_length(timestamp: Timestamp) -> None: timestamp.snapshot_meta.length = 1 @@ -302,14 +308,16 @@ def modify_snapshot_length(timestamp: Timestamp) -> None: with self.assertRaises(exceptions.RepositoryError): self.trusted_set.update_snapshot(self.metadata["snapshot"]) - def test_update_snapshot_cannot_verify_snapshot_with_threshold(self): + def test_update_snapshot_fail_threshold_verification(self) -> None: self.trusted_set.update_timestamp(self.metadata["timestamp"]) snapshot = Metadata.from_bytes(self.metadata["snapshot"]) snapshot.signatures.clear() with self.assertRaises(exceptions.UnsignedMetadataError): self.trusted_set.update_snapshot(snapshot.to_bytes()) - def test_update_snapshot_version_different_timestamp_snapshot_version(self): + def test_update_snapshot_version_diverge_timestamp_snapshot_version( + self, + ) -> None: def timestamp_version_modifier(timestamp: Timestamp) -> None: timestamp.snapshot_meta.version = 2 @@ -326,7 +334,7 @@ def timestamp_version_modifier(timestamp: Timestamp) -> None: with self.assertRaises(exceptions.BadVersionNumberError): self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_snapshot_file_removed_from_meta(self): + def test_update_snapshot_file_removed_from_meta(self) -> None: self._update_all_besides_targets(self.metadata["timestamp"]) def remove_file_from_meta(snapshot: Snapshot) -> None: @@ -337,7 +345,7 @@ def remove_file_from_meta(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.RepositoryError): self.trusted_set.update_snapshot(snapshot) - def test_update_snapshot_meta_version_decreases(self): + def test_update_snapshot_meta_version_decreases(self) -> None: self.trusted_set.update_timestamp(self.metadata["timestamp"]) def version_meta_modifier(snapshot: Snapshot) -> None: @@ -349,7 +357,7 @@ def version_meta_modifier(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.BadVersionNumberError): self.trusted_set.update_snapshot(self.metadata["snapshot"]) - def test_update_snapshot_expired_new_snapshot(self): + def test_update_snapshot_expired_new_snapshot(self) -> None: self.trusted_set.update_timestamp(self.metadata["timestamp"]) def snapshot_expired_modifier(snapshot: Snapshot) -> None: @@ -364,7 +372,7 @@ def snapshot_expired_modifier(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.ExpiredMetadataError): self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_snapshot_successful_rollback_checks(self): + def test_update_snapshot_successful_rollback_checks(self) -> None: def meta_version_bump(timestamp: Timestamp) -> None: timestamp.snapshot_meta.version += 1 @@ -386,7 +394,7 @@ def version_bump(snapshot: Snapshot) -> None: # update targets to trigger final snapshot meta version check self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_targets_no_meta_in_snapshot(self): + def test_update_targets_no_meta_in_snapshot(self) -> None: def no_meta_modifier(snapshot: Snapshot) -> None: snapshot.meta = {} @@ -396,7 +404,7 @@ def no_meta_modifier(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.RepositoryError): self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_targets_hash_different_than_snapshot_meta_hash(self): + def test_update_targets_hash_diverge_from_snapshot_meta_hash(self) -> None: def meta_length_modifier(snapshot: Snapshot) -> None: for metafile_path in snapshot.meta: snapshot.meta[metafile_path] = MetaFile(version=1, length=1) @@ -407,7 +415,7 @@ def meta_length_modifier(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.RepositoryError): self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_targets_version_different_snapshot_meta_version(self): + def test_update_targets_version_diverge_snapshot_meta_version(self) -> None: def meta_modifier(snapshot: Snapshot) -> None: for metafile_path in snapshot.meta: snapshot.meta[metafile_path] = MetaFile(version=2) @@ -418,7 +426,7 @@ def meta_modifier(snapshot: Snapshot) -> None: with self.assertRaises(exceptions.BadVersionNumberError): self.trusted_set.update_targets(self.metadata["targets"]) - def test_update_targets_expired_new_target(self): + def test_update_targets_expired_new_target(self) -> None: self._update_all_besides_targets() # new_delegated_target has expired def target_expired_modifier(target: Targets) -> None: diff --git a/tests/test_updater_key_rotations.py b/tests/test_updater_key_rotations.py index 2a07fc6761..9855c7f492 100644 --- a/tests/test_updater_key_rotations.py +++ b/tests/test_updater_key_rotations.py @@ -37,8 +37,8 @@ class TestUpdaterKeyRotations(unittest.TestCase): dump_dir: Optional[str] = None def setUp(self) -> None: - self.sim = None - self.metadata_dir = None + self.sim: RepositorySimulator + self.metadata_dir: str self.subtest_count = 0 # pylint: disable-next=consider-using-with self.temp_dir = tempfile.TemporaryDirectory() diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index e4dd7566dc..3616e649ec 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -12,14 +12,14 @@ import sys import tempfile import unittest -from typing import List +from typing import Callable, ClassVar, List from securesystemslib.interface import import_rsa_privatekey_from_file from securesystemslib.signer import SSlibSigner from tests import utils from tuf import exceptions, ngclient, unittest_toolbox -from tuf.api.metadata import Metadata, TargetFile +from tuf.api.metadata import Metadata, Root, TargetFile logger = logging.getLogger(__name__) @@ -27,8 +27,11 @@ class TestUpdater(unittest_toolbox.Modified_TestCase): """Test the Updater class from 'tuf/ngclient/updater.py'.""" + temporary_directory: ClassVar[str] + server_process_handler: ClassVar[utils.TestServerProcess] + @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: # Create a temporary directory to store the repository, metadata, and # target files. 'temporary_directory' must be deleted in # TearDownModule() so that temporary files are always removed, even when @@ -38,18 +41,18 @@ def setUpClass(cls): # Needed because in some tests simple_server.py cannot be found. # The reason is that the current working directory # has been changed when executing a subprocess. - cls.SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py") + SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py") # Launch a SimpleHTTPServer (serves files in the current directory). # Test cases will request metadata and target files that have been # pre-generated in 'tuf/tests/repository_data', which will be served # by the SimpleHTTPServer launched here. cls.server_process_handler = utils.TestServerProcess( - log=logger, server=cls.SIMPLE_SERVER_PATH + log=logger, server=SIMPLE_SERVER_PATH ) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: # Cleans the resources and flush the logged lines (if any). cls.server_process_handler.clean() @@ -57,7 +60,7 @@ def tearDownClass(cls): # the metadata, targets, and key files generated for the test cases shutil.rmtree(cls.temporary_directory) - def setUp(self): + def setUp(self) -> None: # We are inheriting from custom class. unittest_toolbox.Modified_TestCase.setUp(self) @@ -124,7 +127,7 @@ def setUp(self): target_base_url=self.targets_url, ) - def tearDown(self): + def tearDown(self) -> None: # We are inheriting from custom class. unittest_toolbox.Modified_TestCase.tearDown(self) @@ -132,7 +135,9 @@ def tearDown(self): self.server_process_handler.flush_log() def _modify_repository_root( - self, modification_func, bump_version=False + self, + modification_func: Callable[[Metadata], None], + bump_version: bool = False, ) -> None: """Apply 'modification_func' to root and persist it.""" role_path = os.path.join( @@ -159,13 +164,13 @@ def _modify_repository_root( ) ) - def _assert_files(self, roles: List[str]): + def _assert_files(self, roles: List[str]) -> None: """Assert that local metadata files exist for 'roles'""" expected_files = [f"{role}.json" for role in roles] client_files = sorted(os.listdir(self.client_directory)) self.assertEqual(client_files, expected_files) - def test_refresh_and_download(self): + def test_refresh_and_download(self) -> None: # Test refresh without consistent targets - targets without hash prefix. # top-level targets are already in local cache (but remove others) @@ -179,10 +184,12 @@ def test_refresh_and_download(self): # Get targetinfos, assert that cache does not contain files info1 = self.updater.get_targetinfo("file1.txt") + assert isinstance(info1, TargetFile) self._assert_files(["root", "snapshot", "targets", "timestamp"]) # Get targetinfo for 'file3.txt' listed in the delegated role1 info3 = self.updater.get_targetinfo("file3.txt") + assert isinstance(info3, TargetFile) expected_files = ["role1", "root", "snapshot", "targets", "timestamp"] self._assert_files(expected_files) self.assertIsNone(self.updater.find_cached_target(info1)) @@ -200,7 +207,7 @@ def test_refresh_and_download(self): path = self.updater.find_cached_target(info3) self.assertEqual(path, os.path.join(self.dl_dir, info3.path)) - def test_refresh_with_only_local_root(self): + def test_refresh_with_only_local_root(self) -> None: os.remove(os.path.join(self.client_directory, "timestamp.json")) os.remove(os.path.join(self.client_directory, "snapshot.json")) os.remove(os.path.join(self.client_directory, "targets.json")) @@ -217,7 +224,7 @@ def test_refresh_with_only_local_root(self): expected_files = ["role1", "root", "snapshot", "targets", "timestamp"] self._assert_files(expected_files) - def test_implicit_refresh_with_only_local_root(self): + def test_implicit_refresh_with_only_local_root(self) -> None: os.remove(os.path.join(self.client_directory, "timestamp.json")) os.remove(os.path.join(self.client_directory, "snapshot.json")) os.remove(os.path.join(self.client_directory, "targets.json")) @@ -231,7 +238,7 @@ def test_implicit_refresh_with_only_local_root(self): expected_files = ["role1", "root", "snapshot", "targets", "timestamp"] self._assert_files(expected_files) - def test_both_target_urls_not_set(self): + def test_both_target_urls_not_set(self) -> None: # target_base_url = None and Updater._target_base_url = None updater = ngclient.Updater( self.client_directory, self.metadata_url, self.dl_dir @@ -240,7 +247,7 @@ def test_both_target_urls_not_set(self): with self.assertRaises(ValueError): updater.download_target(info) - def test_no_target_dir_no_filepath(self): + def test_no_target_dir_no_filepath(self) -> None: # filepath = None and Updater.target_dir = None updater = ngclient.Updater(self.client_directory, self.metadata_url) info = TargetFile(1, {"sha256": ""}, "targetpath") @@ -249,15 +256,17 @@ def test_no_target_dir_no_filepath(self): with self.assertRaises(ValueError): updater.download_target(info) - def test_external_targets_url(self): + def test_external_targets_url(self) -> None: self.updater.refresh() info = self.updater.get_targetinfo("file1.txt") + assert isinstance(info, TargetFile) self.updater.download_target(info, target_base_url=self.targets_url) - def test_length_hash_mismatch(self): + def test_length_hash_mismatch(self) -> None: self.updater.refresh() targetinfo = self.updater.get_targetinfo("file1.txt") + assert isinstance(targetinfo, TargetFile) length = targetinfo.length with self.assertRaises(exceptions.RepositoryError): @@ -270,13 +279,13 @@ def test_length_hash_mismatch(self): self.updater.download_target(targetinfo) # pylint: disable=protected-access - def test_updating_root(self): + def test_updating_root(self) -> None: # Bump root version, resign and refresh self._modify_repository_root(lambda root: None, bump_version=True) self.updater.refresh() self.assertEqual(self.updater._trusted_set.root.signed.version, 2) - def test_missing_targetinfo(self): + def test_missing_targetinfo(self) -> None: self.updater.refresh() # Get targetinfo for non-existing file diff --git a/tests/test_updater_top_level_update.py b/tests/test_updater_top_level_update.py index c2647e4978..a5b511a60a 100644 --- a/tests/test_updater_top_level_update.py +++ b/tests/test_updater_top_level_update.py @@ -349,7 +349,7 @@ def test_new_snapshot_unsigned(self) -> None: self._assert_files_exist(["root", "timestamp"]) - def test_new_snapshot_version_mismatch(self): + def test_new_snapshot_version_mismatch(self) -> None: # Check against timestamp role’s snapshot version # Increase snapshot version without updating timestamp @@ -414,7 +414,7 @@ def test_new_targets_unsigned(self) -> None: self._assert_files_exist(["root", "timestamp", "snapshot"]) - def test_new_targets_version_mismatch(self): + def test_new_targets_version_mismatch(self) -> None: # Check against snapshot role’s targets version # Increase targets version without updating snapshot diff --git a/tests/test_updater_with_simulator.py b/tests/test_updater_with_simulator.py index e15df5d2f7..b992a2def1 100644 --- a/tests/test_updater_with_simulator.py +++ b/tests/test_updater_with_simulator.py @@ -16,7 +16,7 @@ from tests import utils from tests.repository_simulator import RepositorySimulator -from tuf.api.metadata import SPECIFICATION_VERSION, Targets +from tuf.api.metadata import SPECIFICATION_VERSION, TargetFile, Targets from tuf.exceptions import BadVersionNumberError, UnsignedMetadataError from tuf.ngclient import Updater @@ -27,7 +27,7 @@ class TestUpdater(unittest.TestCase): # set dump_dir to trigger repository state dumps dump_dir: Optional[str] = None - def setUp(self): + def setUp(self) -> None: # pylint: disable-next=consider-using-with self.temp_dir = tempfile.TemporaryDirectory() self.metadata_dir = os.path.join(self.temp_dir.name, "metadata") @@ -49,7 +49,7 @@ def setUp(self): self.sim.dump_dir = os.path.join(self.dump_dir, name) os.mkdir(self.sim.dump_dir) - def tearDown(self): + def tearDown(self) -> None: self.temp_dir.cleanup() def _run_refresh(self) -> Updater: @@ -67,7 +67,7 @@ def _run_refresh(self) -> Updater: updater.refresh() return updater - def test_refresh(self): + def test_refresh(self) -> None: # Update top level metadata self._run_refresh() @@ -99,7 +99,7 @@ def test_refresh(self): } @utils.run_sub_tests_with_dataset(targets) - def test_targets(self, test_case_data: Tuple[str, bytes, str]): + def test_targets(self, test_case_data: Tuple[str, bytes, str]) -> None: targetpath, content, encoded_path = test_case_data path = os.path.join(self.targets_dir, encoded_path) @@ -117,7 +117,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]): updater = self._run_refresh() # target now exists, is not in cache yet info = updater.get_targetinfo(targetpath) - self.assertIsNotNone(info) + assert info is not None # Test without and with explicit local filepath self.assertIsNone(updater.find_cached_target(info)) self.assertIsNone(updater.find_cached_target(info, path)) @@ -136,7 +136,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]): self.assertEqual(path, updater.find_cached_target(info)) self.assertEqual(path, updater.find_cached_target(info, path)) - def test_fishy_rolenames(self): + def test_fishy_rolenames(self) -> None: roles_to_filenames = { "../a": "..%2Fa.json", "": ".json", @@ -162,7 +162,7 @@ def test_fishy_rolenames(self): for fname in roles_to_filenames.values(): self.assertTrue(fname in local_metadata) - def test_keys_and_signatures(self): + def test_keys_and_signatures(self) -> None: """Example of the two trickiest test areas: keys and root updates""" # Update top level metadata @@ -202,7 +202,7 @@ def test_keys_and_signatures(self): self._run_refresh() - def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self): + def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self) -> None: # Test triggering snapshot rollback check on a newly downloaded snapshot # when the local snapshot is loaded even when there is a hash mismatch # with timestamp.snapshot_meta. @@ -233,7 +233,7 @@ def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self): self._run_refresh() @patch.object(builtins, "open", wraps=builtins.open) - def test_not_loading_targets_twice(self, wrapped_open: MagicMock): + def test_not_loading_targets_twice(self, wrapped_open: MagicMock) -> None: # Do not load targets roles more than once when traversing # the delegations tree diff --git a/tests/utils.py b/tests/utils.py index 15f2892414..e5c251d0f7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,7 +21,7 @@ """ from contextlib import contextmanager -from typing import Dict, Any, Callable +from typing import Any, Callable, Dict, IO, Optional, Callable, List, Iterator import unittest import argparse import errno @@ -48,9 +48,13 @@ # Test runner decorator: Runs the test as a set of N SubTests, # (where N is number of items in dataset), feeding the actual test # function one test case at a time -def run_sub_tests_with_dataset(dataset: DataSet): - def real_decorator(function: Callable[[unittest.TestCase, Any], None]): - def wrapper(test_cls: unittest.TestCase): +def run_sub_tests_with_dataset( + dataset: DataSet +) -> Callable[[Callable], Callable]: + def real_decorator( + function: Callable[[unittest.TestCase, Any], None] + ) -> Callable[[unittest.TestCase], None]: + def wrapper(test_cls: unittest.TestCase) -> None: for case, data in dataset.items(): with test_cls.subTest(case=case): function(test_cls, data) @@ -60,15 +64,15 @@ def wrapper(test_cls: unittest.TestCase): class TestServerProcessError(Exception): - def __init__(self, value="TestServerProcess"): + def __init__(self, value: str="TestServerProcess") -> None: self.value = value - def __str__(self): + def __str__(self) -> str: return repr(self.value) @contextmanager -def ignore_deprecation_warnings(module): +def ignore_deprecation_warnings(module: str) -> Iterator[None]: with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning, @@ -82,13 +86,16 @@ def ignore_deprecation_warnings(module): # but the current blocking connect() seems to work fast on Linux and seems # to at least work on Windows (ECONNREFUSED unfortunately has a 2 second # timeout on Windows) -def wait_for_server(host, server, port, timeout=10): +def wait_for_server(host: str, server: str, port: int, timeout: int=10) -> None: start = time.time() remaining_timeout = timeout succeeded = False while not succeeded and remaining_timeout > 0: try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock: Optional[socket.socket] = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) + assert sock is not None sock.settimeout(remaining_timeout) sock.connect((host, port)) succeeded = True @@ -104,14 +111,14 @@ def wait_for_server(host, server, port, timeout=10): if sock: sock.close() sock = None - remaining_timeout = timeout - (time.time() - start) + remaining_timeout = int(timeout - (time.time() - start)) if not succeeded: raise TimeoutError("Could not connect to the " + server \ + " on port " + str(port) + "!") -def configure_test_logging(argv): +def configure_test_logging(argv: List[str]) -> None: # parse arguments but only handle '-v': argv may contain # other things meant for unittest argument parser parser = argparse.ArgumentParser(add_help=False) @@ -165,13 +172,14 @@ class TestServerProcess(): """ - def __init__(self, log, server='simple_server.py', - timeout=10, popen_cwd=".", extra_cmd_args=None): + def __init__(self, log: logging.Logger, server: str='simple_server.py', + timeout: int=10, popen_cwd: str=".", extra_cmd_args: Optional[List[str]]=None + ): self.server = server self.__logger = log # Stores popped messages from the queue. - self.__logged_messages = [] + self.__logged_messages: List[str] = [] if extra_cmd_args is None: extra_cmd_args = [] @@ -185,7 +193,9 @@ def __init__(self, log, server='simple_server.py', - def _start_server(self, timeout, extra_cmd_args, popen_cwd): + def _start_server( + self, timeout: int, extra_cmd_args: List[str], popen_cwd: str + ) -> None: """ Start the server subprocess and a thread responsible to redirect stdout/stderr to the Queue. @@ -201,7 +211,7 @@ def _start_server(self, timeout, extra_cmd_args, popen_cwd): - def _start_process(self, extra_cmd_args, popen_cwd): + def _start_process(self, extra_cmd_args: List[str], popen_cwd: str) -> None: """Starts the process running the server.""" # The "-u" option forces stdin, stdout and stderr to be unbuffered. @@ -213,7 +223,7 @@ def _start_process(self, extra_cmd_args, popen_cwd): - def _start_redirect_thread(self): + def _start_redirect_thread(self) -> None: """Starts a thread responsible to redirect stdout/stderr to the Queue.""" # Run log_queue_worker() in a thread. @@ -228,7 +238,7 @@ def _start_redirect_thread(self): @staticmethod - def _log_queue_worker(stream, line_queue): + def _log_queue_worker(stream: IO, line_queue: queue.Queue) -> None: """ Worker function to run in a seprate thread. Reads from 'stream', puts lines in a Queue (Queue is thread-safe). @@ -247,7 +257,7 @@ def _log_queue_worker(stream, line_queue): - def _wait_for_port(self, timeout): + def _wait_for_port(self, timeout: int) -> None: """ Validates the first item from the Queue against the port message. If validation is successful, self.port is set. @@ -279,7 +289,7 @@ def _wait_for_port(self, timeout): - def _kill_server_process(self): + def _kill_server_process(self) -> None: """Kills the server subprocess if it's running.""" if self.is_process_running(): @@ -290,7 +300,7 @@ def _kill_server_process(self): - def flush_log(self): + def flush_log(self) -> None: """Flushes the log lines from the logging queue.""" while True: @@ -311,7 +321,7 @@ def flush_log(self): - def clean(self): + def clean(self) -> None: """ Kills the subprocess and closes the TempFile. Calls flush_log to check for logged information, but not yet flushed. @@ -324,5 +334,5 @@ def clean(self): - def is_process_running(self): + def is_process_running(self) -> bool: return True if self.__server_process.poll() is None else False diff --git a/tuf/log.py b/tuf/log.py index 9865b04369..f9ae6c7721 100755 --- a/tuf/log.py +++ b/tuf/log.py @@ -182,7 +182,7 @@ def filter(self, record): -def set_log_level(log_level=_DEFAULT_LOG_LEVEL): +def set_log_level(log_level: int=_DEFAULT_LOG_LEVEL): """ Allow the default log level to be overridden. If 'log_level' is not diff --git a/tuf/unittest_toolbox.py b/tuf/unittest_toolbox.py index 063bec8df6..ac1305918b 100755 --- a/tuf/unittest_toolbox.py +++ b/tuf/unittest_toolbox.py @@ -29,6 +29,7 @@ import random import string +from typing import Optional class Modified_TestCase(unittest.TestCase): """ @@ -70,12 +71,12 @@ def setUp(): """ - def setUp(self): + def setUp(self) -> None: self._cleanup = [] - def tearDown(self): + def tearDown(self) -> None: for cleanup_function in self._cleanup: # Perform clean up by executing clean-up functions. try: @@ -87,7 +88,7 @@ def tearDown(self): - def make_temp_directory(self, directory=None): + def make_temp_directory(self, directory: Optional[str]=None) -> str: """Creates and returns an absolute path of a directory.""" prefix = self.__class__.__name__+'_' @@ -102,7 +103,9 @@ def _destroy_temp_directory(): - def make_temp_file(self, suffix='.txt', directory=None): + def make_temp_file( + self,suffix: str='.txt', directory: Optional[str]=None + ) -> str: """Creates and returns an absolute path of an empty file.""" prefix='tmp_file_'+self.__class__.__name__+'_' temp_file = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=directory) @@ -113,7 +116,9 @@ def _destroy_temp_file(): - def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'): + def make_temp_data_file( + self, suffix: str='', directory: Optional[str]=None, data: str = 'junk data' + ) -> str: """Returns an absolute path of a temp file containing data.""" temp_file_path = self.make_temp_file(suffix=suffix, directory=directory) temp_file = open(temp_file_path, 'wt', encoding='utf8') @@ -123,7 +128,7 @@ def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'): - def random_path(self, length = 7): + def random_path(self, length: int = 7) -> str: """Generate a 'random' path consisting of random n-length strings.""" rand_path = '/' + self.random_string(length) @@ -136,7 +141,7 @@ def random_path(self, length = 7): @staticmethod - def random_string(length=15): + def random_string(length: int=15) -> str: """Generate a random string of specified length.""" rand_str = ''