diff --git a/src/snappy/snappy.py b/src/snappy/snappy.py index aa1a22e..6bd2b8b 100644 --- a/src/snappy/snappy.py +++ b/src/snappy/snappy.py @@ -149,23 +149,15 @@ def __init__(self): self.remains = None @staticmethod - def check_format(data): - """Checks that the given data starts with snappy framing format - stream identifier. - Raises UncompressError if it doesn't start with the identifier. - :return: None + def check_format(fin): + """Does this stream start with a stream header block? + + True indicates that the stream can likely be decoded using this class. """ - if len(data) < 6: - raise UncompressError("Too short data length") - chunk_type = struct.unpack("> 8) - chunk_type &= 0xff - if (chunk_type != _IDENTIFIER_CHUNK or - size != len(_STREAM_IDENTIFIER)): - raise UncompressError("stream missing snappy identifier") - chunk = data[4:4 + size] - if chunk != _STREAM_IDENTIFIER: - raise UncompressError("stream has invalid snappy identifier") + try: + return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK + except: + return False def decompress(self, data: bytes): """Decompress 'data', returning a string containing the uncompressed @@ -233,14 +225,21 @@ def __init__(self): self.remains = b"" @staticmethod - def check_format(data): - """Checks that there are enough bytes for a hadoop header - - We cannot actually determine if the data is really hadoop-snappy + def check_format(fin): + """Does this look like a hadoop snappy stream? """ - if len(data) < 8: - raise UncompressError("Too short data length") - chunk_length = int.from_bytes(data[4:8], "big") + try: + from snappy.snappy_formats import check_unframed_format + size = fin.seek(0, 2) + fin.seek(0) + assert size >= 8 + + chunk_length = int.from_bytes(fin.read(4), "big") + assert chunk_length < size + fin.read(4) + return check_unframed_format(fin) + except: + return False def decompress(self, data: bytes): """Decompress 'data', returning a string containing the uncompressed @@ -319,16 +318,43 @@ def stream_decompress(src, decompressor.flush() # makes sure the stream ended well -def check_format(fin=None, chunk=None, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, - decompressor_cls=StreamDecompressor): - ok = True - if chunk is None: - chunk = fin.read(blocksize) - if not chunk: - raise UncompressError("Empty input stream") - try: - decompressor_cls.check_format(chunk) - except UncompressError as err: - ok = False - return ok, chunk +def hadoop_stream_decompress( + src, + dst, + blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, +): + c = HadoopStreamDecompressor() + while True: + data = src.read(blocksize) + if not data: + break + buf = c.decompress(data) + if buf: + dst.write(buf) + dst.flush() + + +def hadoop_stream_compress( + src, + dst, + blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, +): + c = HadoopStreamCompressor() + while True: + data = src.read(blocksize) + if not data: + break + buf = c.compress(data) + if buf: + dst.write(buf) + dst.flush() + + +def raw_stream_decompress(src, dst): + data = src.read() + dst.write(decompress(data)) + + +def raw_stream_compress(src, dst): + data = src.read() + dst.write(compress(data)) diff --git a/src/snappy/snappy_formats.py b/src/snappy/snappy_formats.py index 51a54dd..e230e0b 100644 --- a/src/snappy/snappy_formats.py +++ b/src/snappy/snappy_formats.py @@ -8,65 +8,107 @@ from __future__ import absolute_import from .snappy import ( - stream_compress, stream_decompress, check_format, UncompressError) - + HadoopStreamDecompressor, StreamDecompressor, + hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress, + raw_stream_decompress, stream_compress, stream_decompress, + UncompressError +) -FRAMING_FORMAT = 'framing' # Means format auto detection. # For compression will be used framing format. # In case of decompression will try to detect a format from the input stream # header. -FORMAT_AUTO = 'auto' +DEFAULT_FORMAT = "auto" -DEFAULT_FORMAT = FORMAT_AUTO - -ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO] +ALL_SUPPORTED_FORMATS = ["framing", "auto"] _COMPRESS_METHODS = { - FRAMING_FORMAT: stream_compress, + "framing": stream_compress, + "hadoop": hadoop_stream_compress, + "raw": raw_stream_compress } _DECOMPRESS_METHODS = { - FRAMING_FORMAT: stream_decompress, + "framing": stream_decompress, + "hadoop": hadoop_stream_decompress, + "raw": raw_stream_decompress } # We will use framing format as the default to compression. # And for decompression, if it's not defined explicitly, we will try to # guess the format from the file header. -_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT +_DEFAULT_COMPRESS_FORMAT = "framing" + + +def uvarint(fin): + """Read uint64 nbumber from varint encoding in a stream""" + result = 0 + shift = 0 + while True: + byte = fin.read(1)[0] + result |= (byte & 0x7F) << shift + if (byte & 0x80) == 0: + break + shift += 7 + return result + + +def check_unframed_format(fin, reset=False): + """Can this be read using the raw codec + + This function wil return True for all snappy raw streams, but + True does not mean that we can necessarily decode the stream. + """ + if reset: + fin.seek(0) + try: + size = uvarint(fin) + assert size < 2**32 - 1 + next_byte = fin.read(1)[0] + end = fin.seek(0, 2) + assert size < end + assert next_byte & 0b11 == 0 # must start with literal block + return True + except: + return False + # The tuple contains an ordered sequence of a format checking function and # a format-specific decompression function. # Framing format has it's header, that may be recognized. -_DECOMPRESS_FORMAT_FUNCS = ( - (check_format, stream_decompress), -) +_DECOMPRESS_FORMAT_FUNCS = { + "framed": stream_decompress, + "hadoop": hadoop_stream_decompress, + "raw": raw_stream_decompress +} def guess_format_by_header(fin): """Tries to guess a compression format for the given input file by it's header. - :return: tuple of decompression method and a chunk that was taken from the - input for format detection. + + :return: format name (str), stream decompress function (callable) """ - chunk = None - for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS: - ok, chunk = check_method(fin=fin, chunk=chunk) - if not ok: - continue - return decompress_func, chunk - raise UncompressError("Can't detect archive format") + if StreamDecompressor.check_format(fin): + form = "framed" + elif HadoopStreamDecompressor.check_format(fin): + form = "hadoop" + elif check_unframed_format(fin, reset=True): + form = "raw" + else: + raise UncompressError("Can't detect format") + return form, _DECOMPRESS_FORMAT_FUNCS[form] def get_decompress_function(specified_format, fin): - if specified_format == FORMAT_AUTO: - decompress_func, read_chunk = guess_format_by_header(fin) - return decompress_func, read_chunk - return _DECOMPRESS_METHODS[specified_format], None + if specified_format == "auto": + format, decompress_func = guess_format_by_header(fin) + return decompress_func + return _DECOMPRESS_METHODS[specified_format] def get_compress_function(specified_format): - if specified_format == FORMAT_AUTO: + if specified_format == "auto": return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT] return _COMPRESS_METHODS[specified_format] diff --git a/test_formats.py b/test_formats.py index 43afb91..6453b1e 100644 --- a/test_formats.py +++ b/test_formats.py @@ -3,12 +3,11 @@ from unittest import TestCase from snappy import snappy_formats as formats -from snappy.snappy import _CHUNK_MAX, UncompressError class TestFormatBase(TestCase): - compress_format = formats.FORMAT_AUTO - decompress_format = formats.FORMAT_AUTO + compress_format = "auto" + decompress_format = "auto" success = True def runTest(self): @@ -18,34 +17,58 @@ def runTest(self): compressed_stream = io.BytesIO() compress_func(instream, compressed_stream) compressed_stream.seek(0) - decompress_func, read_chunk = formats.get_decompress_function( + decompress_func = formats.get_decompress_function( self.decompress_format, compressed_stream ) + compressed_stream.seek(0) decompressed_stream = io.BytesIO() decompress_func( compressed_stream, decompressed_stream, - start_chunk=read_chunk ) decompressed_stream.seek(0) self.assertEqual(data, decompressed_stream.read()) class TestFormatFramingFraming(TestFormatBase): - compress_format = formats.FRAMING_FORMAT - decompress_format = formats.FRAMING_FORMAT + compress_format = "framing" + decompress_format = "framing" success = True class TestFormatFramingAuto(TestFormatBase): - compress_format = formats.FRAMING_FORMAT - decompress_format = formats.FORMAT_AUTO + compress_format = "framing" + decompress_format = "auto" success = True class TestFormatAutoFraming(TestFormatBase): - compress_format = formats.FORMAT_AUTO - decompress_format = formats.FRAMING_FORMAT + compress_format = "auto" + decompress_format = "framing" + success = True + + +class TestFormatHadoop(TestFormatBase): + compress_format = "hadoop" + decompress_format = "hadoop" + success = True + + +class TestFormatRaw(TestFormatBase): + compress_format = "raw" + decompress_format = "raw" + success = True + + +class TestFormatHadoopAuto(TestFormatBase): + compress_format = "hadoop" + decompress_format = "auto" + success = True + + +class TestFormatRawAuto(TestFormatBase): + compress_format = "raw" + decompress_format = "auto" success = True