Skip to content

Commit

Permalink
Softsign optimization (fastmachinelearning#585)
Browse files Browse the repository at this point in the history
* Softsign LUT optimization

* Test file: Activation softsign

* Changing minimal accurancy to 9.8 and adding new texts using Vivado and Quartus

Co-authored-by: Nemer Chiedde <chiedde@marchied.in2p3.fr>
  • Loading branch information
nemerchiedde and Nemer Chiedde committed Sep 14, 2022
1 parent 0bebc33 commit 3806792
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 8 deletions.
23 changes: 18 additions & 5 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,30 @@ void softplus(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
template<class data_T, class res_T, typename CONFIG_T>
void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
static const int MAX_VALUE=8;
// Initialize the lookup table
#include "activation_tables/softsign_table.tb"

// Index into the lookup table based on data
#pragma unroll
for (int ii=0; ii<CONFIG_T::n_in; ii++) {
ac_int<16> data_round = (data[ii]*CONFIG_T::table_size/16).to_int();
ac_int<16> index = data_round + 8*CONFIG_T::table_size/16;
if (index < 0) index = 0;
if (index > CONFIG_T::table_size-1) index = CONFIG_T::table_size-1;
res[ii] = (res_T) softsign_table[index];
data_T temp hls_register;
res_T temp2 hls_register;
if(data[ii] < 0 ){
temp = -data[ii];
}
else{
temp = data[ii];
}
ac_int<16> index = (temp*CONFIG_T::table_size/MAX_VALUE).to_int();
if (temp > MAX_VALUE) index = CONFIG_T::table_size-1;
temp2 = (res_T) softsign_table[index];
if(data[ii] < 0 ){
res[ii] = -temp2;
}
else{
res[ii] = temp2;
}
}
}

Expand Down
11 changes: 8 additions & 3 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ def __write_softplus_table(self, model, path):
h_file.close()

def __write_softsign_table(self, model, path):
MAX_VALUE = 8
MIN_VALUE = 0
table_name = 'softsign_table'
table_size = self.__get_table_size(model, 'softsign')

Expand All @@ -893,10 +895,13 @@ def __write_softsign_table(self, model, path):

sep = ''
for i in range(table_size):
in_val = 2 * 8.0 * (i - float(table_size) / 2.0) / float(table_size)

in_val = i * (MAX_VALUE-MIN_VALUE)/float(table_size) + (MAX_VALUE-MIN_VALUE)/(float(table_size)*2) + MIN_VALUE

real_val = in_val / (np.fabs(in_val) + 1.)
h_file.write(sep + str(real_val))
sep = ", "
if(real_val >= 0):
h_file.write(sep + str(real_val))
sep = ", "

h_file.write('};\n')
h_file.close()
Expand Down
54 changes: 54 additions & 0 deletions test/pytest/test_softsign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import hls4ml
import tensorflow as tf
import numpy as np
import pytest
from sklearn.metrics import accuracy_score
from pathlib import Path

test_root_path = Path(__file__).parent

def flat_distribution(shape):
return np.random.rand(*shape)


@pytest.fixture()
def generate_data(function, input_shape):
return function((1000, *input_shape))


# TODO: include latency strategy with flat_distribution when it can be made to pass
@pytest.mark.parametrize('backend,strategy,function,input_shape,io_type', [
('Vivado', 'stable', flat_distribution, (4,), 'io_parallel'),
('Quartus', 'stable', flat_distribution, (4,), 'io_parallel'),
# IO_stram avaliable just for VIVADO
('Vivado', 'stable', flat_distribution, (4,), 'io_stream'),
('Vivado', 'stable', flat_distribution, (4, 4, 3), 'io_stream')
])
def test_softsign(backend, strategy, generate_data, input_shape, io_type):
X = generate_data
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softsign', name='softsign'))
model.compile()

f_type = 'ac_fixed<18,8,true,AC_RND,AC_SAT>' if backend == 'Quartus' else 'ap_fixed<18,8,AP_RND,AP_SAT>'
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name')
cfg['LayerName']['softsign']['Strategy'] = strategy
cfg['LayerName']['softsign']['inv_table_t'] = f_type
cfg['LayerName']['softsign']['exp_table_t'] = f_type

odir = str(test_root_path / 'hls4mlprj_softsign_{}'.format(strategy))
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, io_type=io_type,
output_dir=odir, backend=backend)
hls_model.compile()

y_keras = model.predict(X)
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)

acc_hls4ml = accuracy_score(np.argmax(y_keras, axis=-1).ravel(), np.argmax(y_hls4ml, axis=-1).ravel())

print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml))

assert acc_hls4ml >= 0.98


0 comments on commit 3806792

Please sign in to comment.