Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix insertion before a node with multiple inputs + support additional broadcasting #551

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions hls4ml/backends/vivado/passes/repack_stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Layer, Merge, Reshape, register_layer
from hls4ml.model.layers import Layer, Merge, Reshape, Concatenate, register_layer
from hls4ml.backends import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate

Expand Down Expand Up @@ -44,8 +44,10 @@ def initialize(self):
broadcast_config_template = """struct config{index} : nnet::broadcast_config {{
static const unsigned in_width = {in_width};
static const unsigned in_height = {in_height};
static const unsigned n_chan = {n_chan};
static const unsigned n_dupl = {n_dupl};
static const unsigned in_chan = {in_chan};
static const unsigned out_width = {out_width};
static const unsigned out_height = {out_height};
static const unsigned out_chan = {out_chan};
}};\n"""
broadcast_include_list = ['nnet_utils/nnet_stream.h']

Expand All @@ -58,8 +60,10 @@ def format(self, node):
params = self._default_config_params(node)
params['in_height'] = node.get_input_variable().shape[0]
params['in_width'] = node.get_input_variable().shape[1]
params['n_chan'] = node.get_input_variable().shape[2]
params['n_dupl'] = int(np.prod(node.get_output_variable().shape) / np.prod(node.get_input_variable().shape))
params['in_chan'] = node.get_input_variable().shape[2]
params['out_height'] = node.get_output_variable().shape[0]
params['out_width'] = node.get_output_variable().shape[1]
params['out_chan'] = node.get_output_variable().shape[2]

return self.template.format(**params)

Expand Down Expand Up @@ -109,7 +113,7 @@ def transform(self, model, node):

class BroadcastStream(OptimizerPass):
def match(self, node):
if isinstance(node, Merge):
if isinstance(node, Merge) and not isinstance(node, Concatenate):
inp1 = node.get_input_variable(node.inputs[0])
inp2 = node.get_input_variable(node.inputs[1])
return inp1.shape != inp2.shape
Expand All @@ -120,23 +124,41 @@ def transform(self, model, node):
if model.config.backend.name not in ['Vivado'] or \
model.config.get_config_value('IOType') != 'io_stream':
return False
inp1 = node.get_input_variable(node.inputs[0])
inp2 = node.get_input_variable(node.inputs[1])
if np.prod(inp1.shape) > np.prod(inp2.shape):

inp = [node.get_input_variable(inp_name) for inp_name in node.inputs]

if np.prod(inp[0].shape) > np.prod(inp[1].shape):
idx = 1
attrs = {
'target_shape': inp1.shape
'target_shape': inp[0].shape
}
else:
idx = 0
attrs = {
'target_shape': inp2.shape
'target_shape': inp[1].shape
}

def supported_broadcast(inp_shape, target_shape):
# Must be (H, W, C)
if not len(inp_shape) == 3:
return False
# Supported: (1, 1, C) -> (H, W, C)
if inp_shape[0] == inp_shape[1] == 1 and inp_shape[2] == target_shape[2]:
return True
# Supported: (H, W, 1) -> (H, W, C)
if inp_shape[2] == 1 and inp_shape[0] == target_shape[0] and inp_shape[1] == target_shape[1]:
return True
return False

brdcst_inp = node.inputs[idx]
inp_shape = node.get_input_variable(brdcst_inp).shape
target_shape = attrs['target_shape']
if not supported_broadcast(inp_shape, target_shape):
raise RuntimeError('Unsupported broadcast type for stream: {} -> {};'.format(inp_shape, target_shape) + \
'Only (1, 1, C) -> (H, W, C) and (H, W, 1) -> (H, W, C) currently supported')
brdcst_out = 'broadcast_' + brdcst_inp
brdcst_layer = model.make_node('Broadcast', brdcst_out, attrs, [brdcst_inp].copy())
model.insert_node(brdcst_layer)
model.insert_node(brdcst_layer, before=node, input_idx=idx)
node.inputs[idx] = brdcst_out

return True
Expand Down
19 changes: 12 additions & 7 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def make_node(self, kind, name, attributes, inputs, outputs=None):
self.output_vars[o] = out_var
return node

def insert_node(self, node, before=None):
def insert_node(self, node, before=None, input_idx=0):
""" Insert a new node into the model graph.

The node to be inserted should be created with `make_node()` function. The optional
Expand All @@ -405,6 +405,7 @@ def insert_node(self, node, before=None):
node (Layer): Node to insert
before (Layer, optional): The next node in sequence before which a
new node should be inserted.
input_idx (int, optional): If the next node takes multiple inputs, the input index
Raises:
Exception: If an attempt to insert a node with multiple inputs is made or if
`before` does not specify a correct node in sequence.
Expand All @@ -414,7 +415,12 @@ def insert_node(self, node, before=None):
raise Exception('Cannot insert a node with more than one input (for now).')

prev_node = node.get_input_node(node.inputs[0])
next_nodes = [x for x in self.graph.values() if x.inputs[0] in prev_node.outputs]
next_nodes = []
for x in self.graph.values():
overlap = [value for value in x.inputs if value in prev_node.outputs]
if overlap:
next_nodes.append(x)

if before is None:
next_node = next((x for x in self.graph.values() if x.inputs[0] in prev_node.outputs), None)
else:
Expand All @@ -423,7 +429,7 @@ def insert_node(self, node, before=None):
next_node = before

if next_node is not None:
next_node.inputs[0] = node.outputs[0]
next_node.inputs[input_idx] = node.outputs[0]

new_graph = OrderedDict()
for k, v in self.graph.items():
Expand Down Expand Up @@ -496,10 +502,9 @@ def _update_model_outputs(self):
All node outputs and inputs are found. The model outputs are set to all node outputs
that are not also node inputs.
'''
node_outputs = np.array([out for node in self.graph.values() for out in node.outputs])
node_inputs = np.array([inp for node in self.graph.values() for inp in node.inputs])
model_outputs = node_outputs[np.isin(node_outputs, node_inputs, invert=True)]
self.outputs = model_outputs.tolist()
node_outputs = [out for node in self.graph.values() for out in node.outputs]
node_inputs = [inp for node in self.graph.values() for inp in node.inputs]
self.outputs = [out for out in node_outputs if out not in node_inputs]

def get_weights_data(self, layer_name, var_name):
return self.reader.get_weights_data(layer_name, var_name)
Expand Down
44 changes: 37 additions & 7 deletions hls4ml/templates/vivado/nnet_utils/nnet_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ namespace nnet {

struct broadcast_config
{
static const unsigned in_height = 10;
static const unsigned in_width = 10;
static const unsigned n_chan = 1;
static const unsigned n_dupl = 2;
static const unsigned in_height = 1;
static const unsigned in_width = 1;
static const unsigned in_chan = 3;
static const unsigned out_height = 2;
static const unsigned out_width = 2;
static const unsigned out_chan = 3;
};

template<class data_T, class res_T, int N>
Expand Down Expand Up @@ -99,11 +101,13 @@ void repack_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
}

template<class data_T, class res_T, typename CONFIG_T>
void broadcast_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
BroadcastLoop: for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan / data_T::size; i++) {
void broadcast_stream_1x1xC(hls::stream<data_T> &data, hls::stream<res_T> &res) {
assert(CONFIG_T::in_height == 1 && CONFIG_T::in_width == 1 && CONFIG_T::in_chan == CONFIG_T::out_chan);
int n_dupl = (CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::out_chan) / (CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan);
BroadcastLoop: for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan / data_T::size; i++) {
#pragma HLS PIPELINE
data_T in_data = data.read();
for (int j = 0; j < CONFIG_T::n_dupl; j++) {
for (int j = 0; j < n_dupl; j++) {
#pragma HLS PIPELINE
res_T out_data;
#pragma HLS DATA_PACK variable=out_data
Expand All @@ -115,6 +119,32 @@ void broadcast_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
}
}
}

template<class data_T, class res_T, typename CONFIG_T>
void broadcast_stream_HxWx1(hls::stream<data_T> &data, hls::stream<res_T> &res) {
assert(CONFIG_T::in_chan == 1 && CONFIG_T::in_height == CONFIG_T::out_height && CONFIG_T::in_width == CONFIG_T::out_width);
BroadcastLoop: for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan / data_T::size; i++) {
#pragma HLS PIPELINE
data_T in_data = data.read();
res_T out_data;
#pragma HLS DATA_PACK variable=out_data
for (int k = 0; k < res_T::size; k++) {
#pragma HLS UNROLL
out_data[k] = in_data[0];
}
res.write(out_data);
}
}

template<class data_T, class res_T, typename CONFIG_T>
void broadcast_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
if(CONFIG_T::in_height == 1 && CONFIG_T::in_width == 1 && CONFIG_T::in_chan == CONFIG_T::out_chan) {
broadcast_stream_1x1xC<data_T, res_T, CONFIG_T>(data, res);
}
else if(CONFIG_T::in_chan == 1 && CONFIG_T::in_height == CONFIG_T::out_height && CONFIG_T::in_width == CONFIG_T::out_width) {
broadcast_stream_HxWx1<data_T, res_T, CONFIG_T>(data, res);
}
}
}

#endif
32 changes: 31 additions & 1 deletion test/pytest/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,36 @@ def test_final_reshape(iotype):
# because of integer inputs and integer weights, we can expect exact matching
np.testing.assert_allclose(y, y_hls, rtol=0)

@pytest.mark.parametrize('shapes, layer', [
(((2, 2, 3), (2, 2, 1)), tf.keras.layers.Concatenate),
(((2, 2, 1), (2, 2, 3)), tf.keras.layers.Concatenate),
(((2, 2, 3), (2, 2, 1)), tf.keras.layers.Add),
(((2, 2, 1), (2, 2, 3)), tf.keras.layers.Add),
(((1, 1, 2), (3, 4, 2)), tf.keras.layers.Add),
(((3, 4, 2), (1, 1, 2)), tf.keras.layers.Add)])
def test_broadcast_stream(shapes, layer):
''' Test case for stream broadcast before Add but not before Concatenate '''
input1 = tf.keras.layers.Input(shape=shapes[0])
input2 = tf.keras.layers.Input(shape=shapes[1])
inputs = [input1, input2]
outputs = layer()(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

# create the ModelGraph
config = hls4ml.utils.config_from_keras_model(model, granularity='model', default_precision='ap_fixed<32,16>')
odir = str(test_root_path / 'hls4mlprj_graph_broadcast_shapes_{}_{}_stream_{}'.format(str(shapes[0]).replace(' ','').replace(',','_').replace('(','').replace(')',''),
str(shapes[1]).replace(' ','').replace(',','_').replace('(','').replace(')',''),
layer.__name__.lower()))
hls_model = hls4ml.converters.convert_from_keras_model(model,
output_dir=odir,
backend='Vivado',
io_type='io_stream',
hls_config=config)
hls_model.compile()


# Test with integers (for exact agreement)
X1 = np.random.randint(0, 100, size=(1,)+shapes[0]).astype(float)
X2 = np.random.randint(0, 100, size=(1,)+shapes[1]).astype(float)
y = model.predict([X1, X2])
y_hls = hls_model.predict([X1, X2]).reshape(y.shape)
np.testing.assert_allclose(y, y_hls, rtol=0)