Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trvl mask layers #1661

Closed
wants to merge 29 commits into from
Closed

Trvl mask layers #1661

wants to merge 29 commits into from

Conversation

goord
Copy link
Collaborator

@goord goord commented Jan 13, 2023

Implementation of training-validation split per parallel replica via a set of masking layers. This means there is only one (set of) observables, but still training-validation observable wrappers which mask out the correct data.

Currently, I still need to do:

  • Run the realistic nnpdf4 use case
  • Compare results of a basic run
  • Compare results of a realistic run
  • Monitor memory use, investigate scaling with no. parallel replicas
  • Test on GPU's

@goord
Copy link
Collaborator Author

goord commented Jan 13, 2023

I am facing problems with datasets with single datapoint, some replicas wll mask these out, other won't. That is a poor fit for the design I followed, which assumes the training/validation masks can be represented as block-wise boolean arrays with the same number of 'True' values per row, such that the output tensor is not strided

@RoyStegeman
Copy link
Member

This was recently changed in #1636.

In principle I would be happy with a gpu implementation that treats datasets with a single point in the old way (so including it in the tr set). The difference in the result is negligible and as such we could still use the gpu for most purposes, only for a very final fit before a release might it be better to have the new treatment of single point datasets. @scarlehoff has many gpus available, so perhaps he has an opinion on this.

@scarlehoff
Copy link
Member

I'd be happy with the solution of just ignoring those datasets for the time being.

@RoyStegeman
Copy link
Member

Also fine for me. Point is (to address @goord's concerns), this is not a showstopper for the parallel fits.

@goord
Copy link
Collaborator Author

goord commented Jan 13, 2023

Yes that would be the best solution, as things become very complicated if we can't assume equal number of masked data points across replicas. I will try to include (or exclude?) all single-point datasets in the fit if parallel replicas is set and same_trvl_split is unset...

Copy link
Member

@scarlehoff scarlehoff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like this solution. It is more or less what I had in mind I think.

RE the one-point datasets, I think we can eventually find a way around it but let's ignore them for the moment.

Did you run any parallel fits with this code? Did they work? Is the memory footprint greatly impacted?

n3fit/src/n3fit/layers/mask.py Outdated Show resolved Hide resolved
n3fit/src/n3fit/layers/mask.py Outdated Show resolved Hide resolved
n3fit/src/n3fit/model_gen.py Outdated Show resolved Hide resolved
n3fit/src/n3fit/model_gen.py Outdated Show resolved Hide resolved
masked_output_layers.append(mask_layer(output_layer))

# Finally concatenate all observables (so that experiments are one single entity)
ret = op.concatenate(masked_output_layers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the axis removed? (I guess the default is exactly the right axis, but I'd like to have it explicit, it makes debugging easier)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was a bit explicit, but since all tensor shapes have to be as explicit as possible for tensorflow to do the correct thing, I will re-insert it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More than for tensorflow is for the person reading the code in this case, since it is hard to keep track of which axis is what :P

tr_mask_layers = []
vl_mask_layers = []
offset = 0
apply_masks = spec_dict.get("data_transformation_tr") is None and mask_array is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried. If mask_array is not None but there is a data_transformation_tr then the masks will not be applied. If this is necessary then it should fail at the beginning.

We usually do that by adding a check before the fit starts. In this case it should check whether the run options are a parallel fit and data_transformation and if so validphys will raise an exception telling the user which options are inconsistent.

(for the time being you can put just a raise Exception here to stop it and create the check at the end)

Comment on lines 173 to 175
trmask = mask_array[:, offset:offset + dataset.ndata] if apply_masks else None
tr_mask_layers.append(Mask(trmask, axis=1, c=1) if apply_masks else None)
vl_mask_layers.append(Mask(~trmask, axis=1, c=1) if apply_masks else None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trmask = mask_array[:, offset:offset + dataset.ndata] if apply_masks else None
tr_mask_layers.append(Mask(trmask, axis=1, c=1) if apply_masks else None)
vl_mask_layers.append(Mask(~trmask, axis=1, c=1) if apply_masks else None)
if apply_masks:
trmask = mask_array[:, offset:offset + dataset.ndata]
tr_mask_layers.append(Mask(bool_mask=trmask, axis=1))
vl_mask_layers.append(Mask(bool_mask=~trmask, axis=1))
else:
trmask = None
tr_mask_layers.append(None)
vl_mask_layers.append(None)

I don't like the idea of having a list of None, I think with the check above you will always be in a consistent state and you might be able to check whether to apply the mask somewhere else (so you don't need the None).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in rev. bc56876

model_observables.append(obs_layer)

# shift offset for new mask array
offset = offset + dataset.ndata
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be a way to have a list of arrays from the onset (instead of having an offset that we move?) such that to each dataset in the list correspond an array.

@@ -37,6 +40,7 @@ class ObservableWrapper:

name: str
observables: list
trvl_mask_layers: list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can have this be optional = None, and if it is None then it doesn't get applied.

Because (I think, maybe I'm wrong!) that we should never be in a situation in which some of the masks exist and some are None, that way you can avoid the list of None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in rev. bc56876. I did not insert the default value yet though

Comment on lines 14 to 15

import numpy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import numpy

@scarlehoff scarlehoff mentioned this pull request Jan 20, 2023
@goord
Copy link
Collaborator Author

goord commented Feb 6, 2023

Hi @scarlehoff I haven't replied to your suggestions because I propose to first get the validation in order. After testing with the basic_parallel test case with only a single dataset, I believe I still have a bug somewhere in the implementation because my chi2 values differ significantly from the master sequential (or parallel) runs.

variable master-sequential master-parallel trvl mask-layers parallel
chi2 1.02 1.01 20.0
erf_tr 1.61 2.0 3.77
erf_vl 1.81 1.46 55.4

These are means over 500 replicas.

@goord
Copy link
Collaborator Author

goord commented Feb 15, 2023

I have solved a few problems with the implementation: the first - obvious - one was that the masked truth values weren't correctly propagating to the loss function. Another issue was that a replica-specific inverse covariance was not taken into account.

Latest commits result in a basic runcard fit that is bitwise identical to the sequential run for small no. epochs. More validation results will follow.

@goord
Copy link
Collaborator Author

goord commented Mar 1, 2023

Memory use is now ok, slightly lower than the master branch:
memuse-plot

@goord
Copy link
Collaborator Author

goord commented May 10, 2023

Regarding memory usage: due to the different masks per replica, the lru_cache for fittable_datasets_masked in n3fit_data.py is not triggered anymore when same_trvl_per_replica is set to false. This causes the following loop in n3fit_data_utils.py to be executed for each replica:

for dspec, mask in zip_longest(datasets, tr_masks):
        # Load all fktables with the appropiate cuts
        fktables = [fk.load_with_cuts(dspec.cuts) for fk in dspec.fkspecs]
        # And now put them in a FittableDataSet object which
        loaded_obs.append(FittableDataSet(dspec.name, fktables, dspec.op, dspec.frac, mask))

Although the FK table loading is in its own lru_cache and not re-executed, the with_cuts is not and will trigger a copy of the FK-tables for each replica, although the cuts are identical and independent of the mask, and inflates the memory footprint.

The easiest fix is to wrap the fk.load_with_cuts(dspec.cuts) in a dedicated function with a lru_cache attribute. I wouldn't call this an elegant solution though, any thoughts @scarlehoff ?

@scarlehoff
Copy link
Member

Yes. I think that should work since all replicas are using the same fktables now (i.e., it should work if the datasets are the same for all replicas and the only thing that changes is tr_masks but I think this is indeed the case).


@functools.lru_cache
def load_cached_fk_tables(fk, cuts):
return fk.load_with_cuts(cuts)
Copy link
Member

@scarlehoff scarlehoff May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to wrap directly fk.load_with_cuts?

https://github.com/NNPDF/nnpdf/blob/f02b49a8e6eb785af6a56dc3195133abff571046/validphys2/src/validphys/core.py#LL442C4-L442C4

(in practical terms it should be the same)

Given an fktable and a set of cuts, I see no reason why we would we want two different objects so that probably is a positive change in other parts of the code as well*

*hopefully

(of course, for the purposes of testing and benchmarking maybe it is better to start here)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would be the more elegant solution, but needs thorough testing because this indeed impacts the code in more places, there will be shared objects where there were copies before. I can have a look.

@goord goord marked this pull request as ready for review June 19, 2023 12:37
@goord
Copy link
Collaborator Author

goord commented Jun 19, 2023

New validation and performance tests are underway

@goord
Copy link
Collaborator Author

goord commented Jul 13, 2023

A 100-replica fit comparison with the recent fits by @APJansen is done here: https://vp.nnpdf.science/Hz7Gwu95TzCUCH4oYoQOQA==

Intermediate merge of Aron's stuff
@RoyStegeman
Copy link
Member

Can this be closed in favor of #1788 ?

@RoyStegeman RoyStegeman mentioned this pull request Aug 30, 2023
3 tasks
@RoyStegeman RoyStegeman mentioned this pull request Nov 1, 2023
@scarlehoff
Copy link
Member

Let me echo @RoyStegeman's question

Can this be closed in favor of #1788 ?

@goord goord closed this Nov 13, 2023
@goord
Copy link
Collaborator Author

goord commented Nov 13, 2023

closed in favor of #1788

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants