diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 5ef133bf325..f4117a2a081 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -495,16 +495,14 @@ def update_body_from_data(self, body: Any) -> None: if hdrs.CONTENT_LENGTH not in self.headers: self.headers[hdrs.CONTENT_LENGTH] = str(size) - # set content-type - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in self.skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = body.content_type - # copy payload headers if body.headers: for (key, value) in body.headers.items(): - if key not in self.headers: - self.headers[key] = value + if key in self.headers: + continue + if key in self.skip_auto_headers: + continue + self.headers[key] = value def update_expect_continue(self, expect: bool=False) -> None: if expect: diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 940d5a50e74..b16598eb77b 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -702,9 +702,6 @@ def __init__(self, subtype: str='mixed', super().__init__(None, content_type=ctype) self._parts = [] # type: List[_Part] # noqa - self._headers = CIMultiDict() # type: CIMultiDict[str] - assert self.content_type is not None - self._headers[CONTENT_TYPE] = self.content_type def __enter__(self) -> 'MultipartWriter': return self @@ -769,28 +766,18 @@ def append( headers = CIMultiDict() if isinstance(obj, Payload): - if obj.headers is not None: - obj.headers.update(headers) - else: - if isinstance(headers, CIMultiDict): - obj._headers = headers - else: - obj._headers = CIMultiDict(headers) + obj.headers.update(headers) return self.append_payload(obj) else: try: - return self.append_payload(get_payload(obj, headers=headers)) + payload = get_payload(obj, headers=headers) except LookupError: - raise TypeError + raise TypeError('Cannot create payload from %r' % obj) + else: + return self.append_payload(payload) def append_payload(self, payload: Payload) -> Payload: """Adds a new body part to multipart writer.""" - # content-type - assert payload.headers is not None - if CONTENT_TYPE not in payload.headers: - assert payload.content_type is not None - payload.headers[CONTENT_TYPE] = payload.content_type - # compression encoding = payload.headers.get(CONTENT_ENCODING, '').lower() # type: Optional[str] # noqa if encoding and encoding not in ('deflate', 'gzip', 'identity'): diff --git a/aiohttp/payload.py b/aiohttp/payload.py index a002d1288d1..21b83b62509 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -7,15 +7,13 @@ import warnings from abc import ABC, abstractmethod from itertools import chain -from typing import ( # noqa +from typing import ( IO, TYPE_CHECKING, Any, ByteString, - Callable, Dict, Iterable, - List, Optional, Text, TextIO, @@ -47,6 +45,10 @@ TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB +if TYPE_CHECKING: + from typing import List # noqa + + class LookupError(Exception): pass @@ -120,9 +122,8 @@ def register(self, class Payload(ABC): + _default_content_type = 'application/octet-stream' # type: str _size = None # type: Optional[float] - _headers = None # type: Optional[_CIMultiDict] - _content_type = 'application/octet-stream' # type: Optional[str] def __init__(self, value: Any, @@ -137,18 +138,20 @@ def __init__(self, filename: Optional[str]=None, encoding: Optional[str]=None, **kwargs: Any) -> None: - self._value = value self._encoding = encoding self._filename = filename - if headers is not None: - self._headers = CIMultiDict(headers) - if content_type is sentinel and hdrs.CONTENT_TYPE in self._headers: - content_type = self._headers[hdrs.CONTENT_TYPE] - - if content_type is sentinel: - content_type = None - - self._content_type = content_type + self._headers = CIMultiDict() # type: _CIMultiDict + self._value = value + if content_type is not sentinel and content_type is not None: + self._headers[hdrs.CONTENT_TYPE] = content_type + elif self._filename is not None: + content_type = mimetypes.guess_type(self._filename)[0] + if content_type is None: + content_type = self._default_content_type + self._headers[hdrs.CONTENT_TYPE] = content_type + else: + self._headers[hdrs.CONTENT_TYPE] = self._default_content_type + self._headers.update(headers or {}) @property def size(self) -> Optional[float]: @@ -161,15 +164,12 @@ def filename(self) -> Optional[str]: return self._filename @property - def headers(self) -> Optional[_CIMultiDict]: + def headers(self) -> _CIMultiDict: """Custom item headers""" return self._headers @property def _binary_headers(self) -> bytes: - if self.headers is None: - # FIXME: This case actually is unreachable. - return b'' # pragma: no cover return ''.join( [k + ': ' + v + '\r\n' for k, v in self.headers.items()] ).encode('utf-8') + b'\r\n' @@ -180,24 +180,15 @@ def encoding(self) -> Optional[str]: return self._encoding @property - def content_type(self) -> Optional[str]: + def content_type(self) -> str: """Content type""" - if self._content_type is not None: - return self._content_type - elif self._filename is not None: - mime = mimetypes.guess_type(self._filename)[0] - return 'application/octet-stream' if mime is None else mime - else: - return Payload._content_type + return self._headers[hdrs.CONTENT_TYPE] def set_content_disposition(self, disptype: str, quote_fields: bool=True, **params: Any) -> None: """Sets ``Content-Disposition`` header.""" - if self._headers is None: - self._headers = CIMultiDict() - self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( disptype, quote_fields=quote_fields, **params) @@ -292,7 +283,10 @@ def __init__(self, super().__init__(value, *args, **kwargs) if self._filename is not None and disposition is not None: - self.set_content_disposition(disposition, filename=self._filename) + if hdrs.CONTENT_DISPOSITION not in self.headers: + self.set_content_disposition( + disposition, filename=self._filename + ) async def write(self, writer: AbstractStreamWriter) -> None: loop = asyncio.get_event_loop() diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9d2eedcaa2e..7295b5c83ff 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -854,8 +854,8 @@ async def test_writer_serialize_with_content_encoding_gzip(buf, stream, await writer.write(stream) headers, message = bytes(buf).split(b'\r\n\r\n', 1) - assert (b'--:\r\nContent-Encoding: gzip\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) + assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n' + b'Content-Encoding: gzip' == headers) decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) data = decompressor.decompress(message.split(b'\r\n')[0]) @@ -869,8 +869,8 @@ async def test_writer_serialize_with_content_encoding_deflate(buf, stream, await writer.write(stream) headers, message = bytes(buf).split(b'\r\n\r\n', 1) - assert (b'--:\r\nContent-Encoding: deflate\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) + assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n' + b'Content-Encoding: deflate' == headers) thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n' assert thing == message @@ -883,8 +883,8 @@ async def test_writer_serialize_with_content_encoding_identity(buf, stream, await writer.write(stream) headers, message = bytes(buf).split(b'\r\n\r\n', 1) - assert (b'--:\r\nContent-Encoding: identity\r\n' - b'Content-Type: application/octet-stream\r\n' + assert (b'--:\r\nContent-Type: application/octet-stream\r\n' + b'Content-Encoding: identity\r\n' b'Content-Length: 16' == headers) assert thing == message.split(b'\r\n')[0] @@ -902,8 +902,8 @@ async def test_writer_with_content_transfer_encoding_base64(buf, stream, await writer.write(stream) headers, message = bytes(buf).split(b'\r\n\r\n', 1) - assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n' - b'Content-Type: text/plain; charset=utf-8' == + assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n' + b'Content-Transfer-Encoding: base64' == headers) assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0] @@ -916,8 +916,8 @@ async def test_writer_content_transfer_encoding_quote_printable(buf, stream, await writer.write(stream) headers, message = bytes(buf).split(b'\r\n\r\n', 1) - assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n' - b'Content-Type: text/plain; charset=utf-8' == headers) + assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n' + b'Content-Transfer-Encoding: quoted-printable' == headers) assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0])