From 330a4ca725b5fe3689e0c3175a219475241ab370 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sat, 15 Jul 2023 17:39:33 -0500 Subject: [PATCH 1/5] Add support for DepthwiseConv[12]D in Vitis and Vivado io-stream --- .../vivado/passes/convolution_templates.py | 21 +++++++- hls4ml/converters/keras/convolution.py | 15 ++++-- hls4ml/model/layers.py | 18 +++++++ .../vitis/nnet_utils/nnet_sepconv1d_stream.h | 10 ++++ .../vitis/nnet_utils/nnet_sepconv2d_stream.h | 11 ++++ .../vivado/nnet_utils/nnet_sepconv1d_stream.h | 15 ++++++ .../vivado/nnet_utils/nnet_sepconv2d_stream.h | 16 ++++++ test/pytest/test_keras_api.py | 54 +++++++++++++++++++ 8 files changed, 153 insertions(+), 7 deletions(-) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 9a7b10a6f4..a142276834 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -1,6 +1,14 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import ( + Conv1D, + Conv2D, + Conv2DBatchnorm, + DepthwiseConv1D, + DepthwiseConv2D, + SeparableConv1D, + SeparableConv2D, +) # Shared multiplication template @@ -52,13 +60,16 @@ const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" conv1d_function_template = 'nnet::conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +depthconv1d_function_template = ( + 'nnet::depthwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) conv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_conv1d_stream.h'] class Conv1DConfigTemplate(LayerConfigTemplate): def __init__(self): - super().__init__(Conv1D) + super().__init__((Conv1D, DepthwiseConv1D)) self.template = conv1d_config_template self.mult_template = conv_mult_config_template @@ -106,6 +117,12 @@ def format(self, node): return self.template.format(**params) +class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate): + def __init__(self): + super(Conv1DFunctionTemplate, self).__init__(DepthwiseConv1D, include_header=sepconv1d_include_list) + self.template = depthconv1d_function_template + + # Conv2D Templates conv2d_config_template = """struct config{index} : nnet::conv2d_config {{ diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 5ebd2abee1..39780f6dc6 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -2,7 +2,7 @@ from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format -@keras_handler('Conv1D', 'SeparableConv1D') +@keras_handler('Conv1D', 'SeparableConv1D', 'DepthwiseConv1D') def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): assert 'Conv1D' in keras_layer['class_name'] @@ -12,14 +12,19 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): if layer['class_name'] in ['Conv1D', 'QConv1D']: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'kernel') - else: # SeparableConv1D - layer['depthwise_data'], layer['pointwise_data'], layer['bias_data'] = get_weights_data( - data_reader, layer['name'], ['depthwise_kernel', 'pointwise_kernel', 'bias'] + elif layer['class_name'] in ['SeparableConv1D', 'QSeparableConv1D']: + layer['depthwise_data'], layer['pointwise_data'] = get_weights_data( + data_reader, layer['name'], ['depthwise_kernel', 'pointwise_kernel'] ) + else: # DepthwiseConv1D + layer['depthwise_data'] = get_weights_data(data_reader, layer['name'], 'depthwise_kernel') layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') - layer['n_filt'] = keras_layer['config']['filters'] + if 'filters' in keras_layer['config']: + layer['n_filt'] = keras_layer['config']['filters'] + else: + layer['n_filt'] = layer['n_chan'] layer['filt_width'] = keras_layer['config']['kernel_size'][0] layer['stride_width'] = keras_layer['config']['strides'][0] layer['padding'] = keras_layer['config']['padding'] diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index d9da2cc741..f7695c5658 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -472,6 +472,23 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class DepthwiseConv1D(Conv1D): + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_width'], self.attributes['n_chan']] + dims = [f'OUT_HEIGHT_{self.index}', f'N_CHAN_{self.index}'] + else: + shape = [self.attributes['n_chan'], self.attributes['out_width']] + dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] + self.add_output_variable(shape, dims) + + self.add_weights_variable( + name='weight', var_name='w{index}', data='depthwise', quantizer=self.get_attr('depthwise_quantizer') + ) + + self.add_bias(quantizer=self.get_attr('bias_quantizer')) + + class Conv2D(Layer): _expected_attributes = [ Attribute('in_height'), @@ -1314,6 +1331,7 @@ def initialize(self): 'QConv2D': Conv2D, 'QConv2DBatchnorm': Conv2DBatchnorm, 'SeparableConv1D': SeparableConv1D, + 'DepthwiseConv1D': DepthwiseConv1D, 'SeparableConv2D': SeparableConv2D, 'DepthwiseConv2D': DepthwiseConv2D, 'BatchNormalization': BatchNormalization, diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_sepconv1d_stream.h b/hls4ml/templates/vitis/nnet_utils/nnet_sepconv1d_stream.h index 7d7b6ada98..961f563eeb 100644 --- a/hls4ml/templates/vitis/nnet_utils/nnet_sepconv1d_stream.h +++ b/hls4ml/templates/vitis/nnet_utils/nnet_sepconv1d_stream.h @@ -28,6 +28,16 @@ void depthwise_conv_1d_buffer_cl(hls::stream &data, hls::stream & } } +template +void depthwise_conv_1d_cl(hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::implementation == conv_implementation::linebuffer && + "Only \"linebuffer\" implementation is supported in Vitis HLS."); + #pragma HLS inline recursive + depthwise_conv_1d_buffer_cl(data, res, weights, biases); +} + template void pointwise_conv_1d_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_sepconv2d_stream.h b/hls4ml/templates/vitis/nnet_utils/nnet_sepconv2d_stream.h index 7df04da23b..a63976dbf4 100644 --- a/hls4ml/templates/vitis/nnet_utils/nnet_sepconv2d_stream.h +++ b/hls4ml/templates/vitis/nnet_utils/nnet_sepconv2d_stream.h @@ -92,6 +92,17 @@ void pointwise_conv_2d_cl(hls::stream &data, hls::stream &res, } } +template +void depthwise_conv_2d_cl( + hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::implementation == conv_implementation::linebuffer && + "Only \"linebuffer\" implementation is supported in Vitis HLS."); + #pragma HLS inline recursive + depthwise_conv_2d_buffer_cl(data, res, weights, biases); +} + template void separable_conv_2d_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::depthwise_config::weight_t diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h index 7e4630ab6c..740211e509 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h @@ -57,6 +57,21 @@ void depthwise_conv_1d_buffer_cl(hls::stream &data, hls::stream & } } +template +void depthwise_conv_1d_cl(hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + #pragma HLS inline recursive + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + depthwise_conv_1d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + depthwise_conv_1d_encoded_cl(data, res, weights, biases); + break; + } +} + template void pointwise_conv_1d_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h index 856de0e55e..dd585474e5 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h @@ -75,6 +75,22 @@ void depthwise_conv_2d_buffer_cl( } } +template +void depthwise_conv_2d_cl( + hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + #pragma HLS inline recursive + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + depthwise_conv_2d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + depthwise_conv_2d_encoded_cl(data, res, weights, biases); + break; + } +} + template void pointwise_conv_2d_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index a0540e17b9..762d00a582 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -12,6 +12,8 @@ Conv1D, Conv2D, Dense, + DepthwiseConv1D, + DepthwiseConv2D, LeakyReLU, MaxPooling1D, MaxPooling2D, @@ -297,6 +299,58 @@ def test_conv2d(chans, padds, backend, io_type): assert list(hls_model.get_layers())[1].attributes['pad_right'] == 0 +# Currently only Vivado and Vitis is supported for io_stream. +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_depthwise2d(backend, io_type): + ''' + Test proper handling of DepthwiseConv2D + ''' + X = np.random.rand(10, 32, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = tf.keras.models.Sequential() + model.add(DepthwiseConv2D(kernel_size=(3, 3), input_shape=(32, 32, 3))) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name') + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) + + +# Currently only Vivado and Vitis is supported for io_stream. +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_depthwise1d(backend, io_type): + ''' + Test proper handling of QConv2DBatchnorm. + ''' + X = np.random.rand(10, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = tf.keras.models.Sequential() + model.add(DepthwiseConv1D(kernel_size=3, input_shape=(32, 3))) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name') + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) + + pooling_layers = [MaxPooling1D, MaxPooling2D, AveragePooling1D, AveragePooling2D] From 0014903c929c4f992a1cdebdfad9b992ff0d0628 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 17 Jul 2023 14:14:08 -0500 Subject: [PATCH 2/5] add basic QDepthwiseConv2D support, fix activation argument parsing --- hls4ml/converters/keras/qkeras.py | 13 ++++++++++ hls4ml/converters/keras_to_hls.py | 2 +- hls4ml/model/layers.py | 1 + test/pytest/test_qkeras.py | 43 +++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/hls4ml/converters/keras/qkeras.py b/hls4ml/converters/keras/qkeras.py index 94b0ad5d00..a8656030d4 100644 --- a/hls4ml/converters/keras/qkeras.py +++ b/hls4ml/converters/keras/qkeras.py @@ -49,6 +49,19 @@ def parse_qconv_layer(keras_layer, input_names, input_shapes, data_reader): return layer, output_shape +@keras_handler('QDepthwiseConv2D') +def parse_qdepthwiseqconv_layer(keras_layer, input_names, input_shapes, data_reader): + layer, output_shape = parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader) + + layer['depthwise_quantizer'] = get_quantizer_from_config(keras_layer, 'depthwise') + if keras_layer['config']['bias_quantizer'] is not None: + layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias') + else: + layer['bias_quantizer'] = None + + return layer, output_shape + + @keras_handler('QActivation') def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader): assert keras_layer['class_name'] == 'QActivation' diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index 2122a63c33..1d2376f576 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -301,7 +301,7 @@ def parse_keras_model(model_arch, reader): act_layer['class_name'] = 'QActivation' act_layer['config'] = { 'name': layer['name'] + '_' + act_details['class_name'], - 'activation': act_details['class_name'], + 'activation': act_details, } act_layer, output_shape = layer_handlers['QActivation'](act_layer, None, [output_shape], reader) else: diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f7695c5658..60c1c143d7 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1334,6 +1334,7 @@ def initialize(self): 'DepthwiseConv1D': DepthwiseConv1D, 'SeparableConv2D': SeparableConv2D, 'DepthwiseConv2D': DepthwiseConv2D, + 'QDepthwiseConv2D': DepthwiseConv2D, 'BatchNormalization': BatchNormalization, 'QBatchNormalization': BatchNormalization, 'MaxPooling1D': Pooling1D, diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index e18f8a65e4..c249b365b2 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -4,6 +4,7 @@ import numpy as np import pytest from qkeras.qconv2d_batchnorm import QConv2DBatchnorm +from qkeras.qconvolutional import QDepthwiseConv2D from qkeras.qlayers import QActivation, QDense from qkeras.quantizers import ( binary, @@ -400,6 +401,48 @@ def test_qconv2dbn(randX_100_8_8_1, backend, io_type): np.testing.assert_array_equal(y_qkeras, y_hls4ml.reshape(y_qkeras.shape)) +@pytest.fixture(scope='module') +def randX_10_32_32_3(): + return np.random.rand(10, 32, 32, 3) + + +# Currently only Vivado and Vitis is supported for io_stream. +# Note, qkeras only supports 2d version of depthwise +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_qdepthwiseconv2d(randX_10_32_32_3, backend, io_type): + ''' + Test proper handling of QConv2DBatchnorm. + ''' + X = randX_10_32_32_3 + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = Sequential() + model.add( + QDepthwiseConv2D( + kernel_size=(3, 3), + input_shape=(32, 32, 3), + depthwise_quantizer='quantized_bits(6, 0, alpha=1)', + depthwise_initializer='ones', + bias_quantizer='quantized_bits(4, 0, alpha=1)', + bias_initializer='zeros', + activation='quantized_relu(3, 0)', + ) + ) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name') + output_dir = str(test_root_path / f'hls4mlprj_qkeras_qdepthwiseconv2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) + + @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('strategy', ['Latency', 'Resource']) From 923e7ae7ea09d5d8c17330d0ae269f612bef5cd7 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Wed, 19 Jul 2023 10:57:08 -0500 Subject: [PATCH 3/5] remove redundant case statement --- .../vivado/nnet_utils/nnet_sepconv1d_stream.h | 12 ++---------- .../vivado/nnet_utils/nnet_sepconv2d_stream.h | 13 ++----------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h index 740211e509..e1f3978eef 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h @@ -109,16 +109,8 @@ void separable_conv_1d_cl(hls::stream &data, hls::stream &res, unsigned res_depth = CONFIG_T::depthwise_config::out_width; #pragma HLS STREAM variable=depthwise_res depth=res_depth - switch (CONFIG_T::depthwise_config::implementation) { - case conv_implementation::linebuffer: - depthwise_conv_1d_buffer_cl( - data, depthwise_res, depthwise_weights, depthwise_biases); - break; - case conv_implementation::encoded: - depthwise_conv_1d_encoded_cl( - data, depthwise_res, depthwise_weights, depthwise_biases); - break; - } + depthwise_conv_1d_cl(data, depthwise_res, depthwise_weights, + depthwise_biases); pointwise_conv_1d_cl(depthwise_res, res, pointwise_weights, pointwise_biases); } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h index dd585474e5..81d95e93d6 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h @@ -132,17 +132,8 @@ void separable_conv_2d_cl(hls::stream &data, hls::stream &res, unsigned res_depth = CONFIG_T::depthwise_config::out_height * CONFIG_T::depthwise_config::out_width; #pragma HLS STREAM variable=depthwise_res depth=res_depth - switch (CONFIG_T::depthwise_config::implementation) { - case conv_implementation::linebuffer: - depthwise_conv_2d_buffer_cl( - data, depthwise_res, depthwise_weights, depthwise_biases); - break; - case conv_implementation::encoded: - depthwise_conv_2d_encoded_cl( - data, depthwise_res, depthwise_weights, depthwise_biases); - break; - } - + depthwise_conv_2d_cl(data, depthwise_res, depthwise_weights, + depthwise_biases); pointwise_conv_2d_cl(depthwise_res, res, pointwise_weights, pointwise_biases); } From c16031889525cf8fd0aa1cba56589150b5cc7648 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sun, 27 Aug 2023 16:03:35 -0500 Subject: [PATCH 4/5] increase bitwidth to make test comparisons succeed --- test/pytest/test_keras_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index 762d00a582..6852d8b094 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -312,7 +312,7 @@ def test_depthwise2d(backend, io_type): model.add(DepthwiseConv2D(kernel_size=(3, 3), input_shape=(32, 32, 3))) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<32,12>') output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type From b26621b676f86a3e638d4d963a0dfa3cd618bd8e Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 28 Aug 2023 17:57:50 -0500 Subject: [PATCH 5/5] Fix docstring text, increase default precision for sensitive tests --- test/pytest/test_keras_api.py | 2 +- test/pytest/test_qkeras.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index 6852d8b094..64f68302ef 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -330,7 +330,7 @@ def test_depthwise2d(backend, io_type): @pytest.mark.parametrize('io_type', ['io_stream']) def test_depthwise1d(backend, io_type): ''' - Test proper handling of QConv2DBatchnorm. + Test proper handling of DepthwiseConv1D. ''' X = np.random.rand(10, 32, 3) X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index c249b365b2..d567c7f9fb 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -388,7 +388,7 @@ def test_qconv2dbn(randX_100_8_8_1, backend, io_type): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<24,8>') output_dir = str(test_root_path / f'hls4mlprj_qkeras_qconv2dbn_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -412,7 +412,7 @@ def randX_10_32_32_3(): @pytest.mark.parametrize('io_type', ['io_stream']) def test_qdepthwiseconv2d(randX_10_32_32_3, backend, io_type): ''' - Test proper handling of QConv2DBatchnorm. + Test proper handling of QDepthwiseConv2D. ''' X = randX_10_32_32_3 X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> @@ -422,15 +422,14 @@ def test_qdepthwiseconv2d(randX_10_32_32_3, backend, io_type): kernel_size=(3, 3), input_shape=(32, 32, 3), depthwise_quantizer='quantized_bits(6, 0, alpha=1)', - depthwise_initializer='ones', bias_quantizer='quantized_bits(4, 0, alpha=1)', - bias_initializer='zeros', + bias_initializer='he_normal', activation='quantized_relu(3, 0)', ) ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<24,8>') output_dir = str(test_root_path / f'hls4mlprj_qkeras_qdepthwiseconv2d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type