diff --git a/sdc/datatypes/hpat_pandas_series_functions.py b/sdc/datatypes/hpat_pandas_series_functions.py index 372099ceb..4a2fed530 100644 --- a/sdc/datatypes/hpat_pandas_series_functions.py +++ b/sdc/datatypes/hpat_pandas_series_functions.py @@ -102,11 +102,12 @@ def hpat_pandas_series_iloc_impl(self, idx): return hpat_pandas_series_iloc_impl - def hpat_pandas_series_iloc_callable_impl(self, idx): - index = numpy.asarray(list(map(idx, self._series._data))) - return pandas.Series(self._series._data[index], self._series.index[index], self._series._name) + if isinstance(idx, types.Callable): + def hpat_pandas_series_iloc_callable_impl(self, idx): + index = numpy.asarray(list(map(idx, self._series._data))) + return pandas.Series(self._series._data[index], self._series.index[index], self._series._name) - return hpat_pandas_series_iloc_callable_impl + return hpat_pandas_series_iloc_callable_impl raise TypingError('{} The index must be an Integer, Slice or List of Integer or a callable.\ Given: {}'.format(_func_name, idx)) @@ -124,6 +125,8 @@ def hpat_pandas_series_iat_impl(self, idx): # Note: Loc return Series # Note: Index 0 in slice not supported # Note: Loc slice and callable with String not implement + # Note: Loc callable return float Series + series_dtype = self.series.data.dtype index_is_none = (self.series.index is None or isinstance(self.series.index, numba.types.misc.NoneType)) if isinstance(idx, types.SliceType) and index_is_none: @@ -139,7 +142,7 @@ def hpat_pandas_series_loc_slice_noidx_impl(self, idx): return hpat_pandas_series_loc_slice_noidx_impl - if isinstance(idx, (int, types.Integer, types.UnicodeType, types.StringLiteral)): + if isinstance(idx, (int, types.Number, types.UnicodeType, types.StringLiteral)): def hpat_pandas_series_loc_impl(self, idx): index = self._series.index mask = numpy.empty(len(self._series._data), numpy.bool_) @@ -149,11 +152,28 @@ def hpat_pandas_series_loc_impl(self, idx): return hpat_pandas_series_loc_impl + if isinstance(idx, types.Callable): + def hpat_pandas_series_loc_callable_impl(self, idx): + series = self._series + index = series.index + res = numpy.asarray(list(map(idx, self._series._data))) + new_series = pandas.Series(numpy.empty(0, numpy.float64), numpy.empty(0, series_dtype), series._name) + for i in numba.prange(len(res)): + tmp = series.loc[res[i]] + if len(tmp) > 0: + new_series = new_series.append(tmp) + else: + new_series = new_series.append(pandas.Series(numpy.array([numpy.nan]), numpy.array([res[i]]))) + + return new_series + + return hpat_pandas_series_loc_callable_impl + raise TypingError('{} The index must be an Number, Slice, String, List, Array or a callable.\ Given: {}'.format(_func_name, idx)) if accessor == 'at': - if isinstance(idx, (int, types.Integer, types.UnicodeType, types.StringLiteral)): + if isinstance(idx, (int, types.Number, types.UnicodeType, types.StringLiteral)): def hpat_pandas_series_at_impl(self, idx): index = self._series.index mask = numpy.empty(len(self._series._data), numpy.bool_) diff --git a/sdc/tests/test_series.py b/sdc/tests/test_series.py index 735c96e45..cbaae3765 100644 --- a/sdc/tests/test_series.py +++ b/sdc/tests/test_series.py @@ -1257,6 +1257,24 @@ def test_impl(A): S = pd.Series([2, 4, 6], ['1', '3', '5']) np.testing.assert_array_equal(hpat_func(S), test_impl(S)) + @skip_parallel + @skip_sdc_jit('Not impl in old style') + def test_series_loc_callable(self): + def test_impl(S): + return S.loc[lambda a: a ** 2] + hpat_func = self.jit(test_impl) + S = pd.Series([0, 6, 4, 7, 8], [0, 6, 66, 6, 8]) + pd.testing.assert_series_equal(hpat_func(S), test_impl(S)) + + # Loc callable return float Series + @unittest.expectedFailure + def test_series_loc_callable2(self): + def test_impl(S): + return S.loc[lambda a: a] + hpat_func = self.jit(test_impl) + S = pd.Series([0, 6, 8, 8, 8], [0, 6, 66, 6, 8]) + pd.testing.assert_series_equal(hpat_func(S), test_impl(S)) + @skip_sdc_jit('Not impl in old style') def test_series_at_str(self): def test_impl(A):