Skip to content

Commit

Permalink
Implement MultiDense layer
Browse files Browse the repository at this point in the history
Add MultiDense layer

Add MultiDense layer improvements

Recreate initializer per replica to make sure seed is properly set

Add tolerences to test

Add multi_dense path in generate_nn

Add MultiDropout

Replace old dense layer everywhere

Remove MultiDropout, not necessary

Update developing weights structure

Remove MultiDropout once more

Fix naming inconsistency wrt parallel-prefactor
  • Loading branch information
APJansen committed Jan 29, 2024
1 parent b1a36c4 commit d8f28ff
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 45 deletions.
Binary file modified n3fit/runcards/examples/developing_weights.h5
Binary file not shown.
80 changes: 52 additions & 28 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,44 @@
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
"""

from tensorflow.keras.layers import Lambda, LSTM, Dropout, Concatenate
from tensorflow.keras.layers import concatenate, Input # pylint: disable=unused-import
from tensorflow import expand_dims, math, nn
from tensorflow.keras.layers import ( # pylint: disable=unused-import
Dropout,
Input,
Lambda,
concatenate,
)
from tensorflow.keras.layers import Dense as KerasDense
from tensorflow import expand_dims
from tensorflow.keras.layers import LSTM, Concatenate # pylint: disable=unused-import
from tensorflow.keras.regularizers import l1_l2
from tensorflow import nn, math

from n3fit.backends import MetaLayer
from n3fit.backends.keras_backend.multi_dense import MultiDense


# Custom activation functions
def square_activation(x):
""" Squares the input """
return x*x
"""Squares the input"""
return x * x


def modified_tanh(x):
""" A non-saturating version of the tanh function """
return math.abs(x)*nn.tanh(x)
"""A non-saturating version of the tanh function"""
return math.abs(x) * nn.tanh(x)


def leaky_relu(x):
""" Computes the Leaky ReLU activation function """
"""Computes the Leaky ReLU activation function"""
return nn.leaky_relu(x, alpha=0.2)


custom_activations = {
"square" : square_activation,
"square": square_activation,
"leaky_relu": leaky_relu,
"modified_tanh": modified_tanh,
}


def LSTM_modified(**kwargs):
"""
LSTM asks for a sample X timestep X features kind of thing so we need to reshape the input
Expand All @@ -61,9 +71,11 @@ def ReshapedLSTM(input_tensor):

return ReshapedLSTM


class Dense(KerasDense, MetaLayer):
pass


def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_kwargs):
"""
Generates a list of layers which can take as an input either one single layer
Expand All @@ -85,7 +97,7 @@ def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_

# Need to generate a list of dense layers
dense_basis = [
base_layer_selector("dense", kernel_initializer=initializer, **dense_kwargs)
base_layer_selector("single_dense", kernel_initializer=initializer, **dense_kwargs)
for initializer in kernel_initializer
]

Expand Down Expand Up @@ -116,13 +128,26 @@ def apply_dense(xinput):

layers = {
"dense": (
MultiDense,
{
"input_shape": (1,),
"replica_seeds": None,
"kernel_initializer": "glorot_normal",
"units": 5,
"activation": "sigmoid",
"kernel_regularizer": None,
"replica_input": True,
},
),
# This one is only used inside dense_per_flavour
"single_dense": (
Dense,
{
"input_shape": (1,),
"kernel_initializer": "glorot_normal",
"units": 5,
"activation": "sigmoid",
"kernel_regularizer": None
"kernel_regularizer": None,
},
),
"dense_per_flavour": (
Expand All @@ -143,31 +168,28 @@ def apply_dense(xinput):
"concatenate": (Concatenate, {}),
}

regularizers = {
'l1_l2': (l1_l2, {'l1': 0., 'l2': 0.})
}
regularizers = {'l1_l2': (l1_l2, {'l1': 0.0, 'l2': 0.0})}


def base_layer_selector(layer_name, **kwargs):
"""
Given a layer name, looks for it in the `layers` dictionary and returns an instance.
Given a layer name, looks for it in the `layers` dictionary and returns an instance.
The layer dictionary defines a number of defaults
but they can be overwritten/enhanced through kwargs
The layer dictionary defines a number of defaults
but they can be overwritten/enhanced through kwargs
Parameters
----------
`layer_name
str with the name of the layer
`**kwargs`
extra optional arguments to pass to the layer (beyond their defaults)
Parameters
----------
`layer_name
str with the name of the layer
`**kwargs`
extra optional arguments to pass to the layer (beyond their defaults)
"""
try:
layer_tuple = layers[layer_name]
except KeyError as e:
raise NotImplementedError(
"Layer not implemented in keras_backend/base_layers.py: {0}".format(
layer_name
)
"Layer not implemented in keras_backend/base_layers.py: {0}".format(layer_name)
) from e

layer_class = layer_tuple[0]
Expand All @@ -182,6 +204,7 @@ def base_layer_selector(layer_name, **kwargs):

return layer_class(**layer_args)


def regularizer_selector(reg_name, **kwargs):
"""Given a regularizer name looks in the `regularizer` dictionary and
return an instance.
Expand All @@ -204,7 +227,8 @@ def regularizer_selector(reg_name, **kwargs):
reg_tuple = regularizers[reg_name]
except KeyError:
raise NotImplementedError(
"Regularizer not implemented in keras_backend/base_layers.py: {0}".format(reg_name))
"Regularizer not implemented in keras_backend/base_layers.py: {0}".format(reg_name)
)

reg_class = reg_tuple[0]
reg_args = reg_tuple[1]
Expand Down
179 changes: 179 additions & 0 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import List

import tensorflow as tf
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import Dense, Dropout


class MultiDense(Dense):
"""
Dense layer for multiple replicas at the same time.
Inputs to this layer may contain multiple replicas, for the later layers.
In this case, the `replica_input` argument should be set to `True`, which is the default.
The input shape in this case is (batch_size, replicas, gridsize, features).
For the first layer, there are no replicas yet, and so the `replica_input` argument
should be set to `False`.
The input shape in this case is (batch_size, gridsize, features).
Weights are initialized using a `replica_seeds` list of seeds, and are identical to the
weights of a list of single dense layers with the same `replica_seeds`.
Example
-------
>>> from tensorflow.keras import Sequential
>>> from tensorflow.keras.layers import Dense
>>> from tensorflow.keras.initializers import GlorotUniform
>>> import tensorflow as tf
>>> replicas = 2
>>> multi_dense_model = Sequential([
>>> MultiDense(units=8, replica_seeds=[42, 43], replica_input=False, kernel_initializer=GlorotUniform(seed=0)),
>>> MultiDense(units=4, replica_seeds=[52, 53], kernel_initializer=GlorotUniform(seed=0)),
>>> ])
>>> single_models = [
>>> Sequential([
>>> Dense(units=8, kernel_initializer=GlorotUniform(seed=42 + r)),
>>> Dense(units=4, kernel_initializer=GlorotUniform(seed=52 + r)),
>>> ])
>>> for r in range(replicas)
>>> ]
>>> gridsize, features = 100, 2
>>> multi_dense_model.build(input_shape=(None, gridsize, features))
>>> for single_model in single_models:
>>> single_model.build(input_shape=(None, gridsize, features))
>>> test_input = tf.random.uniform(shape=(1, gridsize, features))
>>> multi_dense_output = multi_dense_model(test_input)
>>> single_dense_output = tf.stack([single_model(test_input) for single_model in single_models], axis=1)
>>> tf.reduce_all(tf.equal(multi_dense_output, single_dense_output))
Parameters
----------
replica_seeds: List[int]
List of seeds per replica for the kernel initializer.
kernel_initializer: Initializer
Initializer class for the kernel.
replica_input: bool (default: True)
Whether the input already contains multiple replicas.
"""

def __init__(
self,
replica_seeds: List[int],
kernel_initializer: Initializer,
replica_input: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.replicas = len(replica_seeds)
self.replica_seeds = replica_seeds
self.kernel_initializer = MultiInitializer(
single_initializer=kernel_initializer, replica_seeds=replica_seeds
)
self.bias_initializer = MultiInitializer(
single_initializer=self.bias_initializer, replica_seeds=replica_seeds
)
self.replica_input = replica_input

def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
name="kernel",
shape=(self.replicas, input_dim, self.units),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.replicas, 1, self.units),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
else:
self.bias = None
self.input_spec.axes = {-1: input_dim}
self.built = True

def call(self, inputs):
"""
Compute output of shape (batch_size, replicas, gridsize, units).
For the first layer, (self.replica_input is False), this is equivalent to
applying each replica separately and concatenating along the last axis.
If the input already contains multiple replica outputs, it is equivalent
to applying each replica to its corresponding input.
"""
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)

input_axes = 'brnf' if self.replica_input else 'bnf'
einrule = input_axes + ',rfg->brng'
outputs = tf.einsum(einrule, inputs, self.kernel)

# Reshape the output back to the original ndim of the input.
if not tf.executing_eagerly():
output_shape = self.compute_output_shape(inputs.shape.as_list())
outputs.set_shape(output_shape)

if self.use_bias:
outputs = outputs + self.bias

if self.activation is not None:
outputs = self.activation(outputs)

return outputs

def compute_output_shape(self, input_shape):
# Remove the replica axis from the input shape.
if self.replica_input:
input_shape = input_shape[:1] + input_shape[2:]

output_shape = super().compute_output_shape(input_shape)

# Add back the replica axis to the output shape.
output_shape = output_shape[:1] + [self.replicas] + output_shape[1:]

return output_shape

def get_config(self):
config = super().get_config()
config.update({"replica_input": self.replica_input, "replica_seeds": self.replica_seeds})
return config


class MultiInitializer(Initializer):
"""
Multi replica initializer that exactly replicates a stack of single replica initializers.
Weights are stacked on the first axis, and per replica seeds are added to a base seed of the
given single replica initializer.
Parameters
----------
single_initializer: Initializer
Initializer class for the kernel.
replica_seeds: List[int]
List of seeds per replica for the kernel initializer.
"""

def __init__(self, single_initializer: Initializer, replica_seeds: List[int]):
self.initializer_class = type(single_initializer)
self.initializer_config = single_initializer.get_config()
self.base_seed = single_initializer.seed if hasattr(single_initializer, "seed") else None
self.replica_seeds = replica_seeds

def __call__(self, shape, dtype=None, **kwargs):
shape = shape[1:] # Remove the replica axis from the shape.
per_replica_weights = []
for replica_seed in self.replica_seeds:
if self.base_seed is not None:
self.initializer_config["seed"] = self.base_seed + replica_seed
single_initializer = self.initializer_class.from_config(self.initializer_config)

per_replica_weights.append(single_initializer(shape, dtype, **kwargs))

return tf.stack(per_replica_weights, axis=0)
9 changes: 4 additions & 5 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def check_consistent_parallel(parameters, parallel_models, same_trvl_per_replica
"Replicas cannot be run in parallel with different training/validation "
" masks, please set `same_trvl_per_replica` to True in the runcard"
)
if parameters.get("layer_type") != "dense":
raise CheckError("Parallelization has only been tested with layer_type=='dense'")
if parameters.get("layer_type") == "dense_per_flavour":
raise CheckError("Parallelization has not been tested with layer_type=='dense_per_flavour'")


@make_argcheck
Expand Down Expand Up @@ -427,10 +427,9 @@ def check_fiatlux_pdfs_id(replicas, fiatlux):
f"Cannot generate a photon replica with id larger than the number of replicas of the PDFs set {luxset.name}:\nreplica id={max_id}, replicas of {luxset.name} = {pdfs_ids}"
)


@make_argcheck
def check_multireplica_qed(replicas, fiatlux):
if fiatlux is not None:
if len(replicas) > 1:
raise CheckError(
"At the moment, running a multireplica QED fits is not allowed."
)
raise CheckError("At the moment, running a multireplica QED fits is not allowed.")
Loading

0 comments on commit d8f28ff

Please sign in to comment.