Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add demo for dataset generation. #7

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions d4rl-generation/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
experiment_logs/
datasets/
logs/
*.hdf5

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
54 changes: 54 additions & 0 deletions d4rl-generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Kabuki Dataset generation.

This repo serves to demonstrate a workflow for creating
offline datasets compatible with Kabuki.

It uses dm-acme for online RL training. The data
is logged via EnvLogger, and the logged data is then post-processed
into an HDF5 file. However, this is only a proof of concept, and
neither dm-acme nor EnvLogger constitutes part of the specification.

## Spec for the new HDF5 dataset in Kabuki
We describe the standard used in the HDF5 dataset used in D4RL-V2.
Unlike the old datasets used in previous versions of the D4RL datasets,
we now formalize a new standard for the new datasets.

The new datasets will continue to use HDF5 as the storage format. Although HDF5 files
do not naturally have the mechanisms for storing distinct episodes as separate entries,
it is a widely adopted a standard format that can be easily used in different
frameworks and languages.

The previous iterations of the D4RL datasets have some outstanding issues.
One notable issue is that terminal observations are not captured.
While the omission of terminal observations are not really problematic
for offline actor-critic algorithms such as CQL/BCQ/IQL etc.,
they pose issues for researchers
who would like to work on offline imitation learning research.
It is known that proper handling of terminal transitions can significantly influence the
performance of imitation learning algorithms.
Therefore, the terminal observations will be recorded in the dataset in the new version. The new datasets should capture as much information from the original environment as possible whenever possible.

In the new version, the dataset will follow the convention introduced by the RLDS project.
The agent's experience is stored in the dataset as a sequence of episodes consisting of a variable number of steps. The steps are stored as a flattened dictionary of arrays
in the state-action-reward (SAR) alignment. Concretely each step consists of

* is_first, is_last: indicating the observation for the step is the first/the last step of an episode.
* observation: observation for the step
* action: action taken after observing the `observation` of the step
* reward: reward obtained after applying the action in the step.
* is_terminal: indicating whether the observation is terminal (is_terminal = False indicates that the episode is truncated.)
* discount: discount factor at this step. This may be unfamiliar to gym.Env users but
is consistent with the discount used in dm_env. In particular, discount = 0 indicates that
the *next* step is terminal and 1.0 otherwise.

Refer to https://github.com/google-research/rlds for a more detailed description.

## Generating datasets.
While HDF5 is used as the final format for storing benchmark datasets in D4RL,
HDF5 is not used as the format during the data collection process. In this repo,
we demonstrate using EnvLogger for recording the interactions made by an
RL agent during online learning. The logged experience will then be post-processed
(and potentially stitched with other datasets) to produce the final HDF5 files.
We provide `convert_dataset.py` to show how this can be done by converting
from EnvLogger's Riegeli file formats to a single HDF5 file.
Alternatively, we can also use EnvLogger's RLDS backend to generate an RLDS-compatible TensorFlow dataset and convert that to the HDF5 file.
77 changes: 77 additions & 0 deletions d4rl-generation/convert_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Convert dataset logged by the EnvLogger Riegelli backend to HDF5."""
from envlogger import reader
import numpy as np
import h5py
import tree
from absl import flags
from absl import app

_DATASET_DIR = flags.DEFINE_string("dataset_dir", None, "")
_OUTPUT_FILE = flags.DEFINE_string("output_file", "dataset.hdf5", "")
flags.mark_flag_as_required("dataset_dir")


def _convert_envlogger_episode_to_rlds_steps(episode):
"""Convert an episode of envlogger.StepData to RLDS compatible steps."""
observations = np.stack([step.timestep.observation for step in episode])
# RLDS uses the SAR alignment while envlogger uses ARS.
# The following lines handle converting from the ARS to SAR alignment.
actions = np.stack([step.action for step in episode[1:]])
# Add dummy action to the last step containing the terminal observation.
actions = np.concatenate(
[actions, np.expand_dims(np.zeros_like(actions[0]), axis=0)]
)
# Add dummy reward to the last step containing the terminal observation.
rewards = np.stack([step.timestep.reward for step in episode[1:]])
rewards = np.concatenate(
[rewards, np.expand_dims(np.zeros_like(rewards[0]), axis=0)]
)
# Add dummy discounts to the last step containing the terminal observation.
discounts = np.stack([step.timestep.reward for step in episode[1:]])
discounts = np.concatenate(
[discounts, np.expand_dims(np.zeros_like(discounts[0]), axis=0)]
)
# the is_first/last/terminal flags are already aligned in ARS alignment.
is_first = np.array([step.timestep.first() for step in episode])
is_last = np.array([step.timestep.last() for step in episode])
is_terminal = np.array(
[step.timestep.last() and step.timestep.discount == 0.0 for step in episode]
)
return {
"observation": observations,
"action": actions,
"reward": rewards,
"discounnts": discounts,
"is_first": is_first,
"is_last": is_last,
"is_terminal": is_terminal,
}


def write_to_hdf5_file(episodes, filename):
"""Write episodes in EnvLogger format to an HDF5 file."""
all_steps = []
for episode in episodes:
all_steps.append(_convert_envlogger_episode_to_rlds_steps(episode))
all_steps = tree.map_structure(lambda *xs: np.concatenate(xs), *all_steps)
f = h5py.File(filename, "w")
for key in all_steps.keys():
f.create_dataset(key, data=all_steps[key], compression="gzip")
f.close()


def main(_):
output_file = _OUTPUT_FILE.value
with reader.Reader(data_directory=_DATASET_DIR.value) as r:
print(r.observation_spec())
print(r.metadata())
write_to_hdf5_file(r.episodes, output_file)
# Inspecting the created HDF5 file
f = h5py.File(output_file, "r")
for k in f:
print(k, f[k].shape)
f.close()


if __name__ == "__main__":
app.run(main)
69 changes: 69 additions & 0 deletions d4rl-generation/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Shared helpers for rl_continuous experiments."""
from typing import Optional
from acme import wrappers
import dm_env
import gym
from acme.utils import loggers as acme_loggers
from absl import logging

_VALID_TASK_SUITES = ("gym", "control")


def make_environment(suite: str, task: str, seed=None) -> dm_env.Environment:
"""Makes the requested continuous control environment.
Args:
suite: One of 'gym' or 'control'.
task: Task to load. If `suite` is 'control', the task must be formatted as
f'{domain_name}:{task_name}'
Returns:
An environment satisfying the dm_env interface expected by Acme agents.
"""

if suite not in _VALID_TASK_SUITES:
raise ValueError(
f"Unsupported suite: {suite}. Expected one of {_VALID_TASK_SUITES}"
)

if suite == "gym":
env = gym.make(task)
env.seed(seed)
# Make sure the environment obeys the dm_env.Environment interface.
env = wrappers.GymWrapper(env)

elif suite == "control":
# Load dm_suite lazily not require Mujoco license when not using it.
from dm_control import suite as dm_suite # pylint: disable=g-import-not-at-top

domain_name, task_name = task.split(":")
env = dm_suite.load(domain_name, task_name, task_kwargs={'random': seed})
env = wrappers.ConcatObservationWrapper(env)

# Wrap the environment so the expected continuous action spec is [-1, 1].
# Note: this is a no-op on 'control' tasks.
env = wrappers.CanonicalSpecWrapper(env, clip=True)
env = wrappers.SinglePrecisionWrapper(env)
return env


def get_default_logger_factory(workdir: str, save_data=True, time_delta: float = 1.0):
"""Create a custom logger factory for use in the experiment."""

def logger_factory(label: str, steps_key: Optional[str] = None, task_id: int = 0):
del steps_key, task_id

print_fn = logging.info
terminal_logger = acme_loggers.TerminalLogger(label=label, print_fn=print_fn)

loggers = [terminal_logger]

if save_data:
loggers.append(acme_loggers.CSVLogger(workdir, label=label))

# Dispatch to all writers and filter Nones and by time.
logger = acme_loggers.Dispatcher(loggers, acme_loggers.to_numpy)
logger = acme_loggers.NoneFilter(logger)
logger = acme_loggers.TimeFilter(logger, time_delta)

return logger

return logger_factory
Loading