Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Restore quantized RNN to master (#20952)
Browse files Browse the repository at this point in the history
* Restore quantized RNN

sanity

* Add tests & disable LSTMP from quantization

* apply review comments

* change link

* Add new lines at the EOF

* Add ops to amp lists

* Remove unused features

* Fix DataDesc handling in quantization

* fix website

* fix sanity

* remove magic number

Co-authored-by: Bartlomiej Gawrych <barlomiej.gawrych@intel.com>
  • Loading branch information
bgawrych and Bartlomiej Gawrych authored Apr 6, 2022
1 parent 08737b2 commit ea6c91a
Show file tree
Hide file tree
Showing 19 changed files with 1,772 additions and 142 deletions.
13 changes: 13 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ using FNeedRequantize = std::function<bool(const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<
bool(const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;

/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
Expand All @@ -105,6 +106,7 @@
'_contrib_quantized_pooling',
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
'_contrib_quantized_rnn',
'_image_crop',
'_linspace',
'_contrib_requantize',
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
Expand All @@ -108,6 +109,7 @@
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
'_contrib_quantized_reshape',
'_contrib_quantized_rnn',
'_contrib_quantized_transpose',
'_npx_quantized_reshape',
'_npx_quantized_transpose',
Expand Down
82 changes: 60 additions & 22 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
from ..util import is_np_array, wrap_ctx_to_device_func


def _multilist_iterator(arg, func):
"""Iterate over multidiemnsional list and returns new list
with same dimensions, but applied `func` function on list elements.
E.g. _multilist_iterator([1, 2, [3, 4]], lambda x: x**2) = [1, 4, [9, 16]]
"""
ret = []
if isinstance(arg, list):
for el in arg:
ret.append(_multilist_iterator(el, func))
else:
return func(arg)

return ret

def _quantize_params(qsym, params, min_max_dict):
"""Given a quantized symbol and a dict of params that have not been quantized,
generate quantized params. Currently only supports quantizing the arg_params
Expand Down Expand Up @@ -357,7 +371,7 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_
for batch in data:
if not isinstance(batch, list):
batch = [batch]
batch = [b.as_in_context(mx.cpu()) for b in batch]
batch = _multilist_iterator(batch, lambda b: b.as_in_context(mx.cpu()))
sym_block(*batch[:num_inputs])
num_batches += 1
if num_calib_batches is not None and num_batches >= num_calib_batches:
Expand All @@ -368,20 +382,45 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_


def _generate_list_of_data_desc(data_shapes, data_types):
""""Convert list ot tuples to list of DataDesc."""
if isinstance(data_shapes, list):
if all(isinstance(x, DataDesc) for x in data_shapes):
return data_shapes
if all(isinstance(x, tuple) for x in data_shapes):
if len(data_shapes) == 1:
data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
"""Convert list of tuples to list of DataDesc."""
def flatten_list(arg):
ret = []
for el in arg:
if isinstance(el, list):
ret += flatten_list(el)
else:
data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i],
dtype=data_types[i]) for i in range(len(data_shapes))]
return data_shapes
raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')
ret.append(el)
return ret

flattened_data_types = flatten_list(data_types)
flattened_data_shapes = flatten_list(data_shapes)

if all(isinstance(x, DataDesc) for x in flattened_data_shapes):
return data_shapes

assert len(flattened_data_types) == len(flattened_data_shapes)

# pass integral type as reference
counter = [0]
def get_data_desc(data_shape, counter=counter, data_types=flattened_data_types):
if isinstance(data_shape, DataDesc):
return data_shape
elif isinstance(data_shape, tuple):
desc = DataDesc(name='data' + str(counter[0]), shape=data_shape,
dtype=data_types[counter[0]])
counter[0] += 1
return desc
else:
raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')


if len(data_shapes) == 1 and not isinstance(data_shapes[0], list):
data_descs = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
else:
data_descs = _multilist_iterator(data_shapes, get_data_desc)

return data_descs

@wrap_ctx_to_device_func
def quantize_model(sym, arg_params, aux_params, data_names=('data',),
device=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy',
Expand Down Expand Up @@ -841,25 +880,24 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
x = iter(calib_data)
batch = next(x)
if isinstance(batch, list):
data_shapes = [b.shape for b in batch]
data_types = [b.dtype for b in batch]
data_shapes = _multilist_iterator(batch, lambda x: x.shape)
data_types = _multilist_iterator(batch, lambda x: x.dtype)
else:
data_shapes = [batch.shape]
data_types = [batch.dtype]
else:
raise ValueError('calib_data expects mx.gluon.data.DataLoader')

if data_types is None:
data_types = [mx_real_t] * len(data_shapes)
data_types = _multilist_iterator(data_shapes, lambda x: mx_real_t)

data_descs = _generate_list_of_data_desc(data_shapes, data_types)

num_inputs = len(data_descs)
data_nd = []
for desc in data_descs:
if is_np_array():
data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype))
else:
data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype))
arr_fn = mx.np if is_np_array() else mx.nd
data_nd = _multilist_iterator(data_descs, lambda d, F=arr_fn: F.zeros(shape=d.shape, dtype=d.dtype))

while True:
try:
network(*data_nd)
Expand Down Expand Up @@ -919,7 +957,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
raise ValueError(
'calib_data must be provided when calib_mode=%s' % calib_mode)
if calib_mode in ['naive', 'entropy', 'custom']:
inputs = [mx.sym.var(desc.name) for desc in data_descs]
inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))
calib_net = SymbolBlock(symnet, inputs)
for k, v in calib_net.collect_params().items():
v.grad_req = 'null'
Expand All @@ -939,7 +977,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
else:
raise ValueError('calib_mode has to be one of: naive, entropy, custom')
elif calib_mode is not None and calib_mode == 'none':
inputs = [mx.sym.var(desc.name) for desc in data_descs]
inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))

net = SymbolBlock(qsym, inputs)
for k, v in net.collect_params().items():
Expand Down
5 changes: 4 additions & 1 deletion python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,11 @@ def provide_data(self):
@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]

Expand Down
61 changes: 47 additions & 14 deletions src/operator/nn/dnnl/dnnl_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@

#include "operator/rnn-inl.h"
#include "dnnl_base-inl.h"
#include "operator/quantization/quantized_rnn-inl.h"

namespace mxnet {
namespace op {

struct DNNLRnnParam : public dmlc::Parameter<DNNLRnnParam> {
bool quantized;

DMLC_DECLARE_PARAMETER(DNNLRnnParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized RNN operator");
}
};

struct DNNLRnnLayerParam {
using memory = dnnl::memory;
using dims = dnnl::memory::dims;
Expand Down Expand Up @@ -66,6 +76,10 @@ struct DNNLRnnLayerParam {
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

bool quantized; // whether this layer is quantized
bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the
// quantized rnn operator

DNNLRnnLayerParam(int num_layer,
index_t batch_size,
index_t seq_len,
Expand All @@ -82,18 +96,21 @@ struct DNNLRnnLayerParam {
input_size(input_size),
state_size(state_size),
proj_size(proj_size),
seq_len(seq_len) {}
seq_len(seq_len),
quantized(false),
enable_u8_output(false) {}

void SetDims();
};

typedef std::vector<DNNLRnnLayerParam> LayerParamVector;
struct DNNLRnnFullParam {
RNNParam default_param;
DNNLRnnParam dnnl_param;
LayerParamVector layer_params;
};

DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param,
DNNLRnnFullParam DNNLRnnFullParamParser(const nnvm::NodeAttrs& attrs,
const index_t seq_len,
const index_t batch_size,
const index_t input_size);
Expand All @@ -105,7 +122,7 @@ class DNNLRnnMemMgr {
// The memory buffer in NDArray life-cycle
NDArray workspace_;
// This points to the memory buffer from a NDArray
char* curr_mem;
char* curr_mem = nullptr;
// The total bytes of the workspace of a DNNLRnnOp
size_t mem_size = 0;
// The current available memory bytes
Expand All @@ -121,7 +138,7 @@ class DNNLRnnMemMgr {
* \param size byte number
* \param ctx Context of device enviroment
*/
void Init(dim_t size, const Context& ctx);
void Init(const dim_t size, const Context& ctx);

// Return the bytes number of the buffer
const size_t Size() {
Expand All @@ -135,6 +152,8 @@ class DNNLRnnMemMgr {
dnnl::memory* Alloc(const dnnl::memory::desc& md);
};

typedef std::shared_ptr<dnnl::primitive_attr> shared_dnnl_attr_t;

/*
* Rnn Primitive.
*/
Expand All @@ -144,15 +163,15 @@ class RnnPrimitive {
* lstm_forward, lbr_gru_forward, vanilla_rnn_forward
*/
template <typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
static RnnPrimitive Create(const shared_dnnl_attr_t attr, Args&&... args) {
RnnPrimitive rnn_fwd_prim;
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
[](typename rnn_fwd::primitive_desc* pd) {
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
new typename rnn_fwd::primitive_desc(
fwd_desc, attr ? *attr : dnnl::primitive_attr(), CpuEngine::Get()->get_engine()),
[](void* pd) { delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd); });
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
rnn_fwd_prim.attr_ = attr;
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.weights_proj_desc_ = fwd_pd->weights_projection_desc();
Expand All @@ -164,6 +183,7 @@ class RnnPrimitive {
}

RnnPrimitive() {
this->attr_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = dnnl::memory::desc();
Expand All @@ -173,6 +193,7 @@ class RnnPrimitive {
}

RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand All @@ -183,6 +204,7 @@ class RnnPrimitive {

RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand Down Expand Up @@ -217,9 +239,14 @@ class RnnPrimitive {
return workspace_desc_;
}

const dnnl::primitive_attr& GetPrimAttr() const {
return *attr_;
}

private:
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<dnnl::primitive> primitive_;
shared_dnnl_attr_t attr_;
dnnl::memory::desc weights_layer_desc_;
dnnl::memory::desc weights_iter_desc_;
dnnl::memory::desc weights_proj_desc_;
Expand All @@ -229,7 +256,8 @@ class RnnPrimitive {
RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params);
const NDArray& params,
const shared_dnnl_attr_t attr = nullptr);

/*
* Use this to manage memory and primitive of DNNL RNN forward inference.
Expand All @@ -240,11 +268,12 @@ class DNNLRnnForward {
const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params)
const NDArray& params,
const shared_dnnl_attr_t attr = nullptr)
: ctx_(ctx),
initialized_(false),
param_(layer_param),
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {}
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {}

void SetNewDataMem(void* x,
void* hx,
Expand All @@ -263,6 +292,10 @@ class DNNLRnnForward {
return fwd_inf_.GetPrim();
}

void ResetFwd(const NDArray& data, const NDArray& params, const shared_dnnl_attr_t& attr) {
fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr);
}

const size_t GetSize() const {
const size_t size = fwd_inf_.GetLayerDesc().get_size() + fwd_inf_.GetIterDesc().get_size() +
fwd_inf_.GetProjDesc().get_size();
Expand Down Expand Up @@ -482,13 +515,13 @@ class DNNLRnnBackward {
*/
class DNNLRnnOp {
public:
explicit DNNLRnnOp(const RNNParam& param,
explicit DNNLRnnOp(const nnvm::NodeAttrs& attrs,
const int seq_len,
const int batch_size,
const int input_size)
: initialized_(false),
weights_version_(0),
full_param_(DNNLRnnFullParamParser(param, seq_len, batch_size, input_size)) {}
full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}

void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading

0 comments on commit ea6c91a

Please sign in to comment.