Skip to content

Commit

Permalink
Add multi_dense path in generate_nn
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 10, 2024
1 parent 9be877d commit 5e99568
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,17 +741,28 @@ def initializer_generator(seed, i_layer):
# list_of_pdf_layers[d][r] is the layer at depth d for replica r
list_of_pdf_layers = []
for i_layer, (nodes_out, activation) in enumerate(zip(nodes_list, activations)):
layers = [
base_layer_selector(
if layer_type == "multi_dense":
layers = base_layer_selector(
layer_type,
kernel_initializer=initializer_generator(replica_seed, i_layer),
replica_seeds=replica_seeds,
kernel_initializer=initializer_generator(0, i_layer),
units=nodes_out,
activation=activation,
input_shape=(nodes_in,),
replica_input=(i_layer != 0),
**custom_args,
)
for replica_seed in replica_seeds
]
else:
layers = [
base_layer_selector(
layer_type,
kernel_initializer=initializer_generator(replica_seed, i_layer),
units=nodes_out,
activation=activation,
input_shape=(nodes_in,),
**custom_args,
)
for replica_seed in replica_seeds
]
list_of_pdf_layers.append(layers)
nodes_in = int(nodes_out)

Expand All @@ -766,6 +777,14 @@ def initializer_generator(seed, i_layer):
list_of_pdf_layers[-1] = [lambda x: concat(layer(x)) for layer in list_of_pdf_layers[-1]]

# Apply all layers to the input to create the models
if layer_type == "multi_dense":
pdfs = x_input
for layer in list_of_pdf_layers:
pdfs = layer(pdfs)
model = MetaModel({'NN_input': x_input}, pdfs, name=f"NNs")

return model

pdfs = [layer(x_input) for layer in list_of_pdf_layers[0]]

for layers in list_of_pdf_layers[1:]:
Expand Down

0 comments on commit 5e99568

Please sign in to comment.