From 38acd34391902945ecdbf898763f089e6b4d1eb3 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 28 Feb 2022 11:45:23 +0100 Subject: [PATCH 01/11] Restore quantized RNN sanity --- include/mxnet/op_attr_types.h | 13 + python/mxnet/contrib/quantization.py | 78 ++-- python/mxnet/io/io.py | 23 +- python/mxnet/io/utils.py | 5 + src/operator/nn/dnnl/dnnl_rnn-inl.h | 83 +++- src/operator/nn/dnnl/dnnl_rnn.cc | 179 +++++---- .../dnnl/dnnl_quantize_asym-inl.h | 161 ++++++++ .../dnnl/dnnl_quantized_rnn-inl.h | 82 ++++ .../quantization/dnnl/dnnl_quantized_rnn.cc | 365 ++++++++++++++++++ src/operator/quantization/quantize_asym-inl.h | 177 +++++++++ src/operator/quantization/quantize_asym.cc | 155 ++++++++ .../quantization/quantize_graph_pass.cc | 43 ++- src/operator/quantization/quantize_v2.cc | 2 +- src/operator/quantization/quantized_rnn-inl.h | 41 ++ src/operator/quantization/quantized_rnn.cc | 356 +++++++++++++++++ src/operator/rnn-inl.h | 12 +- src/operator/rnn.cc | 92 ++--- .../python/quantization/test_quantization.py | 74 ++++ 18 files changed, 1767 insertions(+), 174 deletions(-) create mode 100644 src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h create mode 100644 src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h create mode 100644 src/operator/quantization/dnnl/dnnl_quantized_rnn.cc create mode 100644 src/operator/quantization/quantize_asym-inl.h create mode 100644 src/operator/quantization/quantize_asym.cc create mode 100644 src/operator/quantization/quantized_rnn-inl.h create mode 100644 src/operator/quantization/quantized_rnn.cc diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index c936d3e84afa..0bc2a8f62daf 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -343,6 +343,19 @@ using FNeedRequantize = std::function; using FAvoidQuantizeInput = std::function< bool(const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be quantized asymmetrically. + */ +using FNeedAsymQuantizeInput = std::function; + +/*! + * \brief Register a function to determine if the output of a quantized operator + * needs to be dequantized. This is usually used for the quantized operators + * which can produce fp32 outputs directly. + */ +using FAvoidDequantizeOutput = std::function; + /*! * \brief Register a function to determine if the input of a quantized operator * needs to be calibrated. This is usually used for the quantized operators diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 10d2455cb9ae..4e6411135342 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -33,6 +33,20 @@ from ..util import is_np_array, wrap_ctx_to_device_func +def _multilist_iterator(arg, func): + """Iterate over multidiemnsional list and returns new list + with same dimensions, but applied `func` function on list elements. + E.g. _multilist_iterator([1, 2, [3, 4]], lambda x: x**2) = [1, 4, [9, 16]] + """ + ret = [] + if isinstance(arg, list): + for el in arg: + ret.append(_multilist_iterator(el, func)) + else: + return func(arg) + + return ret + def _quantize_params(qsym, params, min_max_dict): """Given a quantized symbol and a dict of params that have not been quantized, generate quantized params. Currently only supports quantizing the arg_params @@ -357,7 +371,7 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_ for batch in data: if not isinstance(batch, list): batch = [batch] - batch = [b.as_in_context(mx.cpu()) for b in batch] + batch = _multilist_iterator(batch, lambda b: b.as_in_context(mx.cpu())) sym_block(*batch[:num_inputs]) num_batches += 1 if num_calib_batches is not None and num_batches >= num_calib_batches: @@ -368,20 +382,41 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_ def _generate_list_of_data_desc(data_shapes, data_types): - """"Convert list ot tuples to list of DataDesc.""" - if isinstance(data_shapes, list): - if all(isinstance(x, DataDesc) for x in data_shapes): - return data_shapes - if all(isinstance(x, tuple) for x in data_shapes): - if len(data_shapes) == 1: - data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])] + """"Convert list of tuples to list of DataDesc.""" + def flatten_list(arg): + ret = [] + for el in arg: + if isinstance(el, list): + ret += flatten_list(el) else: - data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i], - dtype=data_types[i]) for i in range(len(data_shapes))] - return data_shapes - raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple') + ret.append(el) + return ret + + flattened_data_types = flatten_list(data_types) + flattened_data_shapes = flatten_list(data_shapes) + assert len(flattened_data_types) == len(flattened_data_shapes) + + # pass integral type as reference + counter = [0] + def get_data_desc(data_shape, counter=counter, data_types=flattened_data_types): + if isinstance(data_shape, DataDesc): + return data_shape + elif isinstance(data_shape, tuple): + desc = DataDesc(name='data' + str(counter[0]), shape=data_shape, + dtype=data_types[counter[0]]) + counter[0] += 1 + return desc + else: + raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple') + if len(data_shapes) == 1 and not isinstance(data_shapes[0], list): + data_descs = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])] + else: + data_descs = _multilist_iterator(data_shapes, get_data_desc) + + return data_descs + @wrap_ctx_to_device_func def quantize_model(sym, arg_params, aux_params, data_names=('data',), device=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', @@ -841,8 +876,8 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize x = iter(calib_data) batch = next(x) if isinstance(batch, list): - data_shapes = [b.shape for b in batch] - data_types = [b.dtype for b in batch] + data_shapes = _multilist_iterator(batch, lambda x: x.shape) + data_types = _multilist_iterator(batch, lambda x: x.dtype) else: data_shapes = [batch.shape] data_types = [batch.dtype] @@ -850,16 +885,15 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize raise ValueError('calib_data expects mx.gluon.data.DataLoader') if data_types is None: - data_types = [mx_real_t] * len(data_shapes) + data_types = _multilist_iterator(data_shapes, lambda x: mx_real_t) + data_descs = _generate_list_of_data_desc(data_shapes, data_types) num_inputs = len(data_descs) data_nd = [] - for desc in data_descs: - if is_np_array(): - data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype)) - else: - data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype)) + arr_fn = mx.np if is_np_array() else mx.nd + data_nd = _multilist_iterator(data_descs, lambda d, F=arr_fn: F.zeros(shape=d.shape, dtype=d.dtype)) + while True: try: network(*data_nd) @@ -919,7 +953,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize raise ValueError( 'calib_data must be provided when calib_mode=%s' % calib_mode) if calib_mode in ['naive', 'entropy', 'custom']: - inputs = [mx.sym.var(desc.name) for desc in data_descs] + inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name)) calib_net = SymbolBlock(symnet, inputs) for k, v in calib_net.collect_params().items(): v.grad_req = 'null' @@ -939,7 +973,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize else: raise ValueError('calib_mode has to be one of: naive, entropy, custom') elif calib_mode is not None and calib_mode == 'none': - inputs = [mx.sym.var(desc.name) for desc in data_descs] + inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name)) net = SymbolBlock(qsym, inputs) for k, v in net.collect_params().items(): diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 4d78cd999bae..013f401108f8 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -37,7 +37,7 @@ from ..ndarray import array from ..ndarray import concat, tile -from .utils import _init_data, _has_instance, _getdata_by_idx +from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -602,10 +602,12 @@ class NDArrayIter(DataIter): The data name. label_name : str, optional The label name. + layout : str, optional + The data layout. """ def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', - label_name='softmax_label'): + label_name='softmax_label', layout='NCHW'): super(NDArrayIter, self).__init__(batch_size) self.data = _init_data(data, allow_empty=False, default_name=data_name) @@ -631,20 +633,27 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, # used for 'roll_over' self._cache_data = None self._cache_label = None + self.layout = layout @property def provide_data(self): """The name and shape of data provided by this iterator.""" + batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) + DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ + [self.batch_size] + list(v.shape[batch_axis + 1:])), + v.dtype, layout=self.layout) for k, v in self.data ] @property def provide_label(self): """The name and shape of label provided by this iterator.""" + batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) + DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ + [self.batch_size] + list(v.shape[batch_axis + 1:])), + v.dtype, layout=self.layout) for k, v in self.label ] @@ -681,7 +690,7 @@ def next(self): data = self.getdata() label = self.getlabel() # iter should stop when last batch is not complete - if data[0].shape[0] != self.batch_size: + if data[0].shape[self.layout.find('N')] != self.batch_size: # in this case, cache it for next epoch self._cache_data = data self._cache_label = label @@ -697,7 +706,7 @@ def _getdata(self, data_source, start=None, end=None): end = data_source[0][1].shape[0] if data_source else 0 s = slice(start, end) return [ - x[1][s] + _slice_along_batch_axis(x[1], s, self.layout.find('N')) if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) array(x[1][sorted(self.idx[s])][[ @@ -716,7 +725,7 @@ def _concat(self, first_data, second_data): concat( first_data[i], second_data[i], - dim=0 + dim=self.layout.find('N') ) for i in range(len(first_data)) ] diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py index 55ba34aea426..55f228f4556d 100644 --- a/python/mxnet/io/utils.py +++ b/python/mxnet/io/utils.py @@ -84,3 +84,8 @@ def _getdata_by_idx(data, idx): shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) return shuffle_data + +def _slice_along_batch_axis(data, s, batch_axis): + """Apply slice along the batch axis""" + ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop) + return ret \ No newline at end of file diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h b/src/operator/nn/dnnl/dnnl_rnn-inl.h index f28753461e58..fafed4914d64 100644 --- a/src/operator/nn/dnnl/dnnl_rnn-inl.h +++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h @@ -32,10 +32,42 @@ #include "operator/rnn-inl.h" #include "dnnl_base-inl.h" +#include "operator/quantization/quantized_rnn-inl.h" namespace mxnet { namespace op { +struct DNNLRnnParam : public dmlc::Parameter { + bool quantized; + + DMLC_DECLARE_PARAMETER(DNNLRnnParam) { + DMLC_DECLARE_FIELD(quantized).set_default(false).describe( + "Whether it's a quantized RNN operator"); + } +}; + +inline void DNNLMemoryReorder(const dnnl::memory& src, const dnnl::memory& dst) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map reorderPrimitives; +#else + static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; +#endif + OpSignature key{}; + key.AddSign(src); + key.AddSign(dst); + + auto it = reorderPrimitives.find(key); + if (it == reorderPrimitives.end()) { + auto reorder = dnnl::reorder(src, dst); + it = AddToCache(&reorderPrimitives, key, reorder); + } + + dnnl_args_map_t net_args; + net_args.emplace(DNNL_ARG_SRC, src); + net_args.emplace(DNNL_ARG_DST, dst); + DNNLStream::Get()->RegisterPrimArgs(it->second, net_args); +} + struct DNNLRnnLayerParam { using memory = dnnl::memory; using dims = dnnl::memory::dims; @@ -66,6 +98,10 @@ struct DNNLRnnLayerParam { size_t native_single_b_size; // bias size of a single cell from framework size_t single_state_size; // state size of a single cell, hy, cy + bool quantized; // whether this layer is quantized + bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the + // quantized rnn operator + DNNLRnnLayerParam(int num_layer, index_t batch_size, index_t seq_len, @@ -82,7 +118,9 @@ struct DNNLRnnLayerParam { input_size(input_size), state_size(state_size), proj_size(proj_size), - seq_len(seq_len) {} + seq_len(seq_len), + quantized(false), + enable_u8_output(false) {} void SetDims(); }; @@ -90,10 +128,11 @@ struct DNNLRnnLayerParam { typedef std::vector LayerParamVector; struct DNNLRnnFullParam { RNNParam default_param; + DNNLRnnParam dnnl_param; LayerParamVector layer_params; }; -DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, +DNNLRnnFullParam DNNLRnnFullParamParser(const nnvm::NodeAttrs& attrs, const index_t seq_len, const index_t batch_size, const index_t input_size); @@ -105,7 +144,7 @@ class DNNLRnnMemMgr { // The memory buffer in NDArray life-cycle NDArray workspace_; // This points to the memory buffer from a NDArray - char* curr_mem; + char* curr_mem = nullptr; // The total bytes of the workspace of a DNNLRnnOp size_t mem_size = 0; // The current available memory bytes @@ -121,7 +160,7 @@ class DNNLRnnMemMgr { * \param size byte number * \param ctx Context of device enviroment */ - void Init(dim_t size, const Context& ctx); + void Init(const dim_t size, const Context& ctx); // Return the bytes number of the buffer const size_t Size() { @@ -135,6 +174,8 @@ class DNNLRnnMemMgr { dnnl::memory* Alloc(const dnnl::memory::desc& md); }; +typedef std::shared_ptr shared_dnnl_attr_t; + /* * Rnn Primitive. */ @@ -144,15 +185,15 @@ class RnnPrimitive { * lstm_forward, lbr_gru_forward, vanilla_rnn_forward */ template - static RnnPrimitive Create(Args&&... args) { + static RnnPrimitive Create(const shared_dnnl_attr_t attr, Args&&... args) { RnnPrimitive rnn_fwd_prim; auto fwd_desc = typename rnn_fwd::desc(std::forward(args)...); rnn_fwd_prim.fwd_pd_.reset( - new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()), - [](typename rnn_fwd::primitive_desc* pd) { - delete reinterpret_cast(pd); - }); + new typename rnn_fwd::primitive_desc( + fwd_desc, attr ? *attr : dnnl::primitive_attr(), CpuEngine::Get()->get_engine()), + [](void* pd) { delete reinterpret_cast(pd); }); auto fwd_pd = reinterpret_cast(rnn_fwd_prim.fwd_pd_.get()); + rnn_fwd_prim.attr_ = attr; rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc(); rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc(); rnn_fwd_prim.weights_proj_desc_ = fwd_pd->weights_projection_desc(); @@ -164,6 +205,7 @@ class RnnPrimitive { } RnnPrimitive() { + this->attr_ = nullptr; this->fwd_pd_ = nullptr; this->primitive_ = nullptr; this->weights_layer_desc_ = dnnl::memory::desc(); @@ -173,6 +215,7 @@ class RnnPrimitive { } RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -183,6 +226,7 @@ class RnnPrimitive { RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { if (this != &rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -217,9 +261,14 @@ class RnnPrimitive { return workspace_desc_; } + const dnnl::primitive_attr& GetPrimAttr() const { + return *attr_; + } + private: std::shared_ptr fwd_pd_; std::shared_ptr primitive_; + shared_dnnl_attr_t attr_; dnnl::memory::desc weights_layer_desc_; dnnl::memory::desc weights_iter_desc_; dnnl::memory::desc weights_proj_desc_; @@ -229,7 +278,8 @@ class RnnPrimitive { RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params); + const NDArray& params, + const shared_dnnl_attr_t attr = nullptr); /* * Use this to manage memory and primitive of DNNL RNN forward inference. @@ -240,11 +290,12 @@ class DNNLRnnForward { const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) + const NDArray& params, + const shared_dnnl_attr_t attr = nullptr) : ctx_(ctx), initialized_(false), param_(layer_param), - fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {} + fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {} void SetNewDataMem(void* x, void* hx, @@ -263,6 +314,10 @@ class DNNLRnnForward { return fwd_inf_.GetPrim(); } + void ResetFwd(const NDArray& data, const NDArray& params, const shared_dnnl_attr_t& attr) { + fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr); + } + const size_t GetSize() const { const size_t size = fwd_inf_.GetLayerDesc().get_size() + fwd_inf_.GetIterDesc().get_size() + fwd_inf_.GetProjDesc().get_size(); @@ -482,13 +537,13 @@ class DNNLRnnBackward { */ class DNNLRnnOp { public: - explicit DNNLRnnOp(const RNNParam& param, + explicit DNNLRnnOp(const nnvm::NodeAttrs& attrs, const int seq_len, const int batch_size, const int input_size) : initialized_(false), weights_version_(0), - full_param_(DNNLRnnFullParamParser(param, seq_len, batch_size, input_size)) {} + full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} void Forward(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/nn/dnnl/dnnl_rnn.cc b/src/operator/nn/dnnl/dnnl_rnn.cc index 0d65eb99350d..8b4c585a12a6 100644 --- a/src/operator/nn/dnnl/dnnl_rnn.cc +++ b/src/operator/nn/dnnl/dnnl_rnn.cc @@ -33,6 +33,8 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(DNNLRnnParam); + inline int GetRnnGatesNum(int mode) { switch (mode) { case rnn_enum::kLstm: @@ -88,13 +90,28 @@ void DNNLRnnLayerParam::SetDims() { reserve_size = 0; } -DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, +DNNLRnnFullParam DNNLRnnFullParamParser(const NodeAttrs& attrs, const index_t seq_len, const index_t batch_size, const index_t input_size) { + const RNNParam& rnn_param = nnvm::get(attrs.parsed); DNNLRnnFullParam full_param; full_param.default_param = rnn_param; - const int state_size = rnn_param.state_size; + try { + full_param.dnnl_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs.op->name << "(" + << "name=\"" << attrs.name << "\""; + for (const auto& k : attrs.dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + + const int state_size = rnn_param.state_size; const int proj_size = rnn_param.projection_size.has_value() ? rnn_param.projection_size.value() : -1; const int iter_size = @@ -135,15 +152,20 @@ DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, false); } - // Set dims, workspace size, and state_outputs flag + // Set dims, workspace size, state_outputs, quantized and enable_u8_output flag for (auto& layer_param : layer_params) { layer_param.SetDims(); - layer_param.state_outputs = rnn_param.state_outputs; + layer_param.state_outputs = rnn_param.state_outputs; + layer_param.quantized = full_param.dnnl_param.quantized; + layer_param.enable_u8_output = true; } + // Quantized RNN operator produces kFloat32 outputs. + if (full_param.dnnl_param.quantized) + layer_params.back().enable_u8_output = false; return full_param; } -void DNNLRnnMemMgr::Init(dim_t size, const Context& ctx) { +void DNNLRnnMemMgr::Init(const dim_t size, const Context& ctx) { workspace_ = NDArray(TShape({size}), ctx, false, mshadow::kUint8); if (workspace_.data().dptr_ == nullptr) LOG(FATAL) << "oneDNN RNN operator memory allocation error."; @@ -178,54 +200,65 @@ dnnl::memory* DNNLRnnMemMgr::Alloc(const dnnl::memory::desc& md) { RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) { + const NDArray& params, + const shared_dnnl_attr_t attr) { using namespace dnnl; - using tag = dnnl::memory::format_tag; - const int mode = layer_param.mode; - memory::data_type data_type = get_dnnl_type(data.dtype()); - memory::data_type weight_type = get_dnnl_type(params.dtype()); + using tag = dnnl::memory::format_tag; + const int mode = layer_param.mode; + memory::data_type src_layer_dtype = get_dnnl_type(data.dtype()); + memory::data_type iter_dtype = get_dnnl_type(mshadow::kFloat32); + memory::data_type weight_dtype = + get_dnnl_type(layer_param.quantized ? mshadow::kInt8 : params.dtype()); + memory::data_type bias_dtype = get_dnnl_type(mshadow::kFloat32); + memory::data_type dst_layer_dtype = + get_dnnl_type((layer_param.quantized && layer_param.enable_u8_output) ? mshadow::kUint8 : + mshadow::kFloat32); + const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference; const rnn_direction dnnl_rnn_direction = layer_param.bidirectional ? rnn_direction::bidirectional_concat : rnn_direction::unidirectional; - auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); - auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); - auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); - auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); - auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); - auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); - auto src_cell_desc = memory::desc(layer_param.cell_dims, data_type, tag::ldnc); + auto src_layer_desc = memory::desc(layer_param.src_dims, src_layer_dtype, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_dtype, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_dtype, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, bias_dtype, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, dst_layer_dtype, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc); + auto src_cell_desc = memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc); auto weight_peep_desc = memory::desc(); auto weight_proj_desc = layer_param.proj_size > 0 ? - memory::desc(layer_param.weight_proj_dims, weight_type, tag::any) : + memory::desc(layer_param.weight_proj_dims, weight_dtype, tag::any) : memory::desc(); auto dst_state_desc = layer_param.state_outputs ? - memory::desc(layer_param.state_dims, data_type, tag::ldnc) : + memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc) : memory::desc(); - auto dst_cell_desc = layer_param.state_outputs ? - memory::desc(layer_param.cell_dims, data_type, tag::ldnc) : - memory::desc(); + auto dst_cell_desc = + layer_param.state_outputs ? + memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc) : // no cell in 1.x + memory::desc(); auto fwd = RnnPrimitive(); switch (mode) { case rnn_enum::kLstm: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, dnnl_rnn_direction, src_layer_desc, src_state_desc, src_cell_desc, weight_layer_desc, weight_iter_desc, - weight_peep_desc, - weight_proj_desc, + weight_peep_desc, // peep new + weight_proj_desc, // proj new bias_desc, dst_layer_desc, dst_state_desc, dst_cell_desc); break; case rnn_enum::kGru: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, dnnl_rnn_direction, src_layer_desc, src_state_desc, @@ -238,6 +271,7 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: fwd = RnnPrimitive::Create( + attr, prop, mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, dnnl_rnn_direction, @@ -449,11 +483,19 @@ void DNNLRnnForward::SetNewDataMem(void* x, auto& cpu_engine = CpuEngine::Get()->get_engine(); dnnl_args_map_t& args = net_args_; + int src_dtype = dtype; + int dst_dtype = dtype; + if (param_.quantized) { + src_dtype = mshadow::kUint8; + if (param_.enable_u8_output) + dst_dtype = mshadow::kUint8; + } + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); // Set various data memory - RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype); - RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype); + RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, src_dtype); + RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dst_dtype); RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype); if (param_.state_outputs) { @@ -468,37 +510,30 @@ void DNNLRnnForward::SetNewDataMem(void* x, } } -inline void DNNLMemoryReorder(const dnnl::memory& src, const dnnl::memory& dst) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map reorderPrimitives; -#else - static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; -#endif - OpSignature key{}; - key.AddSign(src); - key.AddSign(dst); - - auto it = reorderPrimitives.find(key); - if (it == reorderPrimitives.end()) { - auto reorder = dnnl::reorder(src, dst); - it = AddToCache(&reorderPrimitives, key, reorder); - } - - dnnl_args_map_t net_args; - net_args.emplace(DNNL_ARG_SRC, src); - net_args.emplace(DNNL_ARG_DST, dst); - DNNLStream::Get()->RegisterPrimArgs(it->second, net_args); -} - /* * Reorder the concatenated weights memory to a efficient memory block * with primitive-prefered format. */ void DNNLRnnForward::ReorderWeights() { - DNNLMemoryReorder(*weights_layer_r_, *weights_layer_); - DNNLMemoryReorder(*weights_iter_r_, *weights_iter_); - if (param_.proj_size > 0) - DNNLMemoryReorder(*weights_proj_r_, *weights_proj_); + if (param_.quantized) { + const dnnl::primitive_attr& attr = this->fwd_inf_.GetPrimAttr(); + auto ReorderWithAttr = [&](dnnl::memory& src, dnnl::memory& dst) { + auto reorder_pd = dnnl::reorder::primitive_desc(src, dst, attr); + dnnl_args_map_t net_args; + net_args[DNNL_ARG_SRC] = src; + net_args[DNNL_ARG_DST] = dst; + DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), net_args); + }; + ReorderWithAttr(*weights_layer_r_, *weights_layer_); + ReorderWithAttr(*weights_iter_r_, *weights_iter_); + if (param_.proj_size > 0) + ReorderWithAttr(*weights_proj_r_, *weights_proj_); + } else { + DNNLMemoryReorder(*weights_layer_r_, *weights_layer_); + DNNLMemoryReorder(*weights_iter_r_, *weights_iter_); + if (param_.proj_size > 0) + DNNLMemoryReorder(*weights_proj_r_, *weights_proj_); + } } void AdjustGruGateOrder(char* weight, @@ -573,7 +608,7 @@ inline void EmplaceNetArgs(dnnl_args_map_t* net_args, const int arg_name, const */ void DNNLRnnForward::SetWeightsMem(void* w_ptr, void* b_ptr, const bool is_train, const int dtype) { using format_tag = dnnl::memory::format_tag; - auto dnnl_dtype = get_dnnl_type(dtype); + const auto dnnl_dtype = get_dnnl_type(dtype); const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); const size_t buffer_bytes = @@ -702,7 +737,7 @@ void DNNLRnnForward::SetWeightsMem(void* w_ptr, void* b_ptr, const bool is_train // in forward training path, we use plain memory (ldxxx) as the space for weights and // their gradients. Then, forward training primitives could fetch them from the scope // of forward inference. And from there, we don't need to reorder the plain memory to - // the optimal rnn-packed memory for forward inference. + // the optimal rnn-packed memory for forward inference ReorderWeights(); initialized_ = true; } @@ -764,6 +799,19 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, const std::vector& outputs) { using format_tag = dnnl::memory::format_tag; + // Get the bytes of a real type + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + // In the `autograd.record()` context, RNNOp is required to run into // `forward_training` mode. const bool is_training = (op_ctx.is_train || op_ctx.need_grad); @@ -772,7 +820,7 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, if (fwd_inf_vec_.size() < num_fusion) { for (auto& layer_param : full_param_.layer_params) { fwd_inf_vec_.emplace_back( - ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], nullptr); } } @@ -783,19 +831,6 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, } } - // Get the bytes of a real type - const NDArray& weights = inputs[rnn_enum::kParams]; - int dtype = weights.dtype(); - size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); - - const RNNParam& default_param = full_param_.default_param; - char* weights_ptr = static_cast(weights.data().dptr_); - char* bias_ptr = - weights_ptr + (weights.data().Size() - GetRnnBiasSize(default_param.num_layers, - default_param.state_size, - default_param.bidirectional + 1, - default_param.mode)) * - dtype_bytes; for (auto& fwd_layer : fwd_inf_vec_) { size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; @@ -819,7 +854,7 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, CHECK_EQ(num_fusion, fwd_inf_vec_.size()) << "Layer vector's size has a different value than the number of fusion."; if (dst_.size() < num_fusion - 1) { - int data_dtype = outputs[rnn_enum::kOut].dtype(); + const int data_dtype = outputs[rnn_enum::kOut].dtype(); const size_t data_dbytes = mshadow::mshadow_sizeof(data_dtype); mgr_.Init((outputs[rnn_enum::kOut].data().Size() * data_dbytes + kDNNLAlign) * (num_fusion - 1), op_ctx.run_ctx.ctx); @@ -1121,7 +1156,7 @@ void DNNLRnnOp::Forward(const OpContext& ctx, } // Get data type - int data_dtype = inputs[rnn_enum::kData].dtype(); + int data_dtype = outputs[rnn_enum::kOut].dtype(); // Get temporary memory for output, state_out, statecell_out const int num_layers = default_param.num_layers; const int seq_length = default_param.seq_length_; diff --git a/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h new file mode 100644 index 000000000000..83e72e0a0d9e --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation using DNNL + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include "operator/nn/dnnl/dnnl_base-inl.h" +#include "operator/quantization/quantize_asym-inl.h" + +namespace mxnet { +namespace op { + +class DNNLQuantizeAsymOp { + public: + explicit DNNLQuantizeAsymOp(const nnvm::NodeAttrs& attrs) + : param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + QuantizeAsymParam param_; + bool initialized_{false}; + float cached_scale_{0.f}; + float cached_shift_{0.f}; + dnnl::memory::desc o_desc_; + dnnl_args_map_t args_; + std::shared_ptr fwd_pd_; +}; + +void DNNLQuantizeAsymOp::Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + NDArray in_buffer = inputs[0]; + float scale = 0.f; + float shift = 0.f; + + // Pass through quantized data + if (inputs[0].dtype() == mshadow::kUint8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 0; + if (req[0] != kWriteInplace) { + const_cast(outputs[0]).CopyFrom(*inputs[0].GetDNNLData()); + DNNLStream::Get()->Submit(); + } + } else { + in_buffer = inputs[0].Reorder2Default(); + const dnnl::memory* i_mem = in_buffer.GetDNNLData(); + float* in_ptr = in_buffer.data().dptr(); + const int nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (inputs[0].dtype() == mshadow::kInt8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 128; +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); ++i) { + in_ptr[i] += 128.0f; + } + } else if (inputs[0].dtype() == mshadow::kFloat32) { + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + scale = + MaxValue() / (param_.max_calib_range.value() - param_.min_calib_range.value()); + shift = MaxValue() - param_.max_calib_range.value() * scale; + } else { + float data_min = mshadow::red::limits::MaxValue(); + float data_max = mshadow::red::limits::MinValue(); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { + int tid = omp_get_thread_num(); + if (in_ptr[i] > data_maxs[tid]) + data_maxs[tid] = in_ptr[i]; + if (in_ptr[i] < data_mins[tid]) + data_mins[tid] = in_ptr[i]; + } + for (index_t i = 0; i < nthreads; i++) { + if (data_maxs[i] > data_max) + data_max = data_maxs[i]; + if (data_mins[i] < data_min) + data_min = data_mins[i]; + } + scale = MaxValue() / (data_max - data_min); + shift = MaxValue() - data_max * scale; + } + + if (initialized_ && (cached_scale_ != scale || cached_shift_ != shift)) + initialized_ = false; + } + + *outputs[1].data().dptr() = scale; + *outputs[2].data().dptr() = shift; + + if (!initialized_) { + cached_scale_ = scale; + cached_shift_ = shift; + dnnl::primitive_attr attr; + attr.set_rnn_data_qparams(scale, shift); + const dnnl::engine& cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + const dnnl::memory::desc& i_desc = i_mem->get_desc(); + o_desc_ = i_desc; + o_desc_.data.data_type = get_dnnl_type_t(outputs[0].dtype()); + dnnl::reorder::primitive_desc reorder_pd(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); + initialized_ = true; + } + dnnl_output_t o_mem = CreateDNNLMem(outputs[0], o_desc_, req[0]); + args_[DNNL_ARG_FROM] = *i_mem; + args_[DNNL_ARG_TO] = *o_mem.second; + DNNLStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); + CommitOutput(outputs[0], o_mem); + DNNLStream::Get()->Submit(); + } +} + +void DNNLQuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (inputs[0].shape().ndim() == 3 && inputs[0].dtype() == mshadow::kFloat32) { + DNNLQuantizeAsymOp& op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } else { + FallBackCompute(QuantizeAsymForward, state_ptr, ctx, inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ \ No newline at end of file diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h new file mode 100644 index 000000000000..e3c28e6a4711 --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include "operator/nn/dnnl/dnnl_rnn-inl.h" +#include "operator/rnn-inl.h" +#include "operator/quantization/quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +class DNNLQuantizedRnnOp { + public: + explicit DNNLQuantizedRnnOp(const nnvm::NodeAttrs& attrs, + const int seq_len, + const int batch_size, + const int input_size) + : initialized_(false), + weights_ver_(0), + rnn_attr_(new dnnl::primitive_attr), + full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} + + void Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + bool initialized_; + size_t weights_ver_; + shared_dnnl_attr_t rnn_attr_; + DNNLRnnFullParam full_param_; + DNNLRnnMemMgr mgr_; + std::vector fwd_inf_vec_; // forward inference layers + + // Used to store the intermediate results of multi-layer + std::vector dst_; + // According to + // https://intel.github.io/mkl-dnn/dev_guide_int8_computations.html, the + // non-symmetric quantization is assumed by LSTM primitive. Namely, the + // formula is: + // data_f32 = (data_u8 - shift) / scale + float cached_data_shift_{0.0}; + float cached_data_scale_{0.0}; + void Init(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ \ No newline at end of file diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc new file mode 100644 index 000000000000..7ecb5ec58184 --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#if MXNET_USE_ONEDNN == 1 + +#include "operator/quantization/quantization_utils.h" +#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +std::vector GetDNNLRnnWeightsQParams(const DNNLRnnFullParam& full_param, float* w_ptr) { + const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const RNNParam& default_param = full_param.default_param; + const LayerParamVector& layer_params = full_param.layer_params; + + const DNNLRnnLayerParam& layer_param0 = layer_params.at(0); + const size_t w_size0 = layer_param0.single_w_size; + const size_t wx_size0 = 4 * layer_param0.state_size * layer_param0.input_size; + const size_t wh_size0 = 4 * layer_param0.state_size * layer_param0.state_size; + + int directions = 1; + float* wx = w_ptr; + float* wh = wx + wx_size0; + float* fake_wx = wx; + float* fake_wh = wh; + + std::vector wx_goi_max; + std::vector wh_goi_max; + if (default_param.bidirectional) { + directions = 2; + wx_goi_max.resize(wx_size0); + wh_goi_max.resize(wh_size0); + fake_wx = wx_goi_max.data(); + fake_wh = wh_goi_max.data(); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wx_size0); ++i) { + fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]); + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]); + } + } + std::vector w_max(4 * layer_param0.state_size, 0.0); + const index_t input_size = layer_param0.input_size; // input + const index_t state_size = layer_param0.state_size; // state + const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp_max = w_max[go]; + for (index_t i = 0; i < input_size; ++i) { + tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max); + } + for (index_t i = 0; i < state_size; ++i) { + tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max); + } + w_max[go] = tmp_max; + } + wx += layer_param0.single_w_size * directions; + wh += layer_param0.single_w_size * directions; + + std::vector goi_max(wh_size0, 0.0); + for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) { + const DNNLRnnLayerParam& layer_param = layer_params.at(lyr); + const int weight_nblks = layer_param.num_layer * directions; + for (int blk = 0; blk < weight_nblks; ++blk) { +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + goi_max[i] = MaxAbs(wx[i], wh[i]); + } + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp = w_max[go]; +// NOTES: min/max reductions were supported since OpenMP 3.1, which was +// released in Jul 2011 (hence the version number). +#if _OPENMP >= 201107 +#pragma omp parallel for reduction(max : tmp) num_threads(nthreads) +#endif + for (index_t i = 0; i < state_size; ++i) { + tmp = Max(goi_max[go * state_size + i], tmp); + } + w_max[go] = tmp; + } + } + wx += layer_param.single_w_size * directions; + wh = wx + wh_size0; + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(w_max.size()); ++i) { + w_max[i] = mshadow::red::limits::MaxValue() / w_max[i]; + } + return w_max; +} + +void DNNLQuantizedRnnOp::Init(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using format_tag = dnnl::memory::format_tag; + + // Get the bytes of a real type + const Context& ctx = op_ctx.run_ctx.ctx; + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + int weights_dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + + // In the `autograd.record()` context, RNNOp is required to run into + // `forward_training` mode. + + const size_t num_fusion = full_param_.layer_params.size(); + if (fwd_inf_vec_.size() < num_fusion) { + size_t buffer_size = 0; // Element number, instead of bytes, in the buffer + for (auto& layer_param : full_param_.layer_params) { + buffer_size += layer_param.workspace_size + layer_param.reserve_size; + } + buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1); + buffer_size += kDNNLAlign * num_fusion * 5; // Add margin for alignment + + for (auto& layer_param : full_param_.layer_params) { + fwd_inf_vec_.emplace_back( + ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], rnn_attr_); + buffer_size += fwd_inf_vec_.back().GetSize(); + } + mgr_.Init(buffer_size, ctx); + } + + for (auto& fwd_layer : fwd_inf_vec_) { + size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; + size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; + size_t layer_weights_bytes = single_w_bytes * directions; + size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias + + if (!fwd_layer.IsInitialized()) + fwd_layer.SetWeightsMem(weights_ptr, bias_ptr, false, weights_dtype); + weights_ptr += layer_weights_bytes; + bias_ptr += layer_bias_bytes; + } + + CHECK_EQ(num_fusion, fwd_inf_vec_.size()) + << "Layer vector's size has a different value than the number of fusion."; + if (dst_.size() < num_fusion - 1) { + const int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate + // results of the multiple fused layers. And for the result of the last + // fused layer, `outputs[rnn_enum::kOut]` could provide the space. Hence, + // `forward_inf_vec_.back()` is excluded when allocates the spaces for + // intermediate results. + for (std::vector::const_iterator fwd = fwd_inf_vec_.begin(); + fwd != fwd_inf_vec_.end() - 1; + ++fwd) + dst_.push_back( + mgr_.Alloc({fwd->GetParam().dst_dims, get_dnnl_type(data_dtype), format_tag::tnc})); + } + + initialized_ = true; +} + +template +inline void RegisterDNNLRnn(DNNLRnnX const& rnn) { + DNNLStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap()); +} + +template <> +inline void RegisterDNNLRnn(DNNLRnnBackward const& rnn) { + DNNLStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); + rnn.SetNativeWeightsGrads(); +} + +void DNNLQuantizedRnnOp::Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(op_ctx.requested[0]); + + const RNNParam& default_param = full_param_.default_param; + const uint32_t num_base_inputs = GetRnnNumInputs(default_param); + float data_scale = inputs[num_base_inputs + quantized_rnn::kDataScale].data().dptr()[0]; + float data_shift = inputs[num_base_inputs + quantized_rnn::kDataShift].data().dptr()[0]; + + const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && + weights_ver_ != inputs[rnn_enum::kParams].version()) ? + true : + false; + const NDArray& weights = inputs.at(rnn_enum::kParams); + float* weights_ptr = weights.data().dptr(); + if (!initialized_ || fwd_inf_vec_.empty()) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + rnn_attr_->set_rnn_data_qparams(data_scale, data_shift); + if (need_reset_weight || fwd_inf_vec_.empty()) + rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4), + GetDNNLRnnWeightsQParams(full_param_, weights_ptr)); + } + + // Initialize weights version + if (!initialized_ && weights_ver_ == 0) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!fwd_inf_vec_.empty() && + ((cached_data_scale_ != data_scale || cached_data_shift_ != data_shift))) { + initialized_ = false; + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + // Check if weights NDArray was changed. If so, reset initialized_ + if (fwd_inf_vec_.size() > 0 && weights_ver_ != inputs[rnn_enum::kParams].version()) { + initialized_ = false; + for (auto& fwd : fwd_inf_vec_) + fwd.Reset(); + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!initialized_ || fwd_inf_vec_.empty()) { + Init(op_ctx, inputs, req, outputs); + } + + // Get data type + int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Get temporary memory for output, state_out, statecell_out + const int num_layers = default_param.num_layers; + const int seq_length = default_param.seq_length_; + const int batch_size = default_param.batch_size_; + const int state_size = default_param.state_size; + const int directions = default_param.bidirectional ? 2 : 1; + dnnl::memory::desc dst_desc({seq_length, batch_size, directions * state_size}, + get_dnnl_type(data_dtype), + dnnl::memory::format_tag::tnc); + dnnl::memory::desc state_desc({num_layers, directions, batch_size, state_size}, + get_dnnl_type(data_dtype), + dnnl::memory::format_tag::ldnc); + auto out_mem = CreateDNNLMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]); + dnnl_output_t stateout_mem; + dnnl_output_t statecellout_mem; + + // Get input & output NDArray + char* src = static_cast(inputs[rnn_enum::kData].data().dptr_); + char* src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); + char* dst = static_cast(out_mem.second->get_data_handle()); + char* dst_state = nullptr; // Output state + char* src_state_cell = nullptr; // Used in LSTM for cell state + char* dst_state_cell = nullptr; // Used in LSTM for cell state + const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { + stateout_mem = + CreateDNNLMem(outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]); + dst_state = static_cast(stateout_mem.second->get_data_handle()); + } + + if (default_param.mode == rnn_enum::kLstm) { + src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); + if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { + statecellout_mem = + CreateDNNLMem(outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]); + dst_state_cell = static_cast(statecellout_mem.second->get_data_handle()); + } + } + + if (fwd_inf_vec_.size() == 1) { + fwd_inf_vec_.front().SetNewDataMem( + src, src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); + } else { + CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error."; + size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + // Set input data memory for the first layer. This stores intermediate + // output results in this->xxx, used as the source input of the next layer. + fwd_inf_vec_.front().SetNewDataMem(src, + src_state, + src_state_cell, + this->dst_.front()->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... + for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) { + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), + src_state, + src_state_cell, + this->dst_.at(lyr)->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + } + // Set output data memory for the last layer. + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), + src_state, + src_state_cell, + dst, + dst_state, + dst_state_cell, + data_dtype); + } + + for (auto& inf_lyr : fwd_inf_vec_) + RegisterDNNLRnn(inf_lyr); + + CommitOutput(outputs[rnn_enum::kOut], out_mem); + if (default_param.state_outputs) { + CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem); + if (default_param.mode == rnn_enum::kLstm) + CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem); + } + DNNLStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 \ No newline at end of file diff --git a/src/operator/quantization/quantize_asym-inl.h b/src/operator/quantization/quantize_asym-inl.h new file mode 100644 index 000000000000..4d3fb554db8d --- /dev/null +++ b/src/operator/quantization/quantize_asym-inl.h @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ + +#include +#include +#include +#include +#include + +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" +#include "./quantization_utils.h" + +namespace mxnet { +namespace op { + +struct QuantizeAsymParam : public dmlc::Parameter { + dmlc::optional min_calib_range; + dmlc::optional max_calib_range; + + DMLC_DECLARE_PARAMETER(QuantizeAsymParam) { + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe( + "The minimum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe( + "The maximum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + } +}; + +// quantize float to uint8_t +struct quantize_asymmetric { + template + MSHADOW_XINLINE static void Map(int i, + DstDType* out, + float* oscale, + float* oshift, + const SrcDType* in, + const float scale, + const float shift) { + out[i] = static_cast(in[i] * scale + shift + 0.5); + *oscale = scale; + *oshift = shift; + } +}; + +template +class QuantizeAsymOp { + public: + explicit QuantizeAsymOp(const nnvm::NodeAttrs& attrs) : attrs_(attrs) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + + CHECK_EQ(outputs[0].type_flag_, mshadow::kUint8) + << "Asymmetric quantization only supports uint8 outputs."; + mshadow::Stream* s = ctx.get_stream(); + const int input_data_dtype = inputs[0].type_flag_; + if (input_data_dtype == mshadow::kUint8) { + *outputs[1].dptr() = 1; + *outputs[2].dptr() = 0; + UnaryOp::IdentityCompute(attrs_, ctx, {inputs[0]}, req, outputs); + } else if (input_data_dtype == mshadow::kInt8) { + const float scale = 1; + const float shift = 128; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else if (input_data_dtype == mshadow::kFloat32) { + const QuantizeAsymParam& param = nnvm::get(attrs_.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + const float scale = + MaxValue() / (param.max_calib_range.value() - param.min_calib_range.value()); + const float shift = MaxValue() - param.max_calib_range.value() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else { + mxnet::TShape src_shape, dst_shape; + const size_t float_bytes = sizeof(float); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * float_bytes + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t( + reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, dev_id); + TBlob in_max_t( + reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, dev_id); + Tensor workspace( + temp_space.dptr_ + 2 * float_bytes, Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + const float scale = + MaxValue() / (*in_max_t.dptr() - *in_min_t.dptr()); + const float shift = MaxValue() - *in_max_t.dptr() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } + } else { + LOG(FATAL) << "Asymmetric quantizaiton only supports int8, uint8 and " + "float inputs"; + } + } + + private: + nnvm::NodeAttrs attrs_; +}; + +template +void QuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + QuantizeAsymOp& op = state_ptr.get_state>(); + op.Forward(ctx, inputs, req, outputs); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ \ No newline at end of file diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc new file mode 100644 index 000000000000..24d1e9d53c6d --- /dev/null +++ b/src/operator/quantization/quantize_asym.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym.cc + * \brief implementation of asymmetric quantize operation + */ + +#include "operator/quantization/quantize_asym-inl.h" +#if MXNET_USE_ONEDNN == 1 +#include "operator/quantization/dnnl/dnnl_quantize_asym-inl.h" +#endif + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(QuantizeAsymParam); + +inline bool QuantizeAsymShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + mxnet::TShape dshape = in_attrs->at(0); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1)); + + if (out_attrs->at(0).ndim() > 0) { + dshape[0] = out_attrs->at(0)[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + + return !shape_is_none(out_attrs->at(0)); +} + +inline bool QuantizeAsymType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + CHECK_EQ(in_attrs->at(0), mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + + return !type_is_none(out_attrs->at(0)); +} + +bool QuantizeAsymStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_ONEDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + out_attrs->at(2) = kDefaultStorage; + return true; +} + +OpStatePtr CreateQuantizeAsymState(const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + OpStatePtr state; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(attrs); + } else { +#if MXNET_USE_ONEDNN == 1 + if (in_shapes[0].ndim() == 3 && in_types[0] == mshadow::kFloat32) { + state = OpStatePtr::Create(attrs); + return state; + } +#else + state = OpStatePtr::Create>(attrs); +#endif + } + return state; +} + +NNVM_REGISTER_OP(_contrib_quantize_asym) + .describe(R"code(Quantize a input tensor from float to uint8_t. +Output `scale` and `shift` are scalar floats that specify the quantization parameters for the input +data. +The output is calculated using the following equation: +`out[i] = in[i] * scale + shift + 0.5`, +where `scale = uint8_range / (max_range - min_range)` and +`shift = numeric_limits::max - max_range * scale`. +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(3) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) + .set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "scale", "shift"}; + }) + .set_attr("FInferShape", QuantizeAsymShape) + .set_attr("FInferType", QuantizeAsymType) + .set_attr("FInferStorageType", QuantizeAsymStorageType) + .set_attr("FGradient", MakeZeroGradNodes) + .set_attr("FCreateOpState", CreateQuantizeAsymState) +#if MXNET_USE_ONEDNN == 1 + .set_attr("TIsDNNL", true) + .set_attr("FStatefulComputeEx", DNNLQuantizeAsymForward) +#endif + .set_attr("FStatefulCompute", QuantizeAsymForward) + .set_attr("FNeedCalibrateInput", + [](const NodeAttrs& attrs) { return std::vector{0}; }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + const QuantizeAsymParam& param = + nnvm::get(attrs.parsed); + if (param.max_calib_range.has_value() && + param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector( + 1, ResourceRequest::kTempSpace); + } + }) + .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") + .add_arguments(QuantizeAsymParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 3835f1a3a9c9..a4e3086653b0 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -288,6 +288,10 @@ Graph QuantizeGraph(Graph&& src) { static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); static const auto& flist_inputs = nnvm::Op::GetAttr("FListInputNames"); + static const auto& avoid_dequantize_map = + Op::GetAttr("FAvoidDequantizeOutput"); + static const auto& need_asym_quantize_map = + Op::GetAttr("FNeedAsymQuantizeInput"); const auto offline_params = src.GetAttr>("offline_params"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); const auto quantize_granularity = src.GetAttr("quantize_granularity"); @@ -331,7 +335,14 @@ Graph QuantizeGraph(Graph&& src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) { new_node->inputs.emplace_back(mirror_entry); - } else if (!quantized_node_map.count(e.node)) { + } else if (!quantized_node_map.count(e.node) || + (avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { + // If the input of current quantized node has non-support of quantization, a quantize op + // is supposed to insert into the position after the input node to quantize the float + // input to int8/uint8 type. Also, a quantized operator with avoid-dequantize attribute + // can produce float outputs directly. A quantize op is necessary to convert them into + // int8/uint8 type as the input of current quantized node. if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); } else { @@ -354,10 +365,20 @@ Graph QuantizeGraph(Graph&& src) { new_name = node->attrs.name + "_" + e.node->attrs.name; } } - - ObjectPtr quantize_node = InsertNode( - "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry); - quantize_node->attrs.dict["out_type"] = quantized_dtype; + ObjectPtr quantize_node; + if (need_asym_quantize_map.count(node->op()) && + need_asym_quantize_map[node->op()](node->attrs, i)) { + quantize_node = InsertNode("_contrib_quantize_asym", + new_name + suffix + "_quantize", + new_node, + mirror_entry); + } else { + quantize_node = InsertNode( + "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry); + // If current node is rnn op, the quantize op is supposed to quantize the result of + // pre-node to uint8, as quantized rnn op requires uint8 input. + quantize_node->attrs.dict["out_type"] = quantized_dtype; + } quantize_node->op()->attr_parser(&(quantize_node->attrs)); mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; } @@ -439,9 +460,13 @@ Graph QuantizeGraph(Graph&& src) { for (const auto& e : node->inputs) { ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node + // If input node is quantized operator, add dequantize node. But if input node is a + // quantized operator with avoid-dequantize attribute, its output may be already in float + // type, which dosen't need a dequantize op. if (quantized_node_map.count(e.node) && - (mirror_node->op() != Op::Get("_contrib_dequantize"))) { + mirror_node->op() != Op::Get("_contrib_dequantize") && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1 min and 1 max output from mirror node (which is @@ -473,7 +498,9 @@ Graph QuantizeGraph(Graph&& src) { std::vector outputs; for (const auto& e : src.outputs) { - if (quantized_node_map.count(e.node)) { + if (quantized_node_map.count(e.node) && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // Only insert dequantize for those Ops supports quantize and not excluded. ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index e08bd0d5f76d..497ea37cdf28 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -18,7 +18,7 @@ */ /*! - * \file quantize.cc + * \file quantize_v2.cc * \brief */ diff --git a/src/operator/quantization/quantized_rnn-inl.h b/src/operator/quantization/quantized_rnn-inl.h new file mode 100644 index 000000000000..d5d9dd80a6ee --- /dev/null +++ b/src/operator/quantization/quantized_rnn-inl.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ + +namespace mxnet { +namespace op { + +namespace quantized_rnn { +enum QuantizedRnnInputs { kData, kParams, kState, kStateCell }; +enum QuantizedRnnInputMinMax { kDataScale, kDataShift }; +enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut }; +} // namespace quantized_rnn + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ \ No newline at end of file diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc new file mode 100644 index 000000000000..c396e816080d --- /dev/null +++ b/src/operator/quantization/quantized_rnn.cc @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#include +#include +#include + +#include "operator/rnn-inl.h" +#include "operator/quantization/quantization_utils.h" +#include "operator/quantization/quantized_rnn-inl.h" + +#if MXNET_USE_ONEDNN == 1 +#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h" +#endif + +namespace mxnet { +namespace op { + +uint32_t QuantizedRnnNumInputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return 6U; +} + +uint32_t QuantizedRnnNumOutputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return param.state_outputs ? 3U : 1U; +} + +std::vector QuantizedRnnInputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return std::vector{ + "data", "parameters", "state", "state_cell", "min_data", "max_data"}; +} + +std::vector QuantizedRnnOutputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + if (param.state_outputs) { + return std::vector{"output", "state_output", "statecell_ouput"}; + } else { + return std::vector{"output"}; + } +} + +bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_shape->size(), num_inputs) + << "Arguments' size of quantized RNN operator is mismatched. Expected " << num_inputs + << " argmuments but got " << in_shape->size() << "."; + CHECK_EQ(out_shape->size(), num_outputs); + + const mxnet::TShape dshape = in_shape->at(quantized_rnn::kData); + if (!mxnet::ndim_is_known(dshape)) + return false; + CHECK_EQ(dshape.ndim(), 3U) << "Input data of RNN operator should be 3-rank " + "tensor of dim [steps, batch, input size]"; + const dim_t batch_size = dshape[1]; + const dim_t input_size = dshape[2]; + const dim_t directions = param.bidirectional ? 2 : 1; + const dim_t total_lyrs = directions * param.num_layers; + const dim_t state_size = param.state_size; + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size)); + if (param.mode == rnn_enum::kLstm) + SHAPE_ASSIGN_CHECK( + *in_shape, quantized_rnn::kStateCell, Shape3(total_lyrs, batch_size, state_size)); + + const int param_size_fp = GetRnnParamSize( + param.num_layers, input_size, state_size, directions, param.mode, param.projection_size); + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kParams, Shape1(param_size_fp)); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1)); + + out_shape->clear(); + out_shape->push_back({dshape[0], batch_size, directions * state_size}); // output dim: [T, N, C] + if (param.state_outputs) { + out_shape->push_back({total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C] + if (param.mode == rnn_enum::kLstm) + out_shape->push_back({total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C] + } + return true; +} + +bool QuantizedRnnType(const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_type->size(), num_inputs); + CHECK_EQ(out_type->size(), num_outputs); + + CHECK_EQ(in_type->at(quantized_rnn::kData), mshadow::kUint8) + << "Quantized RNN operator only supports uint8 input, while " + << in_type->at(quantized_rnn::kData) << " is given."; + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kParams, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kState, mshadow::kFloat32); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kStateCell, mshadow::kFloat32); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kOut, mshadow::kFloat32); + if (param.state_outputs) { + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateOut, mshadow::kFloat32); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateCellOut, mshadow::kFloat32); + } + return true; +} + +bool QuantizedRnnStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_attrs->size(), num_inputs); + CHECK_EQ(out_attrs->size(), num_outputs); + +#if MXNET_USE_ONEDNN == 1 + return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +#else + *dispatch_mode = DispatchMode::kFCompute; + + for (auto& v : *out_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + + for (auto& v : *in_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + return true; +#endif +} + +void QuantizedRnnParamParser(nnvm::NodeAttrs* attrs) { + RNNParam param; + attrs->dict["quantized"] = "true"; + try { + param.Init(attrs->dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + attrs->parsed = std::move(param); +} + +OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs, + const Context ctx, + const mxnet::ShapeVector& in_shapes, + const std::vector& in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + OpStatePtr state = OpStatePtr(); +#if MXNET_USE_ONEDNN == 1 + const int data_type = in_types[quantized_rnn::kData]; + const int weight_type = in_types[quantized_rnn::kParams]; + if (data_type == mshadow::kUint8 && weight_type == mshadow::kFloat32) { + const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData]; + state = + OpStatePtr::Create(attrs, data_shape[0], data_shape[1], data_shape[2]); + } +#else + LOG(FATAL) << "Quantized RNN operator relies on oneDNN library." + << " Please build MXNet with USE_ONEDNN=ON to leverage this operator."; +#endif + return state; +} + +void QuantizedRnnForwardCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + LOG(FATAL) << "Quantized RNN operator relies on oneDNN library." + << " Please build MXNet with USE_ONEDNN=ON to leverage this operator."; +} + +#if MXNET_USE_ONEDNN == 1 +void QuantizedRnnForwardCPUEx(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + DNNLQuantizedRnnOp& op = state_ptr.get_state(); + op.Forward(ctx, in_data, req, out_data); +} +#endif // MXNET_USE_ONEDNN == 1 + +bool NeedAsymQuantizeRnnInput(const NodeAttrs& attrs, const size_t index_to_check) { + bool need_asym_quantize = false; + switch (index_to_check) { + case rnn_enum::kData: { + need_asym_quantize = true; + break; + } + default: { + need_asym_quantize = false; + } + } + return need_asym_quantize; +} + +bool AvoidRnnQuantizeInput(const NodeAttrs& attrs, + const size_t index_to_check, + const std::string quantize_granularity) { + std::unordered_set avoid_indexes; + avoid_indexes.insert({quantized_rnn::kParams, quantized_rnn::kState, quantized_rnn::kStateCell}); + + return avoid_indexes.count(index_to_check); +} + +bool AvoidRnnDequantizeOutput(const NodeAttrs& attrs, const size_t index_to_check) { + return true; +} + +static std::vector QuantizedRnnResourceEx(const NodeAttrs& attrs, + const int dev_mask, + const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN == 1 + LOG(FATAL) << "Currently, quantized RNN is not supported on the GPU platform."; +#endif + } else { +#if MXNET_USE_ONEDNN == 1 + request.emplace_back(ResourceRequest::kTempSpace); +#endif + } + return request; +} + +NNVM_REGISTER_OP(_contrib_quantized_rnn) + .describe( + R"code(RNN operator for input data type of uint8. The weight of each gates is converted +to int8, while bias is accumulated in type float32. The hidden state and cell state are in type +float32. For the input data, two more arguments of type float32 must be provided representing the +thresholds of quantizing argument from data type float32 to uint8. The final outputs contain the +recurrent result in float32. It only supports quantization for Vanilla LSTM network. +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_num_inputs(QuantizedRnnNumInputs) + .set_num_outputs(QuantizedRnnNumOutputs) + .set_attr_parser(QuantizedRnnParamParser) + .set_attr("FListInputNames", QuantizedRnnInputNames) + .set_attr("FListOutputNames", QuantizedRnnOutputNames) + .set_attr("FInferShape", QuantizedRnnShape) + .set_attr("FInferType", QuantizedRnnType) + .set_attr("FInferStorageType", QuantizedRnnStorageType) + .set_attr("FCreateOpState", CreateQuantizedRnnState) + .set_attr("FStatefulCompute", QuantizedRnnForwardCPU) +#if MXNET_USE_ONEDNN == 1 + .set_attr("TIsDNNL", true) + .set_attr("FStatefulComputeEx", QuantizedRnnForwardCPUEx) +#endif + .set_attr("FResourceRequestEx", QuantizedRnnResourceEx) + .add_argument("data", "NDArray-or-Symbol", "Input data.") + .add_argument("parameters", "NDArray-or-Symbol", "weight.") + .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") + .add_argument("state_cell", + "NDArray-or-Symbol", + "initial cell state for LSTM networks (only for LSTM)") + .add_argument("data_scale", "NDArray-or-Symbol", "quantization scale of data.") + .add_argument("data_shift", "NDArray-or-Symbol", "quantization shift of data.") + .add_arguments(RNNParam::__FIELDS__()); + +NNVM_REGISTER_OP(RNN) + .set_attr("FQuantizable", + [](const NodeAttrs& attrs) { +#if MXNET_USE_ONEDNN == 1 + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.mode != rnn_enum::kLstm) + LOG(INFO) << "Quantized RNN only supports LSTM mode."; + return param.mode == rnn_enum::kLstm ? QuantizeType::kMust : + QuantizeType::kNone; +#else + LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable oneDNN to " + << "use the feature."; + return QuantizeType::kNone; +#endif // MXNET_USE_ONEDNN == 1 + }) + .set_attr("FQuantizedOp", + [](const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_rnn"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "true"; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }) + .set_attr("FNeedAsymQuantizeInput", NeedAsymQuantizeRnnInput) + .set_attr("FAvoidQuantizeInput", AvoidRnnQuantizeInput) + .set_attr("FAvoidDequantizeOutput", AvoidRnnDequantizeOutput); + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index c34855468c5c..eac274f96a9d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -291,9 +291,9 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, return size; } -inline size_t GetNumInputArguments(RNNParam param_) { - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; - if (param_.use_sequence_length) +inline size_t GetRnnNumInputs(RNNParam param) { + size_t num_inputs = (param.mode == rnn_enum::kLstm) ? 4U : 3U; + if (param.use_sequence_length) num_inputs += 1U; return num_inputs; } @@ -748,7 +748,7 @@ class RNNOp { using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -1125,7 +1125,7 @@ class RNNOp { CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -1369,7 +1369,7 @@ class RNNOp { const std::vector& out_data) { using namespace mshadow; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; if (param_.state_outputs) { diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index e4b84dd0d927..5a03b06674c8 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -34,31 +34,41 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(RNNParam); -static inline std::vector ListArguments(const RNNParam& param_) { +static inline std::vector ListRnnInputNames(const RNNParam& param) { // All RNNs start off with same 3 input arguments std::vector arguments{"data", "parameters", "state"}; // LSTMs also have an additional state_cell argument - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { arguments.emplace_back("state_cell"); } // All RNNs have option of additional sequence_length argument - if (param_.use_sequence_length) { + if (param.use_sequence_length) { arguments.emplace_back("sequence_length"); } return arguments; } +static inline std::vector ListRnnOutputNames(const RNNParam& param) { + std::vector names{"output"}; + if (param.state_outputs) { + names.emplace_back("state_output"); + if (param.mode == rnn_enum::kLstm) + names.emplace_back("statecell_output"); + } + return names; +} + static bool RNNShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { - const RNNParam& param_ = nnvm::get(attrs.parsed); using namespace mshadow; + const RNNParam& param = nnvm::get(attrs.parsed); - // Query param_ object to figure out what the expectd input arguments are - std::vector expected_arguments = ListArguments(param_); + // Query param object to figure out what the expectd input arguments are + std::vector expected_arguments = ListRnnInputNames(param); CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " << expected_arguments.size() @@ -76,29 +86,29 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, } int batch_size = dshape[1]; int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional + int numDirections = param.bidirectional ? 2 : 1; + int total_layers = numDirections * param.num_layers; // double for bidirectional int layer_size = - (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + (param.projection_size.has_value()) ? param.projection_size.value() : param.state_size; SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { SHAPE_ASSIGN_CHECK( - *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param_.state_size)); + *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param.state_size)); } // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, + int param_size = GetRnnParamSize(param.num_layers, input_size, - param_.state_size, + param.state_size, numDirections, - param_.mode, - param_.projection_size); + param.mode, + param.projection_size); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); // Check on sequence_length shape if using - if (param_.use_sequence_length) { + if (param.use_sequence_length) { size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); @@ -107,29 +117,29 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, out_shape->clear(); // output: [sequence len, batch, output size] TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); + if (param.projection_size.has_value()) { + oshape[2] = numDirections * param.projection_size.value(); } else { - oshape[2] = numDirections * param_.state_size; + oshape[2] = numDirections * param.state_size; } out_shape->push_back(oshape); - if (param_.state_outputs) { + if (param.state_outputs) { // outStateShape: [layer_num, batch, state size] TShape outStateShape = dshape; outStateShape[0] = total_layers; outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); + if (param.projection_size.has_value()) { + outStateShape[2] = param.projection_size.value(); } else { - outStateShape[2] = param_.state_size; + outStateShape[2] = param.state_size; } out_shape->push_back(outStateShape); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { TShape cellStateShape = dshape; cellStateShape[0] = total_layers; cellStateShape[1] = batch_size; - cellStateShape[2] = param_.state_size; + cellStateShape[2] = param.state_size; out_shape->push_back(cellStateShape); } } @@ -140,34 +150,34 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, static bool RNNType(const nnvm::NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { - const RNNParam& param_ = nnvm::get(attrs.parsed); + const RNNParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); + CHECK_EQ(in_type->size(), GetRnnNumInputs(param)); size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; - std::vector arguments = ListArguments(param_); + std::vector arguments = ListRnnInputNames(param); for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { TYPE_ASSIGN_CHECK(*in_type, i, dtype); } else { // If using sequence length argument, it has its own indexing type // All other input arguments must match the main data type - if (!(param_.use_sequence_length && i == seq_len_input_idx)) { + if (!(param.use_sequence_length && i == seq_len_input_idx)) { UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); } } } out_type->clear(); out_type->push_back(dtype); - if (param_.state_outputs) { + if (param.state_outputs) { out_type->push_back(dtype); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { out_type->push_back(dtype); } } @@ -248,7 +258,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs& attrs, #if MXNET_USE_ONEDNN == 1 if (ctx.dev_type == kCPU && SupportDNNLRnn(param, in_types[rnn_enum::kData])) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; - state = OpStatePtr::Create(param, data_shape[0], data_shape[1], data_shape[2]); + state = OpStatePtr::Create(attrs, data_shape[0], data_shape[1], data_shape[2]); return state; } #endif // MXNET_USE_ONEDNN == 1 @@ -370,7 +380,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); @@ -386,18 +396,12 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return ListArguments(params); + return ListRnnInputNames(params); }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - std::vector names{"output"}; - if (params.state_outputs) { - names.emplace_back("state_output"); - if (params.mode == rnn_enum::kLstm) - names.emplace_back("statecell_output"); - } - return names; + return ListRnnOutputNames(params); }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) @@ -441,7 +445,7 @@ NNVM_REGISTER_OP(_backward_RNN) }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index dcd4bbd5b546..c0f807830a67 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -1414,3 +1414,77 @@ def get_threshold(nd): assert 'layer1' in min_max_dict assert_almost_equal(onp.array([min_max_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) + + + +@use_np +def test_quantized_rnn(): + data_low = -1 + data_high = 1 + def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): + data_shape = (seq_len, batch_size, input_dim) + + print(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size) + rnn_fp32 = mx.gluon.rnn.LSTM(hidden_size=state_size, + num_layers = num_layers, + bidirectional=bidirectional) + # run fp32 bn + data = mx.np.random.uniform(low=data_low, high=data_high, size=data_shape) + states_shape = (num_layers * 2 if bidirectional else num_layers, batch_size, state_size) + states = [mx.np.zeros((states_shape)) for _ in range(batch_size)] + + rnn_fp32.initialize() + rnn_fp32.hybridize() + + ref_out = rnn_fp32(data, states) + # print(ref_out) + fp32_params = rnn_fp32.collect_params() + sym, p = rnn_fp32.export(None) + rnn_fp32.export("WTF") + data_min = mx.np.min(data) + data_max = mx.np.max(data) + data_scale = 128.0 / (data_max - data_min) + data_shift = 128.0 - data_max * data_scale + qdata = (data * data_scale + data_shift + 0.5).astype('uint8') + + class RNNDataLoader(mx.gluon.data.DataLoader): + def __init__(self, data, states): + super().__init__(mx.gluon.data.SimpleDataset([]), 1) + self.data = data + self.states = states + + def __iter__(self): + return self + + def __next__(self): + return [self.data, self.states] + + def __bool__(self): + return bool(self.dataiter.iter_next()) + + # generate int8 bn from fp32 bn + # dataset = mx.gluon.data.ArrayDataset(data, [states[0]], [states[1]]) + # calib_data = mx.gluon.data.DataLoader(dataset, batch_size=batch_size) + calib_data = RNNDataLoader(data, states) + # qsym, qparams = mx.contrib.quant._quantize_symbol(sym, device=mx.current_device(), + # offline_params=p, quantize_mode='full') + # qsym.save("XDD") + # inputs = [mx.sym.var('data0'), mx.sym.var('data1'), mx.sym.var('data2')] + # calib_net = mx.gluon.SymbolBlock(qsym, inputs) + quant_rnn = mx.contrib.quant.quantize_net(rnn_fp32, + quantized_dtype='auto', + quantize_mode='full', + calib_data=calib_data, + calib_mode='naive', + num_calib_batches=1, + device=mx.current_device()) + #calib_net.load_dict(p, cast_dtype=True, dtype_source='saved', allow_missing=True) + output_int8_to_fp32 = quant_rnn(data, states) + + assert_almost_equal(ref_out[0].asnumpy(), output_int8_to_fp32[0].asnumpy(), rtol=1e-1, atol=8) + print(ref_out[0][0]) + print("============") + print(output_int8_to_fp32[0][0]) + for qdtype in ['int8', 'uint8']: + check_quantized_rnn(1, False, 5, 2, 16, 16) + check_quantized_rnn(1, True, 5, 2, 16, 16) From d4fca7daaea6e2b1ef70e082233794242f566e3f Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 1 Mar 2022 14:27:29 +0100 Subject: [PATCH 02/11] Add tests & disable LSTMP from quantization --- src/operator/quantization/quantized_rnn.cc | 9 +- .../python/quantization/test_quantization.py | 107 +++++++++++------- 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc index c396e816080d..6c4ac1e592e0 100644 --- a/src/operator/quantization/quantized_rnn.cc +++ b/src/operator/quantization/quantized_rnn.cc @@ -288,6 +288,7 @@ static std::vector QuantizedRnnResourceEx(const NodeAttrs& attr } NNVM_REGISTER_OP(_contrib_quantized_rnn) + .add_alias("_npx_contrib_quantized_rnn") .describe( R"code(RNN operator for input data type of uint8. The weight of each gates is converted to int8, while bias is accumulated in type float32. The hidden state and cell state are in type @@ -328,8 +329,12 @@ NNVM_REGISTER_OP(RNN) const RNNParam& param = nnvm::get(attrs.parsed); if (param.mode != rnn_enum::kLstm) LOG(INFO) << "Quantized RNN only supports LSTM mode."; - return param.mode == rnn_enum::kLstm ? QuantizeType::kMust : - QuantizeType::kNone; + if (param.mode == rnn_enum::kLstm && + !param.projection_size.has_value()) { + return QuantizeType::kMust; + } else { + return QuantizeType::kNone; + } #else LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable oneDNN to " << "use the feature."; diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index c0f807830a67..444c21d9e5f7 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -1415,37 +1415,24 @@ def get_threshold(nd): assert_almost_equal(onp.array([min_max_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) - - @use_np -def test_quantized_rnn(): +def test_rnn_quantization(): data_low = -1 data_high = 1 - def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): + def check_rnn_quantization(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): data_shape = (seq_len, batch_size, input_dim) - print(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size) rnn_fp32 = mx.gluon.rnn.LSTM(hidden_size=state_size, num_layers = num_layers, bidirectional=bidirectional) - # run fp32 bn + data = mx.np.random.uniform(low=data_low, high=data_high, size=data_shape) states_shape = (num_layers * 2 if bidirectional else num_layers, batch_size, state_size) states = [mx.np.zeros((states_shape)) for _ in range(batch_size)] rnn_fp32.initialize() rnn_fp32.hybridize() - ref_out = rnn_fp32(data, states) - # print(ref_out) - fp32_params = rnn_fp32.collect_params() - sym, p = rnn_fp32.export(None) - rnn_fp32.export("WTF") - data_min = mx.np.min(data) - data_max = mx.np.max(data) - data_scale = 128.0 / (data_max - data_min) - data_shift = 128.0 - data_max * data_scale - qdata = (data * data_scale + data_shift + 0.5).astype('uint8') class RNNDataLoader(mx.gluon.data.DataLoader): def __init__(self, data, states): @@ -1462,29 +1449,73 @@ def __next__(self): def __bool__(self): return bool(self.dataiter.iter_next()) - # generate int8 bn from fp32 bn - # dataset = mx.gluon.data.ArrayDataset(data, [states[0]], [states[1]]) - # calib_data = mx.gluon.data.DataLoader(dataset, batch_size=batch_size) calib_data = RNNDataLoader(data, states) - # qsym, qparams = mx.contrib.quant._quantize_symbol(sym, device=mx.current_device(), - # offline_params=p, quantize_mode='full') - # qsym.save("XDD") - # inputs = [mx.sym.var('data0'), mx.sym.var('data1'), mx.sym.var('data2')] - # calib_net = mx.gluon.SymbolBlock(qsym, inputs) quant_rnn = mx.contrib.quant.quantize_net(rnn_fp32, - quantized_dtype='auto', - quantize_mode='full', - calib_data=calib_data, - calib_mode='naive', - num_calib_batches=1, - device=mx.current_device()) - #calib_net.load_dict(p, cast_dtype=True, dtype_source='saved', allow_missing=True) - output_int8_to_fp32 = quant_rnn(data, states) - - assert_almost_equal(ref_out[0].asnumpy(), output_int8_to_fp32[0].asnumpy(), rtol=1e-1, atol=8) - print(ref_out[0][0]) - print("============") - print(output_int8_to_fp32[0][0]) + quantized_dtype='auto', + quantize_mode='full', + calib_data=calib_data, + calib_mode='naive', + num_calib_batches=1, + device=mx.current_device()) + qout = quant_rnn(data, states) + + qsym, _ = quant_rnn.export(None) + assert qsym.tojson().find("quantized_rnn") != -1 + + ref_out = [ref_out[0], ref_out[1][0], ref_out[1][1]] + for i in range(len(qout)): + mse = onp.mean((ref_out[i].asnumpy() - qout[i].asnumpy())**2) + assert mse < 0.001 + + for qdtype in ['int8', 'uint8']: + check_rnn_quantization(1, False, 5, 2, 16, 16) + check_rnn_quantization(1, True, 5, 2, 16, 16) + + + +@use_np +def test_quantized_rnn(): + def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): + ndir = 2 if bidirectional else 1 + size = ndir*state_size*4 + first_lyr_param_size = (input_dim + state_size + 2) * size + other_lyr_param_size = (state_size * ndir + state_size + 2) * size + full_param_size = first_lyr_param_size + (num_layers - 1) * other_lyr_param_size + + data = mx.np.random.uniform(-1, 1, (seq_len, batch_size, input_dim)) + state = mx.np.random.uniform(-1, 1, (num_layers*ndir, batch_size, state_size)) + state_cell = mx.np.random.uniform(0, 1, (num_layers*ndir, batch_size, state_size)) + params = mx.np.random.normal(0, 1, (full_param_size,)) + + out = npx.rnn(data=data, + parameters=params, + mode='lstm', + state=state, + state_size=state_size, + state_cell=state_cell, + num_layers=num_layers, + bidirectional=bidirectional) + + data_min = mx.np.min(data) + data_max = mx.np.max(data) + data_scale = mx.np.array(128.0 / (data_max - data_min)).reshape((1,)) + data_shift = mx.np.array(128.0 - data_max * data_scale).reshape((1,)) + + qdata = (data * data_scale + data_shift + 0.5).astype('uint8') + qout = npx.contrib_quantized_rnn(data=qdata, + parameters=params, + mode='lstm', + state=state, + state_size=state_size, + state_cell=state_cell, + num_layers=num_layers, + bidirectional=bidirectional, + data_scale=data_scale, + data_shift=data_shift) + + mse = onp.mean((out.asnumpy() - qout.asnumpy())**2) + assert mse < 0.001 + for qdtype in ['int8', 'uint8']: check_quantized_rnn(1, False, 5, 2, 16, 16) - check_quantized_rnn(1, True, 5, 2, 16, 16) + check_quantized_rnn(1, True, 5, 2, 16, 16) \ No newline at end of file From ef1bb90915dd24b81c0bd735393d9e7cba23dfcd Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 10 Mar 2022 16:45:58 +0100 Subject: [PATCH 03/11] apply review comments --- src/operator/nn/dnnl/dnnl_rnn-inl.h | 22 ------------- src/operator/nn/dnnl/dnnl_rnn.cc | 33 +++++++++++++++---- .../python/quantization/test_quantization.py | 10 +++--- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h b/src/operator/nn/dnnl/dnnl_rnn-inl.h index fafed4914d64..6165dfaeb4c4 100644 --- a/src/operator/nn/dnnl/dnnl_rnn-inl.h +++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h @@ -46,28 +46,6 @@ struct DNNLRnnParam : public dmlc::Parameter { } }; -inline void DNNLMemoryReorder(const dnnl::memory& src, const dnnl::memory& dst) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map reorderPrimitives; -#else - static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; -#endif - OpSignature key{}; - key.AddSign(src); - key.AddSign(dst); - - auto it = reorderPrimitives.find(key); - if (it == reorderPrimitives.end()) { - auto reorder = dnnl::reorder(src, dst); - it = AddToCache(&reorderPrimitives, key, reorder); - } - - dnnl_args_map_t net_args; - net_args.emplace(DNNL_ARG_SRC, src); - net_args.emplace(DNNL_ARG_DST, dst); - DNNLStream::Get()->RegisterPrimArgs(it->second, net_args); -} - struct DNNLRnnLayerParam { using memory = dnnl::memory; using dims = dnnl::memory::dims; diff --git a/src/operator/nn/dnnl/dnnl_rnn.cc b/src/operator/nn/dnnl/dnnl_rnn.cc index 8b4c585a12a6..bdda9b5e2259 100644 --- a/src/operator/nn/dnnl/dnnl_rnn.cc +++ b/src/operator/nn/dnnl/dnnl_rnn.cc @@ -233,10 +233,9 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, auto dst_state_desc = layer_param.state_outputs ? memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc) : memory::desc(); - auto dst_cell_desc = - layer_param.state_outputs ? - memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc) : // no cell in 1.x - memory::desc(); + auto dst_cell_desc = layer_param.state_outputs ? + memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc) : + memory::desc(); auto fwd = RnnPrimitive(); switch (mode) { @@ -249,8 +248,8 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, src_cell_desc, weight_layer_desc, weight_iter_desc, - weight_peep_desc, // peep new - weight_proj_desc, // proj new + weight_peep_desc, + weight_proj_desc, bias_desc, dst_layer_desc, dst_state_desc, @@ -510,6 +509,28 @@ void DNNLRnnForward::SetNewDataMem(void* x, } } +inline void DNNLMemoryReorder(const dnnl::memory& src, const dnnl::memory& dst) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map reorderPrimitives; +#else + static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; +#endif + OpSignature key{}; + key.AddSign(src); + key.AddSign(dst); + + auto it = reorderPrimitives.find(key); + if (it == reorderPrimitives.end()) { + auto reorder = dnnl::reorder(src, dst); + it = AddToCache(&reorderPrimitives, key, reorder); + } + + dnnl_args_map_t net_args; + net_args.emplace(DNNL_ARG_SRC, src); + net_args.emplace(DNNL_ARG_DST, dst); + DNNLStream::Get()->RegisterPrimArgs(it->second, net_args); +} + /* * Reorder the concatenated weights memory to a efficient memory block * with primitive-prefered format. diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 444c21d9e5f7..6b74a49a9d56 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -1467,9 +1467,8 @@ def __bool__(self): mse = onp.mean((ref_out[i].asnumpy() - qout[i].asnumpy())**2) assert mse < 0.001 - for qdtype in ['int8', 'uint8']: - check_rnn_quantization(1, False, 5, 2, 16, 16) - check_rnn_quantization(1, True, 5, 2, 16, 16) + check_rnn_quantization(1, False, 5, 2, 16, 16) + check_rnn_quantization(1, True, 5, 2, 16, 16) @@ -1516,6 +1515,5 @@ def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_di mse = onp.mean((out.asnumpy() - qout.asnumpy())**2) assert mse < 0.001 - for qdtype in ['int8', 'uint8']: - check_quantized_rnn(1, False, 5, 2, 16, 16) - check_quantized_rnn(1, True, 5, 2, 16, 16) \ No newline at end of file + check_quantized_rnn(1, False, 5, 2, 16, 16) + check_quantized_rnn(1, True, 5, 2, 16, 16) \ No newline at end of file From 2f1c63817b50eb113d3f402d0a3f2547d6fed2be Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 10 Mar 2022 17:20:05 +0100 Subject: [PATCH 04/11] change link --- src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h index e3c28e6a4711..9030fb38f5ab 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h @@ -63,7 +63,7 @@ class DNNLQuantizedRnnOp { // Used to store the intermediate results of multi-layer std::vector dst_; // According to - // https://intel.github.io/mkl-dnn/dev_guide_int8_computations.html, the + // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html, the // non-symmetric quantization is assumed by LSTM primitive. Namely, the // formula is: // data_f32 = (data_u8 - shift) / scale From 08ccf2cc88c6b227dfefc7f62210527a9c31f2e8 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 15 Mar 2022 13:26:29 +0100 Subject: [PATCH 05/11] Add new lines at the EOF --- src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h | 2 +- src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h | 2 +- src/operator/quantization/dnnl/dnnl_quantized_rnn.cc | 2 +- src/operator/quantization/quantize_asym-inl.h | 2 +- src/operator/quantization/quantize_asym.cc | 4 +++- src/operator/quantization/quantized_rnn-inl.h | 2 +- src/operator/quantization/quantized_rnn.cc | 3 ++- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h index 83e72e0a0d9e..9bbbd2d9eb54 100644 --- a/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h +++ b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h @@ -158,4 +158,4 @@ void DNNLQuantizeAsymForward(const OpStatePtr& state_ptr, } // namespace mxnet #endif // MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h index 9030fb38f5ab..cdd5417e3ea3 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h @@ -79,4 +79,4 @@ class DNNLQuantizedRnnOp { } // namespace mxnet #endif // MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc index 7ecb5ec58184..b79640d60369 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc @@ -362,4 +362,4 @@ void DNNLQuantizedRnnOp::Forward(const OpContext& op_ctx, } // namespace op } // namespace mxnet -#endif // MXNET_USE_ONEDNN == 1 \ No newline at end of file +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/quantization/quantize_asym-inl.h b/src/operator/quantization/quantize_asym-inl.h index 4d3fb554db8d..3aa44c4e4fd6 100644 --- a/src/operator/quantization/quantize_asym-inl.h +++ b/src/operator/quantization/quantize_asym-inl.h @@ -174,4 +174,4 @@ void QuantizeAsymForward(const OpStatePtr& state_ptr, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc index 24d1e9d53c6d..b67635fe685a 100644 --- a/src/operator/quantization/quantize_asym.cc +++ b/src/operator/quantization/quantize_asym.cc @@ -22,6 +22,8 @@ * \brief implementation of asymmetric quantize operation */ +#include + #include "operator/quantization/quantize_asym-inl.h" #if MXNET_USE_ONEDNN == 1 #include "operator/quantization/dnnl/dnnl_quantize_asym-inl.h" @@ -152,4 +154,4 @@ where `scale = uint8_range / (max_range - min_range)` and .add_arguments(QuantizeAsymParam::__FIELDS__()); } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet diff --git a/src/operator/quantization/quantized_rnn-inl.h b/src/operator/quantization/quantized_rnn-inl.h index d5d9dd80a6ee..6ab53cef867c 100644 --- a/src/operator/quantization/quantized_rnn-inl.h +++ b/src/operator/quantization/quantized_rnn-inl.h @@ -38,4 +38,4 @@ enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut }; } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc index 6c4ac1e592e0..bb5e67c3f4dc 100644 --- a/src/operator/quantization/quantized_rnn.cc +++ b/src/operator/quantization/quantized_rnn.cc @@ -24,6 +24,7 @@ */ #include +#include #include #include @@ -358,4 +359,4 @@ NNVM_REGISTER_OP(RNN) .set_attr("FAvoidDequantizeOutput", AvoidRnnDequantizeOutput); } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet From 5f74260932bc15a453d937f9f4ed0cbe520fd6b3 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 15 Mar 2022 18:13:17 +0100 Subject: [PATCH 06/11] Add ops to amp lists --- python/mxnet/amp/lists/symbol_bf16.py | 2 ++ python/mxnet/amp/lists/symbol_fp16.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py index dd545a778578..ed2416a477cb 100644 --- a/python/mxnet/amp/lists/symbol_bf16.py +++ b/python/mxnet/amp/lists/symbol_bf16.py @@ -119,6 +119,7 @@ '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', + '_contrib_quantize_asym', '_contrib_quantize_v2', '_contrib_quantized_concat', '_contrib_quantized_conv', @@ -127,6 +128,7 @@ '_contrib_quantized_pooling', '_contrib_quantized_elemwise_add', '_contrib_quantized_act', + '_contrib_quantized_rnn', '_image_crop', '_linspace', '_contrib_requantize', diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 52a8f459f302..ad1f0ad4b293 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -99,6 +99,7 @@ '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', + '_contrib_quantize_asym', '_contrib_quantize_v2', '_contrib_quantized_concat', '_contrib_quantized_conv', @@ -108,6 +109,7 @@ '_contrib_quantized_elemwise_add', '_contrib_quantized_act', '_contrib_quantized_reshape', + '_contrib_quantized_rnn', '_contrib_quantized_transpose', '_npx_quantized_reshape', '_npx_quantized_transpose', From dc2dfbba740f5e973b72b311303c1b8b51ac2602 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 16 Mar 2022 09:41:52 +0100 Subject: [PATCH 07/11] Remove unused features --- python/mxnet/io/io.py | 18 ++++++------------ python/mxnet/io/utils.py | 5 ----- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 013f401108f8..b0a01290c4f1 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -37,7 +37,7 @@ from ..ndarray import array from ..ndarray import concat, tile -from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis +from .utils import _init_data, _has_instance, _getdata_by_idx class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -602,12 +602,10 @@ class NDArrayIter(DataIter): The data name. label_name : str, optional The label name. - layout : str, optional - The data layout. """ def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', - label_name='softmax_label', layout='NCHW'): + label_name='softmax_label'): super(NDArrayIter, self).__init__(batch_size) self.data = _init_data(data, allow_empty=False, default_name=data_name) @@ -633,16 +631,12 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, # used for 'roll_over' self._cache_data = None self._cache_label = None - self.layout = layout @property def provide_data(self): """The name and shape of data provided by this iterator.""" - batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ - [self.batch_size] + list(v.shape[batch_axis + 1:])), - v.dtype, layout=self.layout) + DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.data ] @@ -690,7 +684,7 @@ def next(self): data = self.getdata() label = self.getlabel() # iter should stop when last batch is not complete - if data[0].shape[self.layout.find('N')] != self.batch_size: + if data[0].shape[0] != self.batch_size: # in this case, cache it for next epoch self._cache_data = data self._cache_label = label @@ -706,7 +700,7 @@ def _getdata(self, data_source, start=None, end=None): end = data_source[0][1].shape[0] if data_source else 0 s = slice(start, end) return [ - _slice_along_batch_axis(x[1], s, self.layout.find('N')) + x[1][s] if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) array(x[1][sorted(self.idx[s])][[ @@ -725,7 +719,7 @@ def _concat(self, first_data, second_data): concat( first_data[i], second_data[i], - dim=self.layout.find('N') + dim=0 ) for i in range(len(first_data)) ] diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py index 55f228f4556d..55ba34aea426 100644 --- a/python/mxnet/io/utils.py +++ b/python/mxnet/io/utils.py @@ -84,8 +84,3 @@ def _getdata_by_idx(data, idx): shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) return shuffle_data - -def _slice_along_batch_axis(data, s, batch_axis): - """Apply slice along the batch axis""" - ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop) - return ret \ No newline at end of file From 075d40165e8d2c2ab0f946589fccb6cb10a993dc Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 16 Mar 2022 15:42:53 +0100 Subject: [PATCH 08/11] Fix DataDesc handling in quantization --- python/mxnet/contrib/quantization.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 4e6411135342..64942353db29 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -382,7 +382,7 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_ def _generate_list_of_data_desc(data_shapes, data_types): - """"Convert list of tuples to list of DataDesc.""" + """Convert list of tuples to list of DataDesc.""" def flatten_list(arg): ret = [] for el in arg: @@ -394,6 +394,10 @@ def flatten_list(arg): flattened_data_types = flatten_list(data_types) flattened_data_shapes = flatten_list(data_shapes) + + if all(isinstance(x, DataDesc) for x in flattened_data_shapes): + return data_shapes + assert len(flattened_data_types) == len(flattened_data_shapes) # pass integral type as reference From 08178b30ac63d391f0da0ced408e5c1ac25e9ef6 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 18 Mar 2022 10:31:56 +0100 Subject: [PATCH 09/11] fix website --- src/operator/quantization/quantize_asym.cc | 8 +++++--- src/operator/quantization/quantized_rnn.cc | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc index b67635fe685a..4cb2669cd1c9 100644 --- a/src/operator/quantization/quantize_asym.cc +++ b/src/operator/quantization/quantize_asym.cc @@ -107,12 +107,14 @@ OpStatePtr CreateQuantizeAsymState(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(_contrib_quantize_asym) .describe(R"code(Quantize a input tensor from float to uint8_t. -Output `scale` and `shift` are scalar floats that specify the quantization parameters for the input -data. -The output is calculated using the following equation: +Output `scale` and `shift` are scalar floats that specify the quantization +parameters for the input data. The output is calculated using the following equation: + `out[i] = in[i] * scale + shift + 0.5`, + where `scale = uint8_range / (max_range - min_range)` and `shift = numeric_limits::max - max_range * scale`. + .. Note:: This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) .set_attr_parser(ParamParser) diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc index bb5e67c3f4dc..2529ead6832f 100644 --- a/src/operator/quantization/quantized_rnn.cc +++ b/src/operator/quantization/quantized_rnn.cc @@ -290,12 +290,12 @@ static std::vector QuantizedRnnResourceEx(const NodeAttrs& attr NNVM_REGISTER_OP(_contrib_quantized_rnn) .add_alias("_npx_contrib_quantized_rnn") - .describe( - R"code(RNN operator for input data type of uint8. The weight of each gates is converted + .describe(R"code(RNN operator for input data type of uint8. The weight of each gates is converted to int8, while bias is accumulated in type float32. The hidden state and cell state are in type float32. For the input data, two more arguments of type float32 must be provided representing the thresholds of quantizing argument from data type float32 to uint8. The final outputs contain the recurrent result in float32. It only supports quantization for Vanilla LSTM network. + .. Note:: This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) .set_num_inputs(QuantizedRnnNumInputs) From 8324032da367b01ebd805c09fc59a18241daf205 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 18 Mar 2022 13:49:18 +0100 Subject: [PATCH 10/11] fix sanity --- src/operator/quantization/quantized_rnn.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc index 2529ead6832f..88c80bca3cc7 100644 --- a/src/operator/quantization/quantized_rnn.cc +++ b/src/operator/quantization/quantized_rnn.cc @@ -290,11 +290,12 @@ static std::vector QuantizedRnnResourceEx(const NodeAttrs& attr NNVM_REGISTER_OP(_contrib_quantized_rnn) .add_alias("_npx_contrib_quantized_rnn") - .describe(R"code(RNN operator for input data type of uint8. The weight of each gates is converted -to int8, while bias is accumulated in type float32. The hidden state and cell state are in type -float32. For the input data, two more arguments of type float32 must be provided representing the -thresholds of quantizing argument from data type float32 to uint8. The final outputs contain the -recurrent result in float32. It only supports quantization for Vanilla LSTM network. + .describe(R"code(RNN operator for input data type of uint8. The weight of each +gates is converted to int8, while bias is accumulated in type float32. +The hidden state and cell state are in type float32. For the input data, two more arguments +of type float32 must be provided representing the thresholds of quantizing argument from +data type float32 to uint8. The final outputs contain the recurrent result in float32. +It only supports quantization for Vanilla LSTM network. .. Note:: This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) From 410ad17a1239b8bd06e473d67ad0bf04e97a8279 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 28 Mar 2022 14:39:45 +0200 Subject: [PATCH 11/11] remove magic number --- .../quantization/dnnl/dnnl_quantized_rnn.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc index b79640d60369..73393d9b4c36 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc @@ -33,13 +33,14 @@ namespace op { std::vector GetDNNLRnnWeightsQParams(const DNNLRnnFullParam& full_param, float* w_ptr) { const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const int num_gates = 4; const RNNParam& default_param = full_param.default_param; const LayerParamVector& layer_params = full_param.layer_params; const DNNLRnnLayerParam& layer_param0 = layer_params.at(0); const size_t w_size0 = layer_param0.single_w_size; - const size_t wx_size0 = 4 * layer_param0.state_size * layer_param0.input_size; - const size_t wh_size0 = 4 * layer_param0.state_size * layer_param0.state_size; + const size_t wx_size0 = num_gates * layer_param0.state_size * layer_param0.input_size; + const size_t wh_size0 = num_gates * layer_param0.state_size * layer_param0.state_size; int directions = 1; float* wx = w_ptr; @@ -64,10 +65,10 @@ std::vector GetDNNLRnnWeightsQParams(const DNNLRnnFullParam& full_param, fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]); } } - std::vector w_max(4 * layer_param0.state_size, 0.0); - const index_t input_size = layer_param0.input_size; // input - const index_t state_size = layer_param0.state_size; // state - const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state + std::vector w_max(num_gates * layer_param0.state_size, 0.0); + const index_t input_size = layer_param0.input_size; // input + const index_t state_size = layer_param0.state_size; // state + const index_t gates_nblks = num_gates * layer_param0.state_size; // gates * state for (index_t go = 0; go < gates_nblks; ++go) { float tmp_max = w_max[go]; for (index_t i = 0; i < input_size; ++i) {