Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494223959
  • Loading branch information
tensorflower-gardener authored and fyangf committed Mar 21, 2023
1 parent a619f51 commit b50c213
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
13 changes: 10 additions & 3 deletions official/vision/configs/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import dataclasses
import os
from typing import List, Optional, Union
from typing import Optional, List, Sequence, Union

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import backbones
Expand Down Expand Up @@ -65,8 +66,14 @@ class Parser(hyperparams.Config):

@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
"""Input config for training.
Attributes:
weights: Sampling weights for each corresponding input_path. If used, then
input_path must be a config with matching keys.
"""
input_path: Union[Sequence[str], str, base_config.Config] = ''
weights: Optional[base_config.Config] = None
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
Expand Down
19 changes: 18 additions & 1 deletion official/vision/dataloaders/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,31 @@

"""Dataset reader for vision model garden."""

from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Mapping, Optional, Tuple

import tensorflow as tf

from official.core import config_definitions as cfg
from official.core import input_reader


def build_weighted_sampling_combine_fn(
weights: Mapping[Any, Any]) -> Callable[[tf.data.Dataset], tf.data.Dataset]:
"""Builds a combine_fn using weighted sampling."""

def combine_fn(datasets: Mapping[Any, tf.data.Dataset]) -> tf.data.Dataset:
"""Combines multiple datasets using weighted sampling."""
ds = []
ws = []
for k, dataset in datasets.items():
ds.append(dataset)
ws.append(weights[k])
return tf.data.Dataset.sample_from_datasets(
ds, ws, stop_on_empty_dataset=True)

return combine_fn


def calculate_batch_sizes(total_batch_size: int,
pseudo_label_ratio: float) -> Tuple[int, int]:
"""Calculates labeled and pseudo-labeled dataset batch sizes.
Expand Down
25 changes: 25 additions & 0 deletions official/vision/tasks/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

from official.common import dataset_fn
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.vision.configs import retinanet as exp_cfg
from official.vision.dataloaders import input_reader
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import retinanet_input
from official.vision.dataloaders import tf_example_decoder
Expand Down Expand Up @@ -130,10 +132,33 @@ def build_inputs(self,
skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances)

combine_fn = None
if params.is_training and params.weights:
# Combine multiple datasets using weighted sampling.
if (not isinstance(params.input_path, cfg.base_config.Config) or
not isinstance(params.weights, cfg.base_config.Config)):
raise ValueError(
'input_path and weights must both be a Config to use weighted '
'sampling.')
input_paths = params.input_path.as_dict()
weights = params.weights.as_dict()
if len(input_paths) != len(weights):
raise ValueError(
'The number of input_path and weights must be the same, but got %d '
'input_paths and %d weights.' % (len(input_paths), len(weights)))

for k in input_paths.keys():
if k not in weights:
raise ValueError(
'input_path key \'%s\' does not have a corresponding weight.' % k)

combine_fn = input_reader.build_weighted_sampling_combine_fn(weights)

reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
combine_fn=combine_fn,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)

Expand Down

0 comments on commit b50c213

Please sign in to comment.