diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py index edbfcc56b7..63c3693b0b 100644 --- a/hls4ml/backends/quartus/passes/core_templates.py +++ b/hls4ml/backends/quartus/passes/core_templates.py @@ -71,8 +71,6 @@ def format(self, node): static const bool store_weights_in_bram = false; typedef {bias_t.name} bias_t; typedef {scale_t.name} scale_t; - template - using product = nnet::product::{product_type}; }};\n""" batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index c3e35bf1d5..22aa5837ee 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -13,8 +13,8 @@ typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; - template - using product = nnet::product::{product_type}; + template + using product = nnet::product::{product_type}; }};\n""" # Conv1D templates diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 6869567359..201562f7fb 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -18,8 +18,8 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; - template - using product = nnet::product::{product_type}; + template + using product = nnet::product::{product_type}; }};\n""" dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' @@ -62,8 +62,8 @@ def format(self, node): static const bool store_weights_in_bram = false; typedef {bias_t.name} bias_t; typedef {scale_t.name} scale_t; - template - using product = nnet::product::{product_type}; + template + using product = nnet::product::{product_type}; }};\n""" batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' diff --git a/hls4ml/backends/vivado/passes/merge_templates.py b/hls4ml/backends/vivado/passes/merge_templates.py index 219a7f4e29..863512c4c5 100644 --- a/hls4ml/backends/vivado/passes/merge_templates.py +++ b/hls4ml/backends/vivado/passes/merge_templates.py @@ -50,8 +50,8 @@ def format(self, node): static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; typedef {accum_t.name} accum_t; - template - using product = nnet::product::{product_type}; + template + using product = nnet::product::{product_type}; }};\n""" class DotConfigTemplate(LayerConfigTemplate): diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h index 18645042f6..8075a96714 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h @@ -51,38 +51,34 @@ struct dense_config // partitioning arrays cyclically to go with roll factors? }; -template -inline typename std::enable_if>::value - and std::is_same>::value, ac_int<1, false>>::type -product(ac_int<1, false> a, ac_int<1, false> w){ +inline ac_int<1, false> product(ac_int<1, false> a, ac_int<1, false> w) +{ // specialisation for 1-bit weights and incoming data - return (ret_T) (a == w); + return (a == w); } -template -inline typename std::enable_if<(not std::is_same>::value) - and std::is_same>::value, ret_T>::type -product(data_T a, ac_int<1, false> w){ +template +auto product(data_T a, ac_int<1, false> w) -> decltype(-a) +{ // Specialisation for 1-bit weights, arbitrary data - return w == 0 ? (ret_T) -a : a; + if (w == 0) return -a; + else return a; } -template -inline typename std::enable_if<(not std::is_same>::value) - and std::is_same>::value, ret_T>::type -product(data_T a, ac_int<2, true> w){ +template +auto product(data_T a, ac_int<2, true> w) -> decltype(-a) +{ // Specialisation for 2-bit weights, arbitrary data - if (w == 0) return (ret_T) 0; - else if(w == -1) return (ret_T) -a; - else return (ret_T) a; // if(w == 1) + if (w == 0) return 0; + else if(w == -1) return -a; + else return a; // if(w == 1) } -template -inline typename std::enable_if<(not std::is_same>::value) - and (not std::is_same>::value), ret_T>::type -product(data_T a, weight_T w){ +template +auto product(data_T a, weight_T w) -> decltype(a*w) +{ // 'Normal' product - return (ret_T)(a * w); + return a * w; } template @@ -138,7 +134,7 @@ void dense_rf_gt( uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (w_index >= CONFIG_T::reuse_factor_rounded*CONFIG_T::block_factor_rounded) continue; int data_index = d_index[ir][im]; - tmp_acc[im] = product(data[data_index], weights[w_index]); + tmp_acc[im] = product(data[data_index], weights[w_index]); } hls_register typename CONFIG_T::accum_t mult[CONFIG_T::multiplier_limit]; ResetMult: @@ -192,7 +188,7 @@ void dense_rf_lt( for (int im = 0, in_index = ir; im < CONFIG_T::block_factor; im++) { uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (ir + CONFIG_T::reuse_factor * im >= CONFIG_T::n_in*CONFIG_T::n_out) continue; - mult[im] = product(data[in_index], weights[w_index]); + mult[im] = product(data[in_index], weights[w_index]); in_index += CONFIG_T::reuse_factor; if (in_index >= CONFIG_T::n_in) in_index = ir; } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h index 9a5cff0d3d..edc6ff3205 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h @@ -43,8 +43,8 @@ struct batchnorm_config static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? - template - using product = nnet::product::mult; + template + using product = nnet::product::mult; }; template @@ -71,7 +71,7 @@ void normalize( #pragma HLS ARRAY_PARTITION variable=bias complete int multiplier_limit = ceil(float(CONFIG_T::n_in) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product::limit(multiplier_limit); + CONFIG_T::template product::limit(multiplier_limit); } else if (CONFIG_T::io_type == io_serial) { #pragma HLS ARRAY_RESHAPE variable=scale complete dim=1 @@ -87,10 +87,10 @@ void normalize( } if (CONFIG_T::n_filt==-1) { - res[ires] = CONFIG_T::template product::product(data[ires], scale[ires]) + bias[ires]; + res[ires] = CONFIG_T::template product::product(data[ires], scale[ires]) + bias[ires]; } else { int norm_index = ires%CONFIG_T::n_filt; - res[ires] = CONFIG_T::template product::product(data[ires], scale[norm_index]) + bias[norm_index]; + res[ires] = CONFIG_T::template product::product(data[ires], scale[norm_index]) + bias[norm_index]; } } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h index 382887fed7..826bdafe9a 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h @@ -43,7 +43,7 @@ void normalize( constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit; - CONFIG_T::template product::limit(multiplier_limit); + CONFIG_T::template product::limit(multiplier_limit); BatchNormLoop: for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) { #pragma HLS PIPELINE II=ii @@ -60,7 +60,7 @@ void normalize( } else { norm_index = j % CONFIG_T::n_filt; } - out_data[j] = CONFIG_T::template product::product(in_data[j], scale[norm_index]) + bias[norm_index]; + out_data[j] = CONFIG_T::template product::product(in_data[j], scale[norm_index]) + bias[norm_index]; } res.write(out_data); diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h index deb1c042da..c9785335aa 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h @@ -30,8 +30,8 @@ struct dense_config static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? // Product function to use - template - using product = nnet::product::mult; + template + using product = nnet::product::mult; }; template diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h index adfaa0e1b7..dc803ff2bc 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h @@ -86,7 +86,7 @@ void dense_compressed( auto weight_cache = weights[w].weight; data_T data_cache = data[row]; //mult[col] += weight_cache * data_cache; - typename CONFIG_T::accum_t prod = CONFIG_T::template product::product(data_cache, weight_cache); + typename CONFIG_T::accum_t prod = CONFIG_T::template product::product(data_cache, weight_cache); fill_mult(col, mult, prod); } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h index 4a04671fd6..2bbab0496b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h @@ -54,7 +54,7 @@ void dense_latency( #pragma HLS ARRAY_PARTITION variable=acc complete int multiplier_limit = ceil(float(CONFIG_T::n_in*CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::n_zeros) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product::limit(multiplier_limit); + CONFIG_T::template product::limit(multiplier_limit); } else if (CONFIG_T::io_type == io_serial){ // Only reduce cycle_factor if n_out is evenly divisible by reuse_factor @@ -90,10 +90,10 @@ void dense_latency( Product2: for(int jj = 0; jj < CONFIG_T::n_out; jj++) { if (CONFIG_T::io_type == io_serial) { int multiplier_limit = ceil(float(CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product::limit(multiplier_limit); + CONFIG_T::template product::limit(multiplier_limit); } int index = ii*CONFIG_T::n_out+jj; - mult[index] = CONFIG_T::template product::product(cache, weights[index]); + mult[index] = CONFIG_T::template product::product(cache, weights[index]); } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index 756a627434..c0e5d17591 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -73,7 +73,8 @@ void dense_resource_rf_leq_nin( for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); // Increment w_index w_index += rufactor; @@ -157,7 +158,8 @@ void dense_resource_rf_gt_nin_rem0( MultLoop: for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); w_index += rufactor; if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds @@ -223,7 +225,7 @@ void dense_resource_rf_gt_nin( int w_index = ir + rufactor * im; int in_index = w_index % nin; if (w_index >= CONFIG_T::n_in*CONFIG_T::n_out) continue; // check out of bounds - tmpmult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); + tmpmult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); } typename CONFIG_T::accum_t mult[multiplier_limit]; diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h index 48a5e172d3..b103533e02 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h @@ -38,8 +38,8 @@ struct dot_config { static const unsigned reuse_factor = 1; typedef float accum_t; // Product function to use - template - using product = nnet::product::mult; + template + using product = nnet::product::mult; }; struct concat_config { @@ -129,7 +129,7 @@ void dot1d( #pragma HLS PIPELINE II=CONFIG_T::reuse_factor constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); - CONFIG_T::template product::limit(multiplier_limit); + CONFIG_T::template product::limit(multiplier_limit); typename CONFIG_T::accum_t mult[CONFIG_T::n_in]; #pragma HLS ARRAY_PARTITION variable=mult complete @@ -137,7 +137,7 @@ void dot1d( Product: for(int i_mult=0; i_mult < CONFIG_T::n_in; i_mult++) { #pragma HLS UNROLL - mult[i_mult] = CONFIG_T::template product::product(data1[i_mult], data2[i_mult]); + mult[i_mult] = CONFIG_T::template product::product(data1[i_mult], data2[i_mult]); } Accum: for(int i_acc = 0; i_acc < CONFIG_T::n_in; i_acc++) { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index 3a597f0382..586bc65aeb 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -5,73 +5,74 @@ #include "nnet_helpers.h" #include "hls_stream.h" #include +#include namespace nnet { namespace product{ /* --- - * 5 different methods to perform the product of input and weight, depending on the - * types of each. + * different methods to perform the product of input and weight, depending on the + * types of each. * --- */ -template class Product{ public: - static y_T product(x_T a, w_T w){ - // 'Normal' product - #pragma HLS INLINE - return a * w; - } static void limit(unsigned multiplier_limit) {} // Nothing to do here }; -template -class both_binary : public Product{ +template +class both_binary : public Product{ public: - static y_T product(x_T a, w_T w){ + static x_T product(x_T a, w_T w){ // specialisation for 1-bit weights and incoming data #pragma HLS INLINE return a == w; } }; -template -class weight_binary : public Product{ +template +class weight_binary : public Product{ public: - static y_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-a) + { // Specialisation for 1-bit weights, arbitrary data #pragma HLS INLINE - return w == 0 ? (x_T) -a : a; + if (w == 0) return -a; + else return a; } }; -template -class data_binary : public Product{ +template +class data_binary : public Product{ public: - static y_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-w) + { // Specialisation for 1-bit data, arbitrary weight #pragma HLS INLINE - return a == 0 ? (w_T) -w : w; + if (a == 0) return -w; + else return w; } }; -template -class weight_ternary : public Product{ +template +class weight_ternary : public Product{ public: - static y_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-a) + { // Specialisation for 2-bit weights, arbitrary data #pragma HLS INLINE - if (w == 0) return (x_T) 0; - else if(w == -1) return (x_T) -a; - else return (x_T) a; // if(w == 1) + if (w == 0) return 0; + else if(w == -1) return -a; + else return a; // if(w == 1) } }; -template -class mult : public Product{ +template +class mult : public Product{ public: - static y_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(a*w) + { // 'Normal' product #pragma HLS INLINE return a * w; @@ -82,16 +83,20 @@ class mult : public Product{ } }; -template -class weight_exponential : public Product{ +template +class weight_exponential : public Product{ public: - static y_T product(x_T a, w_T w){ + // Construct the return type from the multiplication equivalent to the largest shifts + // ap_int is the type if the multiplicand equivalent to the largest lshift << + // ap_fixed is the type of the multiplicand equivalent to the largest rshift >> + using r_T = decltype(x_T(0) * (ap_int(1)+ap_fixed(1))); + static r_T product(x_T a, w_T w){ // Shift product for exponential weights #pragma HLS INLINE // shift by the exponent. Negative weights shift right - y_T y = a << w.weight; + r_T y = static_cast(a) << w.weight; // negate or not depending on weight sign - return w.sign == 1 ? (y_T) y : (y_T) -y; + return w.sign == 1 ? y : static_cast(-y); } }; diff --git a/test/pytest/test_batchnorm.py b/test/pytest/test_batchnorm.py new file mode 100644 index 0000000000..25744e7f62 --- /dev/null +++ b/test/pytest/test_batchnorm.py @@ -0,0 +1,44 @@ +import pytest +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import BatchNormalization +import numpy as np +import hls4ml + + +in_shape = 16 +atol = 5e-3 + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + X = np.random.rand(100, in_shape) + return X + + +@pytest.fixture(scope='module') +def model(): + model = Sequential() + model.add(BatchNormalization(input_shape=(in_shape,))) + model.compile() + return model + + +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_global_pool1d(model, data, io_type): + + config = hls4ml.utils.config_from_keras_model(model, + default_precision='ap_fixed<32,1>', + granularity='name') + + hls_model = hls4ml.converters.convert_from_keras_model(model, + hls_config=config, + io_type=io_type, + output_dir=f'hls4mlprj_batchnorm_{io_type}', + part='xcvu9p-flgb2104-2-i') + hls_model.compile() + + + # Predict + y_keras = np.squeeze(model.predict(data)) + y_hls = hls_model.predict(data) + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index 0e424f64ee..820cb431dc 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -37,7 +37,7 @@ def test_dense(backend): keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_dense') + output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) @@ -78,7 +78,7 @@ def test_activations(activation_function, backend): X_input = np.random.rand(100,1) keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}'.format(activation_function.__class__.__name__)) + output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}_{}'.format(activation_function.__class__.__name__, backend)) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() hls_prediction = hls_model.predict(X_input)