From db71b9c5f0d6c68bf69f721871657640a596eff0 Mon Sep 17 00:00:00 2001 From: shuo-ouyang <1414114532@qq.com> Date: Thu, 26 Mar 2020 13:11:42 +0800 Subject: [PATCH] 1bit gradient compression implementation --- ci/docker/runtime_functions.sh | 6 +- python/mxnet/kvstore/kvstore.py | 5 +- src/kvstore/gradient_compression-inl.h | 107 +++++++++++++++++ src/kvstore/gradient_compression.cc | 84 +++++++++---- src/kvstore/gradient_compression.cu | 10 ++ src/kvstore/gradient_compression.h | 8 +- tests/nightly/dist_sync_kvstore.py | 119 ++++++++++++++++-- tests/nightly/test_kvstore.py | 159 +++++++++++++++++++------ 8 files changed, 428 insertions(+), 70 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index e1f36996d43d..87d29227cff4 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1315,8 +1315,10 @@ integrationtest_ubuntu_cpu_dist_kvstore() { python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=gluon_type_cpu python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --no-multiprecision - python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu - python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision + python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu_1bit + python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu_1bit --no-multiprecision + python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu_2bit + python3 ../../tools/launch.py -n 7 --launcher local python3 dist_sync_kvstore.py --type=compressed_cpu_2bit --no-multiprecision python3 ../../tools/launch.py -n 3 --launcher local python3 test_server_profiling.py popd } diff --git a/python/mxnet/kvstore/kvstore.py b/python/mxnet/kvstore/kvstore.py index 11ec3f98178f..eec6aa5453f0 100644 --- a/python/mxnet/kvstore/kvstore.py +++ b/python/mxnet/kvstore/kvstore.py @@ -498,6 +498,9 @@ def set_gradient_compression(self, compression_params): """ Specifies type of low-bit quantization for gradient compression \ and additional arguments depending on the type of compression being used. + The 1bit compression works as follows: values which is above the threshold in the + gradient will be set to +1, whereas values below threshold will be set to -1. + 2bit Gradient Compression takes a positive float `threshold`. The technique works by thresholding values such that positive values in the gradient above threshold will be set to threshold. Negative values whose absolute @@ -538,7 +541,7 @@ def set_gradient_compression(self, compression_params): A dictionary specifying the type and parameters for gradient compression. The key `type` in this dictionary is a required string argument and specifies the type of gradient compression. - Currently `type` can be only `2bit` + Currently `type` can be only `1bit` and `2bit` Other keys in this dictionary are optional and specific to the type of gradient compression. """ diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h index 9b69bd11472c..7d70dff59617 100644 --- a/src/kvstore/gradient_compression-inl.h +++ b/src/kvstore/gradient_compression-inl.h @@ -32,11 +32,106 @@ namespace mxnet { namespace kvstore { // these gpu functions are defined in gradient_compression.cu +void Quantize1BitImpl(mshadow::Stream *s, const std::vector &inputs, + const float threshold); +void Dequantize1BitImpl(mshadow::Stream *s, const std::vector &inputs, + const float threshold); void Quantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, const float threshold); void Dequantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, const float threshold); +struct quantize_1bit { + MSHADOW_XINLINE static void Map(int out_block_id, + int original_size, + float *out, + float *grad, + float *residual, + const float threshold) { + // this block contains the compressed representation of + // upto 32 values starting from out_block_id*32 + float *compr_block = out + out_block_id; + // init to 0 + *compr_block = 0; + // start and end are indices in original grad array + const int start = out_block_id << 5; + const int end = (start + 32 <= original_size) ? start + 32 : original_size; + + char *block_ptr = reinterpret_cast < char * > (compr_block); + // masks used to quantize data + const uint8_t bits[] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; + for (int i = start; i < end; ++i) { + // adds offset to reach appropriate byte + char *curr_byte = block_ptr + ((i - start) >> 3); + // adds gradient to existing residual to get updated grad + residual[i] += grad[i]; + if (residual[i] > threshold) { + // set data to 1 + *curr_byte |= bits[(i & 7)]; + // reduce residual by 1 + residual[i] -= 1; + } else { + // set data to 0 + *curr_byte &= ~bits[(i & 7)]; + // add residual by 1 + // because current position will be dequantized to -1 + residual[i] += 1; + } + } + } +}; + +template +void Quantize1BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, + const float threshold) { + mxnet::op::mxnet_op::Kernel + ::Launch(s, + inputs[2].Size(), // compressed array size + inputs[0].Size(), // original size + inputs[2].dptr(), // compressed array + inputs[0].dptr(), // original array + inputs[1].dptr(), // residual array + threshold); // threshold +} + +struct dequantize_1bit { + MSHADOW_XINLINE static void Map(int i, + float *out, + float *in, + const float threshold) { + // get position of dequantized value to fill + float *outval = out + i; + // gets byte which holds quantized value for this position + char *ch_ptr = reinterpret_cast < char * > (in + (i >> 5)); + ch_ptr += ((i & 31) >> 3); + // masks used to quantize data + const uint8_t bits[] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; + // col denotes which bit of a byte is set for this value + // col=0 implies the first bit, col=1 implies the second bit,... + const int col = i & 7; + const uint8_t mask = bits[col]; + const uint8_t masked = *ch_ptr & mask; + if (masked == mask) { + *outval = +1; + } else { + // if current position of byte is 0 + // dequantized it to -1 + *outval = -1; + } + } +}; + +template +void Dequantize1BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, + const float threshold) { + mxnet::op::mxnet_op::Kernel + ::Launch(s, + inputs[1].Size(), // original size + inputs[1].dptr(), // out array + inputs[0].dptr(), // compressed array + threshold); // threshold +} + struct quantize_2bit { MSHADOW_XINLINE static void Map(int out_block_id, int original_size, @@ -138,6 +233,18 @@ void Dequantize2BitKernelLaunch(mshadow::Stream *s, const std::vector *s, + const std::vector &inputs, + const float threshold) { + Quantize1BitKernelLaunch(s, inputs, threshold); +} + +inline void Dequantize1BitImpl(mshadow::Stream *s, + const std::vector &inputs, + const float threshold) { + Dequantize1BitKernelLaunch(s, inputs, threshold); +} + inline void Quantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, const float threshold) { diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index 30aaec91e27f..86a183dd6688 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -41,8 +41,10 @@ void GradientCompression::SetParams(const std::vector(type_)); } +void GradientCompression::SetOneBitCompression(const float threshold) { + type_ = CompressionType::kOneBit; + threshold_ = threshold; +} + void GradientCompression::SetTwoBitCompression(const float threshold) { type_ = CompressionType::kTwoBit; threshold_ = threshold; @@ -83,7 +90,9 @@ void GradientCompression::DecodeParams(const std::string &s) { } int GradientCompression::GetCompressionFactor() { - if (type_ == CompressionType::kTwoBit) { + if (type_ == CompressionType::kOneBit) { + return 32; + } else if (type_ == CompressionType::kTwoBit) { return 16; } else { LOG(FATAL) << "Unsupported compression type: " << get_type_str(); @@ -106,16 +115,34 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; - if (type_ == CompressionType::kTwoBit) { - if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + if (type_ == CompressionType::kOneBit) { + mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize1BitImpl(ctx.get_stream(), inputs, threshold); + }, from.ctx(), {from.var()}, {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, priority, "QuantizeCPU"); + } else if (type_ == CompressionType::kTwoBit) { mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), residual->data(), to->data()}; Quantize2BitImpl(ctx.get_stream(), inputs, threshold); }, from.ctx(), {from.var()}, {to->var(), residual->var()}, mxnet::FnProperty::kNormal, priority, "QuantizeCPU"); } else { + LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); + } + } else { #if MXNET_USE_CUDA - if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + if (type_ == CompressionType::kOneBit) { + mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize1BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, from.ctx(), {from.var()}, {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, priority, "QuantizeGPU"); + } else if (type_ == CompressionType::kTwoBit) { mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), residual->data(), to->data()}; Quantize2BitImpl(ctx.get_stream(), inputs, threshold); @@ -124,14 +151,14 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t }, from.ctx(), {from.var()}, {to->var(), residual->var()}, mxnet::FnProperty::kNormal, priority, "QuantizeGPU"); } else { - LOG(FATAL) << "unknown device mask"; + LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); } + } else { + LOG(FATAL) << "unknown device mask"; + } #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif - } - } else { - LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); } } @@ -142,35 +169,52 @@ void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; - if (type_ == CompressionType::kTwoBit) { - if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + if (type_ == CompressionType::kOneBit) { + mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); + }, from.ctx(), {from.var()}, {to->var()}, + mxnet::FnProperty::kNormal, priority, "DequantizeCPU"); + } else if (type_ == CompressionType::kTwoBit) { mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), to->data()}; Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); }, from.ctx(), {from.var()}, {to->var()}, mxnet::FnProperty::kNormal, priority, "DequantizeCPU"); } else { + LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); + } + } else { #if MXNET_USE_CUDA - if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + if (type_ == CompressionType::kOneBit) { mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), to->data()}; - Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, from.ctx(), {from.var()}, {to->var()}, mxnet::FnProperty::kNormal, priority, "DequantizeGPU"); + } else if (type_ == CompressionType::kTwoBit) { + mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to completes + ctx.get_stream()->Wait(); + }, from.ctx(), {from.var()}, {to->var()}, + mxnet::FnProperty::kNormal, priority, "DequantizeGPU"); } else { - LOG(FATAL) << "unknown device mask"; + LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); } + } else { + LOG(FATAL) << "unknown device mask"; + } #else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif - } - } else { - LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); } } - } // namespace kvstore } // namespace mxnet diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu index b0d9662520b2..c5bacc227306 100644 --- a/src/kvstore/gradient_compression.cu +++ b/src/kvstore/gradient_compression.cu @@ -27,6 +27,16 @@ namespace mxnet { namespace kvstore { +void Quantize1BitImpl(mshadow::Stream* s, const std::vector& inputs, + const float threshold) { + Quantize1BitKernelLaunch(s, inputs, threshold); +} + +void Dequantize1BitImpl(mshadow::Stream* s, const std::vector& inputs, + const float threshold) { + Dequantize1BitKernelLaunch(s, inputs, threshold); +} + void Quantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, const float threshold) { Quantize2BitKernelLaunch(s, inputs, threshold); diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h index f40b45f5a513..5496ada31bba 100644 --- a/src/kvstore/gradient_compression.h +++ b/src/kvstore/gradient_compression.h @@ -35,7 +35,7 @@ namespace mxnet { namespace kvstore { enum class CompressionType { - kNone, kTwoBit + kNone, kOneBit, kTwoBit }; struct GradientCompressionParam : public dmlc::Parameter { @@ -72,6 +72,12 @@ class GradientCompression { */ std::string get_type_str(); + /*! + * \biref sets one bit gradient compression + * \param threshold float value used for thresholding gradients + */ + void SetOneBitCompression(const float threshold); + /*! * \brief sets two bit gradient compression * \param threshold float value used for thresholding gradients diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 4523a361cf88..215062266437 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -25,13 +25,13 @@ import numpy as np import numpy.random as rnd from mxnet.test_utils import assert_almost_equal, assert_exception -from test_kvstore import compute_expected_2bit_quantization +from test_kvstore import compute_expected_quantization, compute_1bit, compute_2bit def check_diff(A, x, rank=None): """ assert A == x x can be scalar as well as numpy array """ - assert (np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x.asnumpy()) + assert (np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) # setup shape = (2, 3) @@ -88,9 +88,8 @@ def set_optimizer(use_multiprecision): kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate, multi_precision=use_multiprecision)) return kv -def init_kv_compressed(kv): - threshold = 0.5 - kv.set_gradient_compression({'type': '2bit', 'threshold': threshold}) +def init_kv_compressed(kv, compression='2bit', threshold=.5): + kv.set_gradient_compression({'type': compression, 'threshold': threshold}) # init kv compression keys for k, s in compr_keys_shapes: kv.init(k, mx.nd.zeros(s)) @@ -230,6 +229,104 @@ def check_big_row_sparse_keys(dtype, nrepeat): check_big_row_sparse_keys(dtype, nrepeat) print('worker ' + str(my_rank) + ' is done with non compression tests') +def test_sync_1bit_compression(threshold, nrepeat): + + def check_compr_pull_before_push(): + for k, s in compr_keys_shapes: + val = mx.nd.ones(s) + kv.pull(k, val) + check_diff(val, 0) + for k, s in compr_init_keys_shapes: + # tests that GC is not used for init of a key + val = mx.nd.zeros(s) + kv.pull(k, val) + check_diff(val, 1) + + def check_compr_ones(): + for k, s in compr_keys_shapes: + val = mx.nd.zeros(s) + kv.pull(k, val) + curr_val = val[0][0].asnumpy()[0] + kv.push(k, mx.nd.ones(s)) + out = mx.nd.zeros(s) + kv.pull(k, out=out) + newval = curr_val + rate * nworker + check_diff(out, newval) + + def check_compr_neg_ones(): + for k, s in compr_keys_shapes: + val = mx.nd.zeros(s) + kv.pull(k, val) + curr_val = val[0][0].asnumpy()[0] + kv.push(k, -1 * mx.nd.ones(s)) + out = mx.nd.ones(s) + kv.pull(k, out=out) + # current value should be zero after call + # check_compr_ones and check_compr_neg_ones + check_diff(out, 0) + + def check_compr_residual(threshold): + curr_residual = 0 + curr_val = rate * nworker if 2 + curr_residual > threshold else -rate * nworker + for k, s in compr_keys_shapes: + kv.push(k, 2 * mx.nd.ones(s)) + out = mx.nd.zeros(s) + kv.pull(k, out) + check_diff(out, curr_val) + + curr_residual = 1 if 2 > threshold else 3 + curr_val += rate * nworker if 0 + curr_residual > threshold else -rate * nworker + for k, s in compr_keys_shapes: + kv.push(k, mx.nd.zeros(s)) + out = mx.nd.zeros(s) + kv.pull(k, out) + check_diff(out, curr_val) + + curr_residual += -1 if curr_residual > threshold else +1 + curr_val += rate * nworker if -2 + curr_residual > threshold else -rate * nworker + for k, s in compr_keys_shapes: + kv.push(k, -2 * mx.nd.ones(s)) + out = mx.nd.zeros(s) + kv.pull(k, out) + check_diff(out, curr_val) + + def check_compr_random(threshold, nrepeat): + # set a seed so all workers generate same data. knowing this helps + # calculate expected value after pull + mx.random.seed(123) + rnd.seed(123) + + # use new keys so residual is 0 for calculation of expected + for k,s in compr_random_keys_shapes: + kv.init(k, mx.nd.zeros(s)) + for k,s in compr_random_keys_shapes: + curr_residual = np.zeros(s) + for l in range(nrepeat): + orig_val = mx.nd.zeros(s) + kv.pull(k, orig_val) + + grad = mx.nd.array(rnd.rand(s[0], s[1])) + # creates a copy because push changes grad because of assignment + grad_cpy = mx.nd.array(grad) + kv.push(k, grad) + val = mx.nd.zeros(s) + kv.pull(k, val) + + diff = val - orig_val + + # compute expected by using simulation of operator + compr, curr_residual, decompr = compute_expected_quantization(grad_cpy, curr_residual, threshold, compute_1bit) + decompr *= nworker * rate + assert_almost_equal(diff.asnumpy(), decompr) + + print ('worker ' + str(my_rank) + ' started with 1bit compression tests') + check_compr_pull_before_push() + check_compr_ones() + check_compr_neg_ones() + check_compr_residual(threshold) + check_compr_random(threshold, nrepeat) + print('worker ' + str(my_rank) + ' is done with 1bit compression tests') + def test_sync_2bit_compression(threshold, nrepeat): def check_compr_residual(threshold): for k, s in compr_keys_shapes: @@ -316,17 +413,17 @@ def check_compr_random(threshold, nrepeat): diff = val - orig_val # compute expected by using simulation of operator - compr, curr_residual, decompr = compute_expected_2bit_quantization(grad_cpy, curr_residual, threshold) + compr, curr_residual, decompr = compute_expected_quantization(grad_cpy, curr_residual, threshold, compute_2bit) decompr *= nworker * rate assert_almost_equal(diff.asnumpy(), decompr) - print ('worker ' + str(my_rank) + ' started with compression tests') + print ('worker ' + str(my_rank) + ' started with 2bit compression tests') check_compr_pull_before_push() check_compr_zero() check_compr_residual(threshold) check_compr_ones(threshold) check_compr_random(threshold, nrepeat) - print('worker ' + str(my_rank) + ' is done with compression tests') + print('worker ' + str(my_rank) + ' is done with 2bit compression tests') def test_sync_init(gpu_tests=False): def get_dtype(idx, cur_keys): @@ -454,7 +551,11 @@ def check_trainer_sparse_step(): kv = init_kv() kv = set_optimizer(use_multiprecision=opt.multiprecision) test_sync_push_pull(opt.nrepeat) - elif opt.type == 'compressed_cpu': + elif opt.type == 'compressed_cpu_1bit': + kv, threshold = init_kv_compressed(kv, '1bit', 0) + kv = set_optimizer(use_multiprecision=opt.multiprecision) + test_sync_1bit_compression(threshold, opt.nrepeat) + elif opt.type == 'compressed_cpu_2bit': kv, threshold = init_kv_compressed(kv) kv = set_optimizer(use_multiprecision=opt.multiprecision) test_sync_2bit_compression(threshold, opt.nrepeat) diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py index ced3ee1ef8cc..65f6c4a13202 100755 --- a/tests/nightly/test_kvstore.py +++ b/tests/nightly/test_kvstore.py @@ -27,53 +27,73 @@ from mxnet.test_utils import assert_almost_equal + def check_diff_to_scalar(A, x, rank=None): """ assert A == x""" assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) -def compute_expected_2bit_quantization(arr, curr_residual, threshold): - from struct import pack,unpack - def bits2int(bits): - bits = [int(x) for x in bits[::-1]] - x = 0 - for i in range(len(bits)): - x += bits[i]*2**i - return x +def compute_1bit(arr, curr_residual, threshold): + str_quant = "" + new_residual = [] + decompr = [] - def as_float32(s): - return unpack("f",pack("I", bits2int(s)))[0] + for idx, val in np.ndenumerate(arr): + val += curr_residual[idx] + if val > threshold: + str_quant += "1" + new_residual.append(val - 1) + decompr.append(1) + else: + str_quant += "0" + new_residual.append(val + 1) + decompr.append(-1) - # str_quant stores the quantized representation as a sequence of bits - str_quant = '' + # append extra bits when size of array not a factor of 32 + if len(str_quant) != 32: + str_quant += "0" * (32 - len(str_quant) % 32) + return str_quant, new_residual, decompr + +def compute_2bit(arr, curr_residual, threshold): + str_quant = "" new_residual = [] decompr = [] - arr_npy = arr.asnumpy() - for i, a in np.ndenumerate(arr_npy): - a += curr_residual[i] - if a >= threshold: - str_quant += '11' - new_residual.append(a - threshold) + for idx, val in np.ndenumerate(arr): + val += curr_residual[idx] + if val >= threshold: + str_quant += "11" + new_residual.append(val - threshold) decompr.append(threshold) - elif a <= (-1*threshold): - str_quant += '10' - new_residual.append(a + threshold) - decompr.append(-1*threshold) + elif val <= -threshold: + str_quant += "10" + new_residual.append(val + threshold) + decompr.append(-threshold) else: - str_quant += '00' - new_residual.append(a) + str_quant += "00" + new_residual.append(val) decompr.append(0) + # append extra bits when size of array not a factor of 16 - if len(str_quant)%16 != 0: - str_quant += '0'*(16 - len(str_quant)%16) + if len(str_quant) % 16 != 0: + str_quant += "0" * (16 - len(str_quant) % 16) + return str_quant, new_residual, decompr +def compute_expected_quantization(arr, curr_residual, threshold, quantize_func): + + from struct import pack,unpack + def as_float32(s): + return unpack("f",pack("I", int(s, 2)))[0] + + arr_npy = arr.asnumpy() + # str_quant stores the quantized representation as a sequence of bits + str_quant, new_residual, decompr = quantize_func(arr_npy, curr_residual, threshold) + compr = [] # converts the string generated into integers 32chars at a time - i = 0 - while i threshold else -rate * nworker + for j in range(len(keys)): + kv.push(keys[j], [2 * mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]) + out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] + kv.pull(keys[j], out=out) + + for o in out: + check_diff_to_scalar(o, curr_val) + + curr_residual = 1 if 2 > threshold else 3 + curr_val += rate * nworker if 0 + curr_residual > threshold else -rate * nworker + for j in range(len(keys)): + kv.push(keys[j], [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]) + out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] + kv.pull(keys[j], out=out) + for o in out: + check_diff_to_scalar(o, curr_val) + + curr_residual += -1 if curr_residual > threshold else +1 + curr_val += rate * nworker if -2 + curr_residual > threshold else -rate * nworker + for j in range(len(keys)): + kv.push(keys[j], [-2 * mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]) + out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] + kv.pull(keys[j], out=out) + for o in out: + check_diff_to_scalar(o, curr_val) + def push_zeros(kv): for i in range(nrepeat): for j in range(len(keys)): @@ -141,7 +218,7 @@ def push_zeros(kv): for o in out: assert_almost_equal(o.asnumpy(), exp) - def verify_residual(kv, threshold, rate): + def verify_residual_2bit(kv, threshold, rate): for j in range(len(keys)): kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*0.4 for g in range(nworker)]) out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)] @@ -197,8 +274,8 @@ def check_compr_random(kv, threshold): # on cpu sum_dequantized_vals = np.zeros(s) for g in range(nworker): - compr, curr_residual[g], decompr = compute_expected_2bit_quantization( - grads_cpy[g], curr_residual[g], threshold) + compr, curr_residual[g], decompr = compute_expected_quantization( + grads_cpy[g], curr_residual[g], threshold, quantize_func) sum_dequantized_vals += (decompr * rate) for g in range(nworker): @@ -206,9 +283,14 @@ def check_compr_random(kv, threshold): pull_init_test(kv) pull_before_push(kv) - push_zeros(kv) - curval = verify_residual(kv, threshold, rate) - check_neg(kv, -1*threshold, rate, curval) + if compression == '1bit': + push_ones(kv, sign=1) + push_ones(kv, sign=-1) + verify_residual_1bit(kv, threshold, rate) + elif compression == '2bit': + push_zeros(kv) + curval = verify_residual_2bit(kv, threshold, rate) + check_neg(kv, -1*threshold, rate, curval) check_compr_random(kv, threshold) ## group keys interface @@ -252,7 +334,10 @@ def test_group_kvstore(kv_type, stype): test_kvstore('local_allreduce_device', stype) ## compression for local kvstore happens only when reduce is on device - test_compress_kvstore('local_allreduce_device') + test_compress_kvstore('local_allreduce_device', '1bit', -.5) + test_compress_kvstore('local_allreduce_device', '1bit', 0) + test_compress_kvstore('local_allreduce_device', '1bit', .5) + test_compress_kvstore('local_allreduce_device', '2bit', .5) for stype in stypes: test_group_kvstore('local_update_cpu', stype) test_group_kvstore('local_allreduce_cpu', stype)