diff --git a/fourm/demo_4M_sampler.py b/fourm/demo_4M_sampler.py index 2639974..5879727 100644 --- a/fourm/demo_4M_sampler.py +++ b/fourm/demo_4M_sampler.py @@ -258,6 +258,8 @@ def __init__(self, self.mods_sr = mods_sr or list(set(fm_sr.encoder_modalities) | set(fm_sr.decoder_modalities)) else: self.sampler_fm_sr = None + self.mods_sr = [] + # Load tokenizers self.toks = {} @@ -537,4 +539,4 @@ def modalities_to_pil(self, mod_dict, use_fixed_plotting_order=False, resize=Non plot_name = MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name) plotted_modalities.append((img_pil, plot_name)) - return plotted_modalities \ No newline at end of file + return plotted_modalities