Skip to content

Commit

Permalink
Refactor the way how Payload headers are handled
Browse files Browse the repository at this point in the history
This change actually solves three separated, but heavy coupled
issues:

1. `Payload.content_type` may conflict with `Payload.headers[CONTENT_TYPE]`.

While in the end priority goes to the former one, it seems quite strange that
Payload object may have dual state about what content type it contains.

2.IOPayload respects Content-Disposition which comes with headers.

3. ClientRequest.skip_autoheaders now filters Payload.headers as well.

This issue was eventually found due to refactoring: Payload object
may setup some autoheaders, but those will bypass skip logic.
  • Loading branch information
kxepal committed Jan 4, 2019
1 parent 4c7bc9f commit ba5bf63
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 66 deletions.
12 changes: 5 additions & 7 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,16 +495,14 @@ def update_body_from_data(self, body: Any) -> None:
if hdrs.CONTENT_LENGTH not in self.headers:
self.headers[hdrs.CONTENT_LENGTH] = str(size)

# set content-type
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in self.skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = body.content_type

# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
if key not in self.headers:
self.headers[key] = value
if key in self.headers:
continue
if key in self.skip_auto_headers:
continue
self.headers[key] = value

def update_expect_continue(self, expect: bool=False) -> None:
if expect:
Expand Down
23 changes: 5 additions & 18 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,6 @@ def __init__(self, subtype: str='mixed',
super().__init__(None, content_type=ctype)

self._parts = [] # type: List[_Part] # noqa
self._headers = CIMultiDict() # type: CIMultiDict[str]
assert self.content_type is not None
self._headers[CONTENT_TYPE] = self.content_type

def __enter__(self) -> 'MultipartWriter':
return self
Expand Down Expand Up @@ -769,28 +766,18 @@ def append(
headers = CIMultiDict()

if isinstance(obj, Payload):
if obj.headers is not None:
obj.headers.update(headers)
else:
if isinstance(headers, CIMultiDict):
obj._headers = headers
else:
obj._headers = CIMultiDict(headers)
obj.headers.update(headers)
return self.append_payload(obj)
else:
try:
return self.append_payload(get_payload(obj, headers=headers))
payload = get_payload(obj, headers=headers)
except LookupError:
raise TypeError
raise TypeError('Cannot create payload from %r' % obj)
else:
return self.append_payload(payload)

def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# content-type
assert payload.headers is not None
if CONTENT_TYPE not in payload.headers:
assert payload.content_type is not None
payload.headers[CONTENT_TYPE] = payload.content_type

# compression
encoding = payload.headers.get(CONTENT_ENCODING, '').lower() # type: Optional[str] # noqa
if encoding and encoding not in ('deflate', 'gzip', 'identity'):
Expand Down
56 changes: 25 additions & 31 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from typing import ( # noqa
from typing import (
IO,
TYPE_CHECKING,
Any,
ByteString,
Callable,
Dict,
Iterable,
List,
Optional,
Text,
TextIO,
Expand Down Expand Up @@ -47,6 +45,10 @@
TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB


if TYPE_CHECKING:
from typing import List # noqa


class LookupError(Exception):
pass

Expand Down Expand Up @@ -120,9 +122,8 @@ def register(self,

class Payload(ABC):

_default_content_type = 'application/octet-stream' # type: str
_size = None # type: Optional[float]
_headers = None # type: Optional[_CIMultiDict]
_content_type = 'application/octet-stream' # type: Optional[str]

def __init__(self,
value: Any,
Expand All @@ -137,18 +138,20 @@ def __init__(self,
filename: Optional[str]=None,
encoding: Optional[str]=None,
**kwargs: Any) -> None:
self._value = value
self._encoding = encoding
self._filename = filename
if headers is not None:
self._headers = CIMultiDict(headers)
if content_type is sentinel and hdrs.CONTENT_TYPE in self._headers:
content_type = self._headers[hdrs.CONTENT_TYPE]

if content_type is sentinel:
content_type = None

self._content_type = content_type
self._headers = CIMultiDict() # type: _CIMultiDict
self._value = value
if content_type is not sentinel and content_type is not None:
self._headers[hdrs.CONTENT_TYPE] = content_type
elif self._filename is not None:
content_type = mimetypes.guess_type(self._filename)[0]
if content_type is None:
content_type = self._default_content_type
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
self._headers.update(headers or {})

@property
def size(self) -> Optional[float]:
Expand All @@ -161,15 +164,12 @@ def filename(self) -> Optional[str]:
return self._filename

@property
def headers(self) -> Optional[_CIMultiDict]:
def headers(self) -> _CIMultiDict:
"""Custom item headers"""
return self._headers

@property
def _binary_headers(self) -> bytes:
if self.headers is None:
# FIXME: This case actually is unreachable.
return b'' # pragma: no cover
return ''.join(
[k + ': ' + v + '\r\n' for k, v in self.headers.items()]
).encode('utf-8') + b'\r\n'
Expand All @@ -180,24 +180,15 @@ def encoding(self) -> Optional[str]:
return self._encoding

@property
def content_type(self) -> Optional[str]:
def content_type(self) -> str:
"""Content type"""
if self._content_type is not None:
return self._content_type
elif self._filename is not None:
mime = mimetypes.guess_type(self._filename)[0]
return 'application/octet-stream' if mime is None else mime
else:
return Payload._content_type
return self._headers[hdrs.CONTENT_TYPE]

def set_content_disposition(self,
disptype: str,
quote_fields: bool=True,
**params: Any) -> None:
"""Sets ``Content-Disposition`` header."""
if self._headers is None:
self._headers = CIMultiDict()

self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
disptype, quote_fields=quote_fields, **params)

Expand Down Expand Up @@ -292,7 +283,10 @@ def __init__(self,
super().__init__(value, *args, **kwargs)

if self._filename is not None and disposition is not None:
self.set_content_disposition(disposition, filename=self._filename)
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(
disposition, filename=self._filename
)

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
Expand Down
20 changes: 10 additions & 10 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,8 @@ async def test_writer_serialize_with_content_encoding_gzip(buf, stream,
await writer.write(stream)
headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert (b'--:\r\nContent-Encoding: gzip\r\n'
b'Content-Type: text/plain; charset=utf-8' == headers)
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
b'Content-Encoding: gzip' == headers)

decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS)
data = decompressor.decompress(message.split(b'\r\n')[0])
Expand All @@ -869,8 +869,8 @@ async def test_writer_serialize_with_content_encoding_deflate(buf, stream,
await writer.write(stream)
headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert (b'--:\r\nContent-Encoding: deflate\r\n'
b'Content-Type: text/plain; charset=utf-8' == headers)
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
b'Content-Encoding: deflate' == headers)

thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n'
assert thing == message
Expand All @@ -883,8 +883,8 @@ async def test_writer_serialize_with_content_encoding_identity(buf, stream,
await writer.write(stream)
headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert (b'--:\r\nContent-Encoding: identity\r\n'
b'Content-Type: application/octet-stream\r\n'
assert (b'--:\r\nContent-Type: application/octet-stream\r\n'
b'Content-Encoding: identity\r\n'
b'Content-Length: 16' == headers)

assert thing == message.split(b'\r\n')[0]
Expand All @@ -902,8 +902,8 @@ async def test_writer_with_content_transfer_encoding_base64(buf, stream,
await writer.write(stream)
headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n'
b'Content-Type: text/plain; charset=utf-8' ==
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
b'Content-Transfer-Encoding: base64' ==
headers)

assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0]
Expand All @@ -916,8 +916,8 @@ async def test_writer_content_transfer_encoding_quote_printable(buf, stream,
await writer.write(stream)
headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n'
b'Content-Type: text/plain; charset=utf-8' == headers)
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
b'Content-Transfer-Encoding: quoted-printable' == headers)

assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,'
b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0])
Expand Down

0 comments on commit ba5bf63

Please sign in to comment.