Skip to content

Commit

Permalink
Add qr backward for wide matrices with m < n (apache#18197)
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Roberts authored Jul 17, 2020
1 parent a77f774 commit 60d0672
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 71 deletions.
185 changes: 154 additions & 31 deletions src/operator/numpy/linalg/np_qr-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,19 +483,53 @@ struct assign_helper {
}
};

// backprop helper to get y, v
struct QrBackHelper_G1 {
template<typename DType>
MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data,
const int ldin, DType *out_data, const int ldout) {
const int offin(k * m * ldin);
const int offout(k * m * ldout);
for (index_t i = 0; i < m; ++i) {
for (index_t j = 0; j < n - m; ++j) {
out_data[offout + i * ldout + j] = in_data[offin + m + i * ldin + j];
}
}
}
};

// backprop helper to get da from dx, dy
struct QrBackHelper_G2 {
template<typename DType>
MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data_x,
const int ldinx, const DType *in_data_y, const int ldiny,
DType *out_data, const int ldout) {
const int offiny(k * m * ldiny);
const int offinx(k * m * ldinx);
const int offout(k * m * ldout);
for (index_t i = 0; i < m; ++i) {
for (index_t j = 0; j < n - m; ++j) {
out_data[offout + m + i * ldout + j] = in_data_y[offiny + i * ldiny + j];
}
for (index_t j = 0; j < m; ++j) {
out_data[offout + i * ldout + j] = in_data_x[offinx + i * ldinx + j];
}
}
}
};

// Reference https://journals.aps.org/prx/pdf/10.1103/PhysRevX.9.031041
struct qr_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dQ,
const Tensor<xpu, 3, DType>& dR,
const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& Q,
const Tensor<xpu, 3, DType>& R,
const Tensor<xpu, 3, DType>& M,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
// Implements case m >= n; da = [dq + q@copyltu(M))]@r**(-T)
// Implements da = [dq + q@copyltu(M))]@r**(-T)
// Where M = r@(dr**T) - (dq**T)@q
// Reference: https://arxiv.org/abs/1710.08717
Stream<xpu> *s = ctx.get_stream<xpu>();
if (dQ.dptr_ != dA.dptr_) Copy(dA, dQ, s);
// M = R@dR_T
Expand All @@ -514,15 +548,30 @@ struct qr_backward {

template<typename xpu>
size_t QrBackwardWorkspaceSize(const TBlob& a,
const TBlob& q,
const TBlob& r,
const TBlob& grad_a) {
const mxnet::TShape& a_shape = a.shape_;
const int a_ndim = a_shape.ndim();
const int n = a.size(a_ndim - 1);
const int m = a.size(a_ndim - 2);

if (0U == a.Size()) { return 0U; }

MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
size_t work_space_size = 0;
// for grad a and M
work_space_size += a.Size();
work_space_size += r.Size();
if (m >= n) {
work_space_size += r.Size();
} else {
const mxnet::TShape& q_shape = q.shape_;
mxnet::TShape v_shape(q_shape);
v_shape[a_ndim - 1] = n - m;
// allocate space for: m, u, dq_prime, du, dx (shaped like Q)
work_space_size += 5 * q.Size();
// allocate space for: y, dv (shaped like V, the partition of R)
work_space_size += 2 * v_shape.Size();
}
return work_space_size * sizeof(DType);
});
LOG(FATAL) << "InternalError: cannot reach here";
Expand All @@ -542,36 +591,116 @@ void QrBackwardImpl(const TBlob& grad_a,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& a_shape = a.shape_;
const mxnet::TShape& q_shape = q.shape_;
const mxnet::TShape& r_shape = r.shape_;
const int a_ndim = a_shape.ndim();
const int m = a.size(a_ndim - 2);
const int n = a.size(a_ndim - 1);

if (kNullOp == req[0]) { return; }

if (0U == a_shape.Size()) { return; }

MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
// case m >= n; Q of same shape with A and R is (n, n)
DType *m_ptr = reinterpret_cast<DType*>(workspace.dptr_);
DType *grad_a_ptr = m_ptr + r_shape.Size();
TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
// common for all shapes (m, n)
DType *grad_a_ptr = reinterpret_cast<DType*>(workspace.dptr_);
TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask);
// dR_T
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);

qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
grad_q.FlatToKD<xpu, 3, DType>(s),
grad_r.FlatToKD<xpu, 3, DType>(s),
a.FlatToKD<xpu, 3, DType>(s),
q.FlatToKD<xpu, 3, DType>(s),
r.FlatToKD<xpu, 3, DType>(s),
temp_m.FlatToKD<xpu, 3, DType>(s),
ctx, attrs);

if (m >= n) {
// Q of same shape with A (m, n) and R is (n, n)
DType *m_ptr = grad_a_ptr + a_shape.Size();
TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
// dR_T
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);
qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
grad_q.FlatToKD<xpu, 3, DType>(s),
grad_r.FlatToKD<xpu, 3, DType>(s),
q.FlatToKD<xpu, 3, DType>(s),
r.FlatToKD<xpu, 3, DType>(s),
temp_m.FlatToKD<xpu, 3, DType>(s),
ctx, attrs);
} else {
// R is same shape with A (m, n) and Q is (m, m)
// Partition A = (X | Y); R = (U | V)
// X and U are (m, m); Y and V are (m, n - m)
mxnet::TShape v_shape(q_shape);
v_shape[a_ndim - 1] = n - m;

DType *m_ptr = grad_a_ptr + a_shape.Size();
DType *u_ptr = m_ptr + q_shape.Size();
DType *dq_prime_ptr = u_ptr + q_shape.Size();
DType *dv_ptr = dq_prime_ptr + q_shape.Size();
DType *y_ptr = dv_ptr + v_shape.Size();
DType *du_ptr = y_ptr + v_shape.Size();
DType *dx_ptr = du_ptr + q_shape.Size();

TBlob temp_m(m_ptr, q_shape, xpu::kDevMask);
TBlob u_data(u_ptr, q_shape, xpu::kDevMask);
TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask);
TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask);
TBlob y_data(y_ptr, v_shape, xpu::kDevMask);
TBlob du_data(du_ptr, q_shape, xpu::kDevMask);
TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask);

Tensor<xpu, 3, DType> R = r.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dR = grad_r.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> Q = q.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dQ = grad_q.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dQ_prime = dq_prime_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> A = a.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dA = grad_a_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> U = u_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dU = du_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dV = dv_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> Y = y_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dX = dx_data.FlatToKD<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> M = temp_m.FlatToKD<xpu, 3, DType>(s);

// U
for (index_t i = 0; i < R.size(0); ++i) {
const Tensor<xpu, 2, DType>& Ri = R[i];
const Tensor<xpu, 2, DType>& Ui = U[i];
Tensor<xpu, 2, DType> Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s);
Copy(Ui, Um, s);
}
// dU
for (index_t i = 0; i < dR.size(0); ++i) {
const Tensor<xpu, 2, DType>& dRi = dR[i];
const Tensor<xpu, 2, DType>& dUi = dU[i];
Tensor<xpu, 2, DType> dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s);
Copy(dUi, dUm, s);
}
// Y
mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_);
// dV
mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_);
// store dU_T in M
mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m);
// dq_prime = dQ
Copy(dQ_prime, dQ, s);
// dq_prime = dQ+Y@dV.T
gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s);
// dX = op call
qr_backward::op(dX,
dQ_prime,
dU,
Q,
U,
M,
ctx, attrs);
// dY = Q@dV; reuse Y memory for dY
gemm::op(Q, dV, Y, DType(1.0), DType(0.0), false, false, s);
// copy dX and dY to dA
mxnet_op::Kernel<QrBackHelper_G2, xpu>::Launch(
s, dA.size(0), m, n, dX.dptr_, dX.stride_, Y.dptr_, Y.stride_, dA.dptr_, dA.stride_);
}
// common for all shapes
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch(
s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>());
mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch(
s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>());
});
});
}
Expand All @@ -594,14 +723,8 @@ void NumpyLaQrBackward(const nnvm::NodeAttrs& attrs,
const TBlob& q = inputs[3];
const TBlob& r = inputs[4];
const TBlob& grad_a = outputs[0];
const int a_ndim = a.shape_.ndim();
const int n = a.size(a_ndim - 1);
const int m = a.size(a_ndim - 2);

CHECK_LE(n, m)
<< "QrBackward not implemented when ncols > nrows";

size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, r, grad_a);
size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, q, r, grad_a);
Tensor<xpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<xpu, 1, char>(Shape1(workspace_size), ctx.get_stream<xpu>());
QrBackwardImpl<xpu>(grad_a, grad_q, grad_r, a, q, r, req, workspace, ctx, attrs);
Expand Down
Loading

0 comments on commit 60d0672

Please sign in to comment.