diff --git a/src/future/builtins/new_min_max.py b/src/future/builtins/new_min_max.py index 8fd63fdf..6f0c2a86 100644 --- a/src/future/builtins/new_min_max.py +++ b/src/future/builtins/new_min_max.py @@ -1,9 +1,13 @@ +import itertools + from future import utils if utils.PY2: from __builtin__ import max as _builtin_max, min as _builtin_min else: from builtins import max as _builtin_max, min as _builtin_min +_SENTINEL = object() + def newmin(*args, **kwargs): return new_min_max(_builtin_min, *args, **kwargs) @@ -29,21 +33,24 @@ def new_min_max(_builtin_func, *args, **kwargs): if len(args) == 0: raise TypeError - if len(args) != 1 and kwargs.get('default') is not None: + if len(args) != 1 and kwargs.get('default', _SENTINEL) is not _SENTINEL: raise TypeError if len(args) == 1: + iterator = iter(args[0]) try: - next(iter(args[0])) + first = next(iterator) except StopIteration: - if kwargs.get('default') is not None: + if kwargs.get('default', _SENTINEL) is not _SENTINEL: return kwargs.get('default') else: - raise ValueError('iterable is an empty sequence') + raise ValueError('{}() arg is an empty sequence'.format(_builtin_func.__name__)) + else: + iterator = itertools.chain([first], iterator) if kwargs.get('key') is not None: - return _builtin_func(args[0], key=kwargs.get('key')) + return _builtin_func(iterator, key=kwargs.get('key')) else: - return _builtin_func(args[0]) + return _builtin_func(iterator) if len(args) > 1: if kwargs.get('key') is not None: diff --git a/tests/test_future/test_builtins.py b/tests/test_future/test_builtins.py index d983f9d6..ca07b9ef 100644 --- a/tests/test_future/test_builtins.py +++ b/tests/test_future/test_builtins.py @@ -1105,6 +1105,7 @@ def test_max(self): with self.assertRaises(TypeError): max(1, 2, default=0) self.assertEqual(max([], default=0), 0) + self.assertIs(max([], default=None), None) def test_min(self): self.assertEqual(min('123123'), '1') @@ -1123,6 +1124,7 @@ class BadSeq: def __getitem__(self, index): raise ValueError self.assertRaises(ValueError, min, BadSeq()) + self.assertEqual(max(x for x in [5, 4, 3]), 5) for stmt in ( "min(key=int)", # no args @@ -1149,11 +1151,15 @@ def __getitem__(self, index): sorted(data, key=f)[0]) self.assertEqual(min([], default=5), 5) self.assertEqual(min([], default=0), 0) + self.assertIs(min([], default=None), None) with self.assertRaises(TypeError): max(None, default=5) with self.assertRaises(TypeError): max(1, 2, default=0) + # Test iterables that can only be looped once #510 + self.assertEqual(min(x for x in [5]), 5) + def test_next(self): it = iter(range(2)) self.assertEqual(next(it), 0)