From ea6bd62f933126a700633d23ea52a166ce37d1d2 Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Fri, 29 Jul 2022 23:02:24 +0200 Subject: [PATCH] fix: issues with string destination handling in `{Graph,Result}.serialize` Change `{Graph,Result}.serialize` to only handle string destinations as URIs if the schema is `file` and to treat it as operating system paths in all other cases. This is for the following reasons: - `Result.serialize` was treating URI paths as OS paths which only works for some cases, for special charachters and percentage encoding it does not work. - Many valid Unix and Windows paths parse using `urlparse` and have no netloc, e.g. `C:\some\path` and `some:/path`, however they should be treated as paths and not as URIs. - `Graph` and `Result` should behave consistently. Some caveats in this change: - non-file URIs will now be treated as OS paths which may result in slightly confusing error messages, such as `FileNotFoundError: [Errno 2] No such file or directory: 'http://example.com/'` if http://example.com/ is passed. - some valid file URIs (e.g. `file:/path/to/file` from https://datatracker.ietf.org/doc/html/rfc8089) are also valid Posix paths but will be treated as file-URIs instead of Posix paths. For `Graph.serialize` this ambiguity can be avoided by using `pathlib.Path`, but for `Result.serialize` there is currently no way to avoid it, though I will work on https://github.com/RDFLib/rdflib/issues/1834 soon and in that provide a way to avoid the ambiguity there also. --- CHANGELOG.md | 22 ++ rdflib/graph.py | 29 +- rdflib/query.py | 24 +- requirements.flake8.txt | 2 +- setup.py | 2 +- test/test_serializers/test_serializer.py | 214 ++++++++++---- test/test_sparql/test_result.py | 352 ++++++++++++----------- test/utils/destination.py | 54 ++++ test/utils/result.py | 105 +++++++ 9 files changed, 549 insertions(+), 255 deletions(-) create mode 100644 test/utils/destination.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d8d81a722d..11fa854f33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -159,6 +159,28 @@ and will be removed for release. + + + + + + + +- Improve file-URI and path handling in `Graph.serialize` and `Result.serialize` to + address problems with windows path handling in `Result.serialize` and to make + the behavior between `Graph.serialize` and `Result.serialie` more consistent. + Closed [issue #2067](https://github.com/RDFLib/rdflib/issues/2067). + [PR #2068](https://github.com/RDFLib/rdflib/pull/2068). + - String values for the `destination` argument will now only be treated as + file URIs if `urllib.parse.urlparse` returns their schema as `file`. + - Simplified file writing to avoid a temporary file. + + + + + + + diff --git a/rdflib/graph.py b/rdflib/graph.py index fdb6445e22..bf29e3d1ed 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1,9 +1,6 @@ import logging -import os import pathlib import random -import shutil -import tempfile from io import BytesIO from typing import ( IO, @@ -1201,20 +1198,20 @@ def serialize( serializer.serialize(stream, base=base, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): - location = str(destination) + os_path = str(destination) else: location = cast(str, destination) - scheme, netloc, path, params, _query, fragment = urlparse(location) - if netloc != "": - raise ValueError( - f"destination {destination} is not a local file reference" - ) - fd, name = tempfile.mkstemp() - stream = os.fdopen(fd, "wb") - serializer.serialize(stream, base=base, encoding=encoding, **args) - stream.close() - dest = url2pathname(path) if scheme == "file" else location - shutil.move(name, dest) + scheme, netloc, path, params, _query, fragment = urlparse(location) + if scheme == "file": + if netloc != "": + raise ValueError( + f"the file URI {location!r} has an authority component which is not supported" + ) + os_path = url2pathname(path) + else: + os_path = location + with open(os_path, "wb") as stream: + serializer.serialize(stream, encoding=encoding, **args) return self def print(self, format="turtle", encoding="utf-8", out=None): @@ -1276,7 +1273,7 @@ def parse( ... ... ... ''' - >>> import tempfile + >>> import os, tempfile >>> fd, file_name = tempfile.mkstemp() >>> f = os.fdopen(fd, "w") >>> dummy = f.write(my_data) # Returns num bytes written diff --git a/rdflib/query.py b/rdflib/query.py index fcda6f3bd8..c4c6e33b5e 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -1,12 +1,10 @@ import itertools -import os -import shutil -import tempfile import types import warnings from io import BytesIO from typing import IO, TYPE_CHECKING, List, Optional, Union, cast from urllib.parse import urlparse +from urllib.request import url2pathname __all__ = [ "Processor", @@ -267,16 +265,16 @@ def serialize( else: location = cast(str, destination) scheme, netloc, path, params, query, fragment = urlparse(location) - if netloc != "": - print( - "WARNING: not saving as location" + "is not a local file reference" - ) - return None - fd, name = tempfile.mkstemp() - stream = os.fdopen(fd, "wb") - serializer.serialize(stream, encoding=encoding, **args) - stream.close() - shutil.move(name, path) + if scheme == "file": + if netloc != "": + raise ValueError( + f"the file URI {location!r} has an authority component which is not supported" + ) + os_path = url2pathname(path) + else: + os_path = location + with open(os_path, "wb") as stream: + serializer.serialize(stream, encoding=encoding, **args) return None def __len__(self): diff --git a/requirements.flake8.txt b/requirements.flake8.txt index a948a79ae7..535436d692 100644 --- a/requirements.flake8.txt +++ b/requirements.flake8.txt @@ -1,3 +1,3 @@ flake8 -flakeheaven; python_version >= '3.8.0' +flakeheaven >= 2.1.3; python_version >= '3.8.0' pep8-naming diff --git a/setup.py b/setup.py index 45eac275ae..e487822cae 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "dev": [ "black==22.6.0", "flake8", - "flakeheaven; python_version >= '3.8.0'", + "flakeheaven >= 2.1.3; python_version >= '3.8.0'", "isort", "mypy", "pep8-naming", diff --git a/test/test_serializers/test_serializer.py b/test/test_serializers/test_serializer.py index 60da3c51a9..db4b167e00 100644 --- a/test/test_serializers/test_serializer.py +++ b/test/test_serializers/test_serializer.py @@ -2,11 +2,13 @@ import itertools import logging import re -from contextlib import contextmanager +import socket +from contextlib import ExitStack from dataclasses import dataclass, field from functools import lru_cache -from pathlib import Path, PurePath +from pathlib import Path, PosixPath, PurePath from test.utils import GraphHelper, get_unique_plugins +from test.utils.destination import DestinationType, DestParmType, DestRef from typing import ( IO, Callable, @@ -15,10 +17,12 @@ List, Optional, Set, + TextIO, Tuple, Union, cast, ) +from urllib.parse import urlsplit, urlunsplit import pytest from _pytest.mark.structures import Mark, MarkDecorator, ParameterSet @@ -154,11 +158,10 @@ def test_serialize_to_path(tmp_path: Path, simple_graph: Graph): def test_serialize_to_neturl(simple_graph: Graph): - with pytest.raises(ValueError) as raised: + with pytest.raises(FileNotFoundError): simple_graph.serialize( destination="http://example.com/", format="nt", encoding="utf-8" ) - assert "destination" in f"{raised.value}" def test_serialize_to_fileurl(tmp_path: Path, simple_graph: Graph): @@ -182,35 +185,23 @@ def test_serialize_badformat(simple_graph: Graph) -> None: assert "badformat" in f"{ctx.value}" -@dataclass(frozen=True) -class DestRef: - param: Union[Path, PurePath, str, IO[bytes]] - path: Path - - -class DestinationType(str, enum.Enum): - PATH = enum.auto() - PURE_PATH = enum.auto() - STR_PATH = enum.auto() - BINARY_IO = enum.auto() - RETURN = enum.auto() - - @contextmanager - def make_ref(self, tmp_path: Path) -> Generator[Optional[DestRef], None, None]: - path = tmp_path / f"file-{self.name}" - if self is DestinationType.RETURN: - yield None - elif self is DestinationType.PATH: - yield DestRef(path, path) - elif self is DestinationType.PURE_PATH: - yield DestRef(PurePath(path), path) - elif self is DestinationType.STR_PATH: - yield DestRef(f"{path}", path) - elif self is DestinationType.BINARY_IO: - with path.open("wb") as bfh: - yield DestRef(bfh, path) - else: - raise ValueError(f"unsupported type {type!r}") +DESTINATION_TYPES = { + DestinationType.RETURN, + DestinationType.PATH, + DestinationType.PURE_PATH, + DestinationType.STR_PATH, + DestinationType.FILE_URI, + DestinationType.BINARY_IO, +} + + +GraphDestParamType = Union[Path, PurePath, str, IO[bytes]] + + +def narrow_dest_param(param: DestParmType) -> GraphDestParamType: + assert not (hasattr(param, "write") and hasattr(param, "encoding")) + assert not isinstance(param, TextIO) + return param class GraphType(str, enum.Enum): @@ -382,9 +373,8 @@ def make_serialize_parse_tests() -> Generator[ParameterSet, None, None]: Tuple[str, GraphType, DestinationType, Optional[str]], Union[MarkDecorator, Mark], ] = {} - destination_types = set(DestinationType) for serializer_name, destination_type in itertools.product( - serializer_dict.keys(), destination_types + serializer_dict.keys(), DESTINATION_TYPES ): format = serializer_dict[serializer_name] encodings: Set[Optional[str]] = {*format.info.encodings, None} @@ -456,7 +446,7 @@ def test_serialize_parse( else: raise ValueError(f"graph_type {graph_type!r} is not supported") with destination_type.make_ref(tmp_path) as dest_ref: - destination = None if dest_ref is None else dest_ref.param + destination = None if dest_ref is None else narrow_dest_param(dest_ref.param) serialize_result = graph.serialize( destination=destination, format=serializer_name, @@ -507,6 +497,12 @@ def dest_ref(self) -> DestRef: raise RuntimeError("dest_ref is None") return self.opt_dest_ref + @property + def dest_param(self) -> GraphDestParamType: + if self.opt_dest_ref is None: + raise RuntimeError("dest_ref is None") + return narrow_dest_param(self.opt_dest_ref.param) + SerializeFunctionType = Callable[[Graph, SerializeArgs], SerializeResultType] StrSerializeFunctionType = Callable[[Graph, SerializeArgs], str] @@ -536,23 +532,17 @@ def dest_ref(self) -> DestRef: file_serialize_functions: List[FileSerializeFunctionType] = [ - lambda graph, args: graph.serialize(args.dest_ref.param), - lambda graph, args: graph.serialize(args.dest_ref.param, encoding=None), - lambda graph, args: graph.serialize(args.dest_ref.param, encoding="utf-8"), - lambda graph, args: graph.serialize(args.dest_ref.param, args.format), - lambda graph, args: graph.serialize( - args.dest_ref.param, args.format, encoding=None - ), - lambda graph, args: graph.serialize(args.dest_ref.param, args.format, None, None), - lambda graph, args: graph.serialize( - args.dest_ref.param, args.format, encoding="utf-8" - ), - lambda graph, args: graph.serialize( - args.dest_ref.param, args.format, None, encoding="utf-8" - ), + lambda graph, args: graph.serialize(args.dest_param), + lambda graph, args: graph.serialize(args.dest_param, encoding=None), + lambda graph, args: graph.serialize(args.dest_param, encoding="utf-8"), + lambda graph, args: graph.serialize(args.dest_param, args.format), + lambda graph, args: graph.serialize(args.dest_param, args.format, encoding=None), + lambda graph, args: graph.serialize(args.dest_param, args.format, None, None), + lambda graph, args: graph.serialize(args.dest_param, args.format, encoding="utf-8"), lambda graph, args: graph.serialize( - args.dest_ref.param, args.format, None, "utf-8" + args.dest_param, args.format, None, encoding="utf-8" ), + lambda graph, args: graph.serialize(args.dest_param, args.format, None, "utf-8"), ] @@ -585,7 +575,6 @@ def test_serialize_overloads( serialize_function: SerializeFunctionType, ) -> None: format = GraphFormat.TURTLE - serializer = format.info.serializer with destination_type.make_ref(tmp_path) as dest_ref: serialize_result = serialize_function( @@ -603,3 +592,124 @@ def test_serialize_overloads( serialized_data = dest_ref.path.read_text(encoding="utf-8") check_serialized(format, simple_graph, serialized_data) + + +def make_test_serialize_to_strdest_tests() -> Generator[ParameterSet, None, None]: + destination_types: Set[DestinationType] = { + DestinationType.FILE_URI, + DestinationType.STR_PATH, + } + name_prefixes = [ + r"a_b-", + r"a%b-", + r"a%20b-", + r"a b-", + r"a b-", + r"a@b", + r"a$b", + r"a!b", + ] + if isinstance(Path.cwd(), PosixPath): + # not valid on windows https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#naming-conventions + name_prefixes.extend( + [ + r"a:b-", + r"a|b", + ] + ) + for destination_type, name_prefix in itertools.product( + destination_types, name_prefixes + ): + yield pytest.param( + destination_type, + name_prefix, + id=f"{destination_type.name}-{name_prefix}", + ) + + +@pytest.mark.parametrize( + ["destination_type", "name_prefix"], + make_test_serialize_to_strdest_tests(), +) +def test_serialize_to_strdest( + tmp_path: Path, + simple_graph: Graph, + destination_type: DestinationType, + name_prefix: str, +) -> None: + """ + Serialization works correctly with the given arguments and generates output + that can be parsed to a graph that is identical to the original graph. + """ + format = GraphFormat.TURTLE + encoding = "utf-8" + + def path_factory( + tmp_path: Path, type: DestinationType, encoding: Optional[str] + ) -> Path: + return tmp_path / f"{name_prefix}file-{type.name}-{encoding}" + + with destination_type.make_ref( + tmp_path, + encoding=encoding, + path_factory=path_factory, + ) as dest_ref: + assert dest_ref is not None + destination = narrow_dest_param(dest_ref.param) + serialize_result = simple_graph.serialize( + destination=destination, + format=format.info.serializer, + encoding=encoding, + ) + + logging.debug("serialize_result = %r, dest_ref = %s", serialize_result, dest_ref) + + assert isinstance(serialize_result, Graph) + assert dest_ref.path.exists() + serialized_data = dest_ref.path.read_bytes().decode( + "utf-8" if encoding is None else encoding + ) + + logging.debug("serialized_data = %s", serialized_data) + check_serialized(format, simple_graph, serialized_data) + + +@pytest.mark.parametrize( + ["authority"], + [ + ("localhost",), + ("127.0.0.1",), + ("example.com",), + (socket.gethostname(),), + (socket.getfqdn(),), + ], +) +def test_serialize_to_fileuri_with_authortiy( + tmp_path: Path, + simple_graph: Graph, + authority: str, +) -> None: + """ + Serializing to a file URI with authority raises an error. + """ + destination_type = DestinationType.FILE_URI + format = GraphFormat.TURTLE + + with ExitStack() as exit_stack: + dest_ref = exit_stack.enter_context( + destination_type.make_ref( + tmp_path, + ) + ) + assert dest_ref is not None + assert isinstance(dest_ref.param, str) + urlparts = urlsplit(dest_ref.param)._replace(netloc=authority) + use_url = urlunsplit(urlparts) + logging.debug("use_url = %s", use_url) + catcher = exit_stack.enter_context(pytest.raises(ValueError)) + simple_graph.serialize( + destination=use_url, + format=format.info.serializer, + ) + assert False # this should never happen as serialize should always fail + assert catcher.value is not None diff --git a/test/test_sparql/test_result.py b/test/test_sparql/test_result.py index 6e46f8f4d0..8f211e3371 100644 --- a/test/test_sparql/test_result.py +++ b/test/test_sparql/test_result.py @@ -3,12 +3,17 @@ import itertools import logging import re -import sys -from contextlib import contextmanager -from dataclasses import dataclass +import socket +from contextlib import ExitStack from io import BytesIO, StringIO -from pathlib import Path -from test.utils.result import ResultType +from pathlib import Path, PosixPath, PurePath +from test.utils.destination import DestinationType, DestParmType +from test.utils.result import ( + ResultFormat, + ResultFormatInfo, + ResultFormatTrait, + ResultType, +) from typing import ( IO, BinaryIO, @@ -24,6 +29,7 @@ Type, Union, ) +from urllib.parse import urlsplit, urlunsplit import pytest from _pytest.mark.structures import Mark, MarkDecorator, ParameterSet @@ -151,117 +157,29 @@ def check_serialized(format: str, result: Result, data: str) -> None: assert result == parsed_result -class ResultFormatTrait(enum.Enum): - HAS_SERIALIZER = enum.auto() - HAS_PARSER = enum.auto() - - -@dataclass(frozen=True) -class ResultFormat: - name: str - supported_types: Set[ResultType] - traits: Set[ResultFormatTrait] - encodings: Set[str] - - -class ResultFormats(Dict[str, ResultFormat]): - @classmethod - def make(cls, *result_format: ResultFormat) -> "ResultFormats": - result = cls() - for item in result_format: - result[item.name] = item - return result - - -result_formats = ResultFormats.make( - ResultFormat( - "csv", - {ResultType.SELECT}, - { - ResultFormatTrait.HAS_PARSER, - ResultFormatTrait.HAS_SERIALIZER, - }, - {"utf-8", "utf-16"}, - ), - ResultFormat( - "txt", - {ResultType.SELECT}, - { - ResultFormatTrait.HAS_SERIALIZER, - }, - {"utf-8"}, - ), - ResultFormat( - "json", - {ResultType.SELECT}, - { - ResultFormatTrait.HAS_PARSER, - ResultFormatTrait.HAS_SERIALIZER, - }, - {"utf-8", "utf-16"}, - ), - ResultFormat( - "xml", - {ResultType.SELECT}, - { - ResultFormatTrait.HAS_PARSER, - ResultFormatTrait.HAS_SERIALIZER, - }, - {"utf-8", "utf-16"}, - ), - ResultFormat( - "tsv", - {ResultType.SELECT}, - { - ResultFormatTrait.HAS_PARSER, - }, - {"utf-8", "utf-16"}, - ), -) - - -class DestinationType(enum.Enum): +class SourceType(enum.Enum): TEXT_IO = enum.auto() BINARY_IO = enum.auto() - STR_PATH = enum.auto() -class SourceType(enum.Enum): - TEXT_IO = enum.auto() - BINARY_IO = enum.auto() +DESTINATION_TYPES = { + DestinationType.TEXT_IO, + DestinationType.BINARY_IO, + DestinationType.STR_PATH, + DestinationType.FILE_URI, + DestinationType.RETURN, +} +ResultDestParamType = Union[str, IO[bytes], TextIO] -@dataclass(frozen=True) -class DestRef: - param: Union[str, IO[bytes], TextIO] - path: Path - - -@contextmanager -def make_dest( - tmp_path: Path, type: Optional[DestinationType], encoding: str -) -> Iterator[Optional[DestRef]]: - if type is None: - yield None - return - path = tmp_path / f"file-{type}" - if type is DestinationType.STR_PATH: - yield DestRef(f"{path}", path) - elif type is DestinationType.BINARY_IO: - with path.open("wb") as bfh: - yield DestRef(bfh, path) - elif type is DestinationType.TEXT_IO: - assert encoding is not None - with path.open("w", encoding=encoding) as fh: - yield DestRef(fh, path) - else: - raise ValueError(f"unsupported type {type}") + +def narrow_dest_param(param: DestParmType) -> ResultDestParamType: + assert not isinstance(param, PurePath) + return param def make_select_result_serialize_parse_tests() -> Iterator[ParameterSet]: - xfails: Dict[ - Tuple[str, Optional[DestinationType], str], Union[MarkDecorator, Mark] - ] = { + xfails: Dict[Tuple[str, DestinationType, str], Union[MarkDecorator, Mark]] = { ("csv", DestinationType.TEXT_IO, "utf-8"): pytest.mark.xfail(raises=TypeError), ("csv", DestinationType.TEXT_IO, "utf-16"): pytest.mark.xfail(raises=TypeError), ("json", DestinationType.TEXT_IO, "utf-8"): pytest.mark.xfail(raises=TypeError), @@ -271,77 +189,47 @@ def make_select_result_serialize_parse_tests() -> Iterator[ParameterSet]: ("txt", DestinationType.BINARY_IO, "utf-8"): pytest.mark.xfail( raises=TypeError ), - ("txt", DestinationType.BINARY_IO, "utf-16"): pytest.mark.xfail( - raises=TypeError - ), ("txt", DestinationType.STR_PATH, "utf-8"): pytest.mark.xfail(raises=TypeError), - ("txt", DestinationType.STR_PATH, "utf-16"): pytest.mark.xfail( - raises=TypeError - ), + ("txt", DestinationType.FILE_URI, "utf-8"): pytest.mark.xfail(raises=TypeError), } - if sys.platform == "win32": - xfails[("csv", DestinationType.STR_PATH, "utf-8")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - xfails[("csv", DestinationType.STR_PATH, "utf-16")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - xfails[("json", DestinationType.STR_PATH, "utf-8")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - xfails[("json", DestinationType.STR_PATH, "utf-16")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - xfails[("xml", DestinationType.STR_PATH, "utf-8")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - xfails[("xml", DestinationType.STR_PATH, "utf-16")] = pytest.mark.xfail( - raises=FileNotFoundError, - reason="string path handling does not work on windows", - ) - formats = [ - format - for format in result_formats.values() - if ResultFormatTrait.HAS_SERIALIZER in format.traits - and ResultType.SELECT in format.supported_types + format_infos = [ + format_info + for format_info in ResultFormat.info_set() + if ResultFormatTrait.HAS_SERIALIZER in format_info.traits + and ResultType.SELECT in format_info.supported_types ] - destination_types: Set[Optional[DestinationType]] = {None} - destination_types.update(set(DestinationType)) - for format, destination_type in itertools.product(formats, destination_types): - for encoding in format.encodings: - xfail = xfails.get((format.name, destination_type, encoding)) + for format_info, destination_type in itertools.product( + format_infos, DESTINATION_TYPES + ): + for encoding in format_info.encodings: + xfail = xfails.get((format_info.name, destination_type, encoding)) marks = (xfail,) if xfail is not None else () yield pytest.param( - (format, destination_type, encoding), - id=f"{format.name}-{None if destination_type is None else destination_type.name}-{encoding}", + (format_info, destination_type, encoding), + id=f"{format_info.name}-{destination_type.name}-{encoding}", marks=marks, ) @pytest.mark.parametrize( - ["args"], + ["test_args"], make_select_result_serialize_parse_tests(), ) def test_select_result_serialize_parse( tmp_path: Path, select_result: Result, - args: Tuple[ResultFormat, Optional[DestinationType], str], + test_args: Tuple[ResultFormatInfo, DestinationType, str], ) -> None: """ Round tripping of a select query through the serializer and parser of a specific format results in an equivalent result object. """ - format, destination_type, encoding = args - with make_dest(tmp_path, destination_type, encoding) as dest_ref: - destination = None if dest_ref is None else dest_ref.param + format_info, destination_type, encoding = test_args + with destination_type.make_ref(tmp_path, encoding) as dest_ref: + destination = None if dest_ref is None else narrow_dest_param(dest_ref.param) serialize_result = select_result.serialize( destination=destination, - format=format.name, + format=format_info.name, encoding=encoding, ) @@ -354,7 +242,7 @@ def test_select_result_serialize_parse( serialized_data = dest_bytes.decode(encoding) logging.debug("serialized_data = %s", serialized_data) - check_serialized(format.name, select_result, serialized_data) + check_serialized(format_info.name, select_result, serialized_data) def serialize_select(select_result: Result, format: str, encoding: str) -> bytes: @@ -376,11 +264,11 @@ def serialize_select(select_result: Result, format: str, encoding: str) -> bytes def make_select_result_parse_serialized_tests() -> Iterator[ParameterSet]: xfails: Dict[Tuple[str, Optional[SourceType], str], Union[MarkDecorator, Mark]] = {} - formats = [ - format - for format in result_formats.values() - if ResultFormatTrait.HAS_PARSER in format.traits - and ResultType.SELECT in format.supported_types + format_infos = [ + format_info + for format_info in ResultFormat.info_set() + if ResultFormatTrait.HAS_PARSER in format_info.traits + and ResultType.SELECT in format_info.supported_types ] source_types = set(SourceType) xfails[("csv", SourceType.BINARY_IO, "utf-16")] = pytest.mark.xfail( @@ -392,32 +280,32 @@ def make_select_result_parse_serialized_tests() -> Iterator[ParameterSet]: xfails[("tsv", SourceType.BINARY_IO, "utf-16")] = pytest.mark.xfail( raises=UnicodeDecodeError, ) - for format, destination_type in itertools.product(formats, source_types): - for encoding in format.encodings: - xfail = xfails.get((format.name, destination_type, encoding)) + for format_info, destination_type in itertools.product(format_infos, source_types): + for encoding in format_info.encodings: + xfail = xfails.get((format_info.format, destination_type, encoding)) marks = (xfail,) if xfail is not None else () yield pytest.param( - (format, destination_type, encoding), - id=f"{format.name}-{None if destination_type is None else destination_type.name}-{encoding}", + (format_info, destination_type, encoding), + id=f"{format_info.name}-{None if destination_type is None else destination_type.name}-{encoding}", marks=marks, ) @pytest.mark.parametrize( - ["args"], + ["test_args"], make_select_result_parse_serialized_tests(), ) def test_select_result_parse_serialized( tmp_path: Path, select_result: Result, - args: Tuple[ResultFormat, SourceType, str], + test_args: Tuple[ResultFormatInfo, SourceType, str], ) -> None: """ Parsing a serialized result produces the expected result object. """ - format, source_type, encoding = args + format_info, source_type, encoding = test_args - serialized_data = serialize_select(select_result, format.name, encoding) + serialized_data = serialize_select(select_result, format_info.name, encoding) logging.debug("serialized_data = %s", serialized_data.decode(encoding)) @@ -429,6 +317,126 @@ def test_select_result_parse_serialized( else: raise ValueError(f"Invalid source_type {source_type}") - parsed_result = Result.parse(source, format=format.name) + parsed_result = Result.parse(source, format=format_info.name) assert select_result == parsed_result + + +def make_test_serialize_to_strdest_tests() -> Iterator[ParameterSet]: + destination_types: Set[DestinationType] = { + DestinationType.FILE_URI, + DestinationType.STR_PATH, + } + name_prefixes = [ + r"a_b-", + r"a%b-", + r"a%20b-", + r"a b-", + r"a b-", + r"a@b", + r"a$b", + r"a!b", + ] + if isinstance(Path.cwd(), PosixPath): + # not valid on windows https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#naming-conventions + name_prefixes.extend( + [ + r"a:b-", + r"a|b", + ] + ) + for destination_type, name_prefix in itertools.product( + destination_types, name_prefixes + ): + yield pytest.param( + destination_type, + name_prefix, + id=f"{destination_type.name}-{name_prefix}", + ) + + +@pytest.mark.parametrize( + ["destination_type", "name_prefix"], + make_test_serialize_to_strdest_tests(), +) +def test_serialize_to_strdest( + tmp_path: Path, + select_result: Result, + destination_type: DestinationType, + name_prefix: str, +) -> None: + """ + Various ways of specifying the destination argument of ``Result.serialize`` + as a string works correctly. + """ + format_info = ResultFormat.JSON.info + encoding = "utf-8" + + def path_factory( + tmp_path: Path, type: DestinationType, encoding: Optional[str] + ) -> Path: + return tmp_path / f"{name_prefix}file-{type.name}-{encoding}" + + with destination_type.make_ref( + tmp_path, + encoding=encoding, + path_factory=path_factory, + ) as dest_ref: + assert dest_ref is not None + destination = narrow_dest_param(dest_ref.param) + serialize_result = select_result.serialize( + destination=destination, + format=format_info.name, + encoding=encoding, + ) + + assert serialize_result is None + dest_bytes = dest_ref.path.read_bytes() + serialized_data = dest_bytes.decode(encoding) + + logging.debug("serialized_data = %s", serialized_data) + check_serialized(format_info.name, select_result, serialized_data) + + +@pytest.mark.parametrize( + ["authority"], + [ + ("localhost",), + ("127.0.0.1",), + ("example.com",), + (socket.gethostname(),), + (socket.getfqdn(),), + ], +) +def test_serialize_to_fileuri_with_authortiy( + tmp_path: Path, + select_result: Result, + authority: str, +) -> None: + """ + Serializing to a file URI with authority raises an error. + """ + destination_type = DestinationType.FILE_URI + format_info = ResultFormat.JSON.info + encoding = "utf-8" + + with ExitStack() as exit_stack: + dest_ref = exit_stack.enter_context( + destination_type.make_ref( + tmp_path, + encoding=encoding, + ) + ) + assert dest_ref is not None + assert isinstance(dest_ref.param, str) + urlparts = urlsplit(dest_ref.param)._replace(netloc=authority) + use_url = urlunsplit(urlparts) + logging.debug("use_url = %s", use_url) + catcher = exit_stack.enter_context(pytest.raises(ValueError)) + select_result.serialize( + destination=use_url, + format=format_info.name, + encoding=encoding, + ) + assert False # this should never happen as serialize should always fail + assert catcher.value is not None diff --git a/test/utils/destination.py b/test/utils/destination.py new file mode 100644 index 0000000000..ad767d1c47 --- /dev/null +++ b/test/utils/destination.py @@ -0,0 +1,54 @@ +import enum +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path, PurePath +from typing import IO, Callable, Generator, Optional, TextIO, Union + +DestParmType = Union[Path, PurePath, str, IO[bytes], TextIO] + + +@dataclass(frozen=True) +class DestRef: + param: DestParmType + path: Path + + +class DestinationType(str, enum.Enum): + RETURN = enum.auto() + PATH = enum.auto() + PURE_PATH = enum.auto() + STR_PATH = enum.auto() + FILE_URI = enum.auto() + BINARY_IO = enum.auto() + TEXT_IO = enum.auto() + + @contextmanager + def make_ref( + self, + tmp_path: Path, + encoding: Optional[str] = None, + path_factory: Callable[[Path, "DestinationType", Optional[str]], Path] = ( + lambda tmp_path, type, encoding: tmp_path / f"file-{type.name}-{encoding}" + ), + ) -> Generator[Optional[DestRef], None, None]: + path = path_factory(tmp_path, self, encoding) + # path = tmp_path / f"file-{self.name}" + if self is DestinationType.RETURN: + yield None + elif self is DestinationType.PATH: + yield DestRef(path, path) + elif self is DestinationType.PURE_PATH: + yield DestRef(PurePath(path), path) + elif self is DestinationType.STR_PATH: + yield DestRef(f"{path}", path) + elif self is DestinationType.FILE_URI: + yield DestRef(path.as_uri(), path) + elif self is DestinationType.BINARY_IO: + with path.open("wb") as bfh: + yield DestRef(bfh, path) + elif self is DestinationType.TEXT_IO: + assert encoding is not None + with path.open("w", encoding=encoding) as fh: + yield DestRef(fh, path) + else: + raise ValueError(f"unsupported type {type!r}") diff --git a/test/utils/result.py b/test/utils/result.py index 98e31b98f9..57f1f460e9 100644 --- a/test/utils/result.py +++ b/test/utils/result.py @@ -145,3 +145,108 @@ def assert_bindings_sequences_equal( assert lhs_only == [] assert rhs_only == [] assert (len(common) == len(lhs)) and (len(common) == len(rhs)) + + +ResultFormatInfoDict = Dict["ResultFormat", "ResultFormatInfo"] + + +class ResultFormatTrait(enum.Enum): + HAS_SERIALIZER = enum.auto() + HAS_PARSER = enum.auto() + + +class ResultFormat(str, enum.Enum): + CSV = "csv" + TXT = "txt" + JSON = "json" + XML = "xml" + TSV = "tsv" + + @classmethod + @lru_cache(maxsize=None) + def info_dict(cls) -> "ResultFormatInfoDict": + return ResultFormatInfo.make_dict( + ResultFormatInfo( + ResultFormat.CSV, + frozenset({ResultType.SELECT}), + frozenset( + { + ResultFormatTrait.HAS_PARSER, + ResultFormatTrait.HAS_SERIALIZER, + } + ), + frozenset({"utf-8", "utf-16"}), + ), + ResultFormatInfo( + ResultFormat.TXT, + frozenset({ResultType.SELECT}), + frozenset( + { + ResultFormatTrait.HAS_SERIALIZER, + } + ), + frozenset({"utf-8"}), + ), + ResultFormatInfo( + ResultFormat.JSON, + frozenset({ResultType.SELECT}), + frozenset( + { + ResultFormatTrait.HAS_PARSER, + ResultFormatTrait.HAS_SERIALIZER, + } + ), + frozenset({"utf-8", "utf-16"}), + ), + ResultFormatInfo( + ResultFormat.XML, + frozenset({ResultType.SELECT}), + frozenset( + { + ResultFormatTrait.HAS_PARSER, + ResultFormatTrait.HAS_SERIALIZER, + } + ), + frozenset({"utf-8", "utf-16"}), + ), + ResultFormatInfo( + ResultFormat.TSV, + frozenset({ResultType.SELECT}), + frozenset( + { + ResultFormatTrait.HAS_PARSER, + } + ), + frozenset({"utf-8", "utf-16"}), + ), + ) + + @property + def info(self) -> "ResultFormatInfo": + return self.info_dict()[self] + + @classmethod + @lru_cache(maxsize=None) + def set(cls) -> Set["ResultFormat"]: + return set(cls) + + @classmethod + @lru_cache(maxsize=None) + def info_set(cls) -> Set["ResultFormatInfo"]: + return {format.info for format in cls.set()} + + +@dataclass(frozen=True) +class ResultFormatInfo: + format: ResultFormat + supported_types: FrozenSet[ResultType] + traits: FrozenSet[ResultFormatTrait] + encodings: FrozenSet[str] + + @classmethod + def make_dict(cls, *items: "ResultFormatInfo") -> ResultFormatInfoDict: + return dict((info.format, info) for info in items) + + @property + def name(self) -> "str": + return f"{self.format.value}"