Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Restore quantized RNN to master #20952

Merged
merged 11 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ using FNeedRequantize = std::function<bool(const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<
bool(const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;

/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
Expand All @@ -127,6 +128,7 @@
'_contrib_quantized_pooling',
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
'_contrib_quantized_rnn',
'_image_crop',
'_linspace',
'_contrib_requantize',
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
Expand All @@ -108,6 +109,7 @@
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
'_contrib_quantized_reshape',
'_contrib_quantized_rnn',
'_contrib_quantized_transpose',
'_npx_quantized_reshape',
'_npx_quantized_transpose',
Expand Down
82 changes: 60 additions & 22 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
from ..util import is_np_array, wrap_ctx_to_device_func


def _multilist_iterator(arg, func):
"""Iterate over multidiemnsional list and returns new list
with same dimensions, but applied `func` function on list elements.
E.g. _multilist_iterator([1, 2, [3, 4]], lambda x: x**2) = [1, 4, [9, 16]]
"""
ret = []
if isinstance(arg, list):
for el in arg:
ret.append(_multilist_iterator(el, func))
else:
return func(arg)

return ret

def _quantize_params(qsym, params, min_max_dict):
"""Given a quantized symbol and a dict of params that have not been quantized,
generate quantized params. Currently only supports quantizing the arg_params
Expand Down Expand Up @@ -357,7 +371,7 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_
for batch in data:
if not isinstance(batch, list):
batch = [batch]
batch = [b.as_in_context(mx.cpu()) for b in batch]
batch = _multilist_iterator(batch, lambda b: b.as_in_context(mx.cpu()))
sym_block(*batch[:num_inputs])
num_batches += 1
if num_calib_batches is not None and num_batches >= num_calib_batches:
Expand All @@ -368,20 +382,45 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_


def _generate_list_of_data_desc(data_shapes, data_types):
""""Convert list ot tuples to list of DataDesc."""
if isinstance(data_shapes, list):
if all(isinstance(x, DataDesc) for x in data_shapes):
return data_shapes
if all(isinstance(x, tuple) for x in data_shapes):
if len(data_shapes) == 1:
data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
"""Convert list of tuples to list of DataDesc."""
def flatten_list(arg):
ret = []
for el in arg:
if isinstance(el, list):
ret += flatten_list(el)
else:
data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i],
dtype=data_types[i]) for i in range(len(data_shapes))]
return data_shapes
raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')
ret.append(el)
return ret

flattened_data_types = flatten_list(data_types)
flattened_data_shapes = flatten_list(data_shapes)

if all(isinstance(x, DataDesc) for x in flattened_data_shapes):
return data_shapes

assert len(flattened_data_types) == len(flattened_data_shapes)

# pass integral type as reference
counter = [0]
def get_data_desc(data_shape, counter=counter, data_types=flattened_data_types):
if isinstance(data_shape, DataDesc):
return data_shape
elif isinstance(data_shape, tuple):
desc = DataDesc(name='data' + str(counter[0]), shape=data_shape,
dtype=data_types[counter[0]])
counter[0] += 1
return desc
else:
raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')


if len(data_shapes) == 1 and not isinstance(data_shapes[0], list):
data_descs = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
else:
data_descs = _multilist_iterator(data_shapes, get_data_desc)

return data_descs

@wrap_ctx_to_device_func
def quantize_model(sym, arg_params, aux_params, data_names=('data',),
device=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy',
Expand Down Expand Up @@ -841,25 +880,24 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
x = iter(calib_data)
batch = next(x)
if isinstance(batch, list):
data_shapes = [b.shape for b in batch]
data_types = [b.dtype for b in batch]
data_shapes = _multilist_iterator(batch, lambda x: x.shape)
data_types = _multilist_iterator(batch, lambda x: x.dtype)
else:
data_shapes = [batch.shape]
data_types = [batch.dtype]
else:
raise ValueError('calib_data expects mx.gluon.data.DataLoader')

if data_types is None:
data_types = [mx_real_t] * len(data_shapes)
data_types = _multilist_iterator(data_shapes, lambda x: mx_real_t)

data_descs = _generate_list_of_data_desc(data_shapes, data_types)

num_inputs = len(data_descs)
data_nd = []
for desc in data_descs:
if is_np_array():
data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype))
else:
data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype))
arr_fn = mx.np if is_np_array() else mx.nd
data_nd = _multilist_iterator(data_descs, lambda d, F=arr_fn: F.zeros(shape=d.shape, dtype=d.dtype))

while True:
try:
network(*data_nd)
Expand Down Expand Up @@ -919,7 +957,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
raise ValueError(
'calib_data must be provided when calib_mode=%s' % calib_mode)
if calib_mode in ['naive', 'entropy', 'custom']:
inputs = [mx.sym.var(desc.name) for desc in data_descs]
inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))
calib_net = SymbolBlock(symnet, inputs)
for k, v in calib_net.collect_params().items():
v.grad_req = 'null'
Expand All @@ -939,7 +977,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
else:
raise ValueError('calib_mode has to be one of: naive, entropy, custom')
elif calib_mode is not None and calib_mode == 'none':
inputs = [mx.sym.var(desc.name) for desc in data_descs]
inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))

net = SymbolBlock(qsym, inputs)
for k, v in net.collect_params().items():
Expand Down
5 changes: 4 additions & 1 deletion python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,11 @@ def provide_data(self):
@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]

Expand Down
61 changes: 47 additions & 14 deletions src/operator/nn/dnnl/dnnl_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@

#include "operator/rnn-inl.h"
#include "dnnl_base-inl.h"
#include "operator/quantization/quantized_rnn-inl.h"

namespace mxnet {
namespace op {

struct DNNLRnnParam : public dmlc::Parameter<DNNLRnnParam> {
bool quantized;

DMLC_DECLARE_PARAMETER(DNNLRnnParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized RNN operator");
}
};

struct DNNLRnnLayerParam {
using memory = dnnl::memory;
using dims = dnnl::memory::dims;
Expand Down Expand Up @@ -66,6 +76,10 @@ struct DNNLRnnLayerParam {
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

bool quantized; // whether this layer is quantized
bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the
// quantized rnn operator

DNNLRnnLayerParam(int num_layer,
index_t batch_size,
index_t seq_len,
Expand All @@ -82,18 +96,21 @@ struct DNNLRnnLayerParam {
input_size(input_size),
state_size(state_size),
proj_size(proj_size),
seq_len(seq_len) {}
seq_len(seq_len),
quantized(false),
enable_u8_output(false) {}

void SetDims();
};

typedef std::vector<DNNLRnnLayerParam> LayerParamVector;
struct DNNLRnnFullParam {
RNNParam default_param;
DNNLRnnParam dnnl_param;
LayerParamVector layer_params;
};

DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param,
DNNLRnnFullParam DNNLRnnFullParamParser(const nnvm::NodeAttrs& attrs,
const index_t seq_len,
const index_t batch_size,
const index_t input_size);
Expand All @@ -105,7 +122,7 @@ class DNNLRnnMemMgr {
// The memory buffer in NDArray life-cycle
NDArray workspace_;
// This points to the memory buffer from a NDArray
char* curr_mem;
char* curr_mem = nullptr;
// The total bytes of the workspace of a DNNLRnnOp
size_t mem_size = 0;
// The current available memory bytes
Expand All @@ -121,7 +138,7 @@ class DNNLRnnMemMgr {
* \param size byte number
* \param ctx Context of device enviroment
*/
void Init(dim_t size, const Context& ctx);
void Init(const dim_t size, const Context& ctx);

// Return the bytes number of the buffer
const size_t Size() {
Expand All @@ -135,6 +152,8 @@ class DNNLRnnMemMgr {
dnnl::memory* Alloc(const dnnl::memory::desc& md);
};

typedef std::shared_ptr<dnnl::primitive_attr> shared_dnnl_attr_t;

/*
* Rnn Primitive.
*/
Expand All @@ -144,15 +163,15 @@ class RnnPrimitive {
* lstm_forward, lbr_gru_forward, vanilla_rnn_forward
*/
template <typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
static RnnPrimitive Create(const shared_dnnl_attr_t attr, Args&&... args) {
RnnPrimitive rnn_fwd_prim;
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
[](typename rnn_fwd::primitive_desc* pd) {
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
new typename rnn_fwd::primitive_desc(
fwd_desc, attr ? *attr : dnnl::primitive_attr(), CpuEngine::Get()->get_engine()),
[](void* pd) { delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd); });
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
rnn_fwd_prim.attr_ = attr;
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.weights_proj_desc_ = fwd_pd->weights_projection_desc();
Expand All @@ -164,6 +183,7 @@ class RnnPrimitive {
}

RnnPrimitive() {
this->attr_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = dnnl::memory::desc();
Expand All @@ -173,6 +193,7 @@ class RnnPrimitive {
}

RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand All @@ -183,6 +204,7 @@ class RnnPrimitive {

RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand Down Expand Up @@ -217,9 +239,14 @@ class RnnPrimitive {
return workspace_desc_;
}

const dnnl::primitive_attr& GetPrimAttr() const {
return *attr_;
}

private:
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<dnnl::primitive> primitive_;
shared_dnnl_attr_t attr_;
dnnl::memory::desc weights_layer_desc_;
dnnl::memory::desc weights_iter_desc_;
dnnl::memory::desc weights_proj_desc_;
Expand All @@ -229,7 +256,8 @@ class RnnPrimitive {
RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params);
const NDArray& params,
const shared_dnnl_attr_t attr = nullptr);

/*
* Use this to manage memory and primitive of DNNL RNN forward inference.
Expand All @@ -240,11 +268,12 @@ class DNNLRnnForward {
const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params)
const NDArray& params,
const shared_dnnl_attr_t attr = nullptr)
: ctx_(ctx),
initialized_(false),
param_(layer_param),
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {}
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {}

void SetNewDataMem(void* x,
void* hx,
Expand All @@ -263,6 +292,10 @@ class DNNLRnnForward {
return fwd_inf_.GetPrim();
}

void ResetFwd(const NDArray& data, const NDArray& params, const shared_dnnl_attr_t& attr) {
fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr);
}

const size_t GetSize() const {
const size_t size = fwd_inf_.GetLayerDesc().get_size() + fwd_inf_.GetIterDesc().get_size() +
fwd_inf_.GetProjDesc().get_size();
Expand Down Expand Up @@ -482,13 +515,13 @@ class DNNLRnnBackward {
*/
class DNNLRnnOp {
public:
explicit DNNLRnnOp(const RNNParam& param,
explicit DNNLRnnOp(const nnvm::NodeAttrs& attrs,
const int seq_len,
const int batch_size,
const int input_size)
: initialized_(false),
weights_version_(0),
full_param_(DNNLRnnFullParamParser(param, seq_len, batch_size, input_size)) {}
full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}

void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading