Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enable large tensor support for interp (#19363)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Kumar Srivastava <srivastava.141@buckeyemail.osu.edu>
  • Loading branch information
access2rohit and Rohit Kumar Srivastava authored Oct 23, 2020
1 parent eed080f commit 187c75d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/operator/numpy/np_interp_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ struct NumpyInterpParam : public dmlc::Parameter<NumpyInterpParam> {
};

struct interp {
MSHADOW_XINLINE static void Map(int i,
MSHADOW_XINLINE static void Map(index_t i,
double* out,
const double* x,
const double* xp,
const double* fp,
const int dsize,
const size_t dsize,
const double left,
const double right,
const bool has_left,
Expand All @@ -99,19 +99,19 @@ struct interp {
} else if (x_value < xp_low) {
out[i] = lval;
} else {
int imin = 0;
int imax = dsize;
int imid;
index_t imin = 0;
index_t imax = static_cast<index_t>(dsize);
index_t imid;
while (imin < imax) {
imid = static_cast<int>((imax + imin) / 2);
imid = static_cast<index_t>((imax + imin) / 2);
if (x_value >= xp[imid]) {
imin = imid + 1;
} else {
imax = imid;
}
} // biserction search

int j = imin;
index_t j = imin;
if (j == dsize) {
out[i] = fp[dsize-1];
} else if (x_value == xp[j-1]) {
Expand All @@ -130,28 +130,28 @@ struct interp {
};

struct interp_period {
MSHADOW_XINLINE static void Map(int i,
MSHADOW_XINLINE static void Map(index_t i,
double* out,
const double* x,
const double* xp,
const double* fp,
const index_t* idx,
const int dsize,
const size_t dsize,
const double period) {
double x_value = x[i];
int imin = 0;
int imax = dsize;
int imid;
index_t imin = 0;
index_t imax = static_cast<index_t>(dsize);
index_t imid;
while (imin < imax) {
imid = static_cast<int>((imax + imin) / 2);
imid = static_cast<index_t>((imax + imin) / 2);
if (x_value >= xp[idx[imid]]) {
imin = imid + 1;
} else {
imax = imid;
}
} // biserction search

int j = imin;
index_t j = imin;
double xp_below, xp_above;
double fp1, fp2;
if (j == 0) {
Expand Down
17 changes: 17 additions & 0 deletions tests/nightly/test_np_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,3 +2133,20 @@ def test_nan_to_num():
assert inp.grad.shape == inp.shape
assert inp.grad[0, -1] == 0 and inp.grad[1, -1] == 0
assert inp.grad[0, 0] == 1 and inp.grad[2, -1] == 0


@use_np
def test_interp():
xp = np.array([1, 2, 3])
fp = np.array([3, 2, 1])
inp = np.ones((2, INT_OVERFLOW))
inp[-1, -1] = 2.5
inp.attach_grad()
with mx.autograd.record():
B = np.interp(inp, xp, fp)
B.backward()
assert B.shape == inp.shape
assert B[-1, -1] == 1.5
assert inp.grad.shape == inp.shape
assert inp.grad[-1, -1] == 0

0 comments on commit 187c75d

Please sign in to comment.