Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental GTFN Executor Caching #1197

Merged
merged 8 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import dataclasses
from typing import Generic, Protocol, TypeVar
from typing import Callable, Generic, Protocol, TypeVar


StartT = TypeVar("StartT")
Expand All @@ -24,6 +24,7 @@
EndT_co = TypeVar("EndT_co", covariant=True)
NewEndT = TypeVar("NewEndT")
IntermediateT = TypeVar("IntermediateT")
HashT = TypeVar("HashT")


def make_step(function: Workflow[StartT, EndT]) -> Step[StartT, EndT]:
Expand Down Expand Up @@ -128,3 +129,41 @@ def __call__(self, inp: StartT) -> EndT:

def chain(self, step: Workflow[EndT, NewEndT]) -> CombinedStep[StartT, EndT, NewEndT]:
return CombinedStep(first=self, second=step)


@dataclasses.dataclass(frozen=True)
class CachedStep(Step[StartT, EndT], Generic[StartT, EndT, HashT]):
"""
Cached workflow of single input callables.

Examples:
---------
>>> def heavy_computation(x: int) -> int:
... print("This might take a while...")
... return x

>>> cached_step = CachedStep(heavy_computation)

>>> cached_step(42)
This might take a while...
42

The next invocation for the same argument will be cached:
>>> cached_step(42)
42

>>> cached_step(1)
This might take a while...
1
"""

workflow: Workflow[StartT, EndT]
hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment]

_cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict)

def __call__(self, inp: StartT) -> EndT:
hash_ = self.hash_function(inp)
if hash_ in self._cache:
return self._cache[hash_]
return self._cache.setdefault(hash_, self.workflow(inp))
Copy link
Contributor Author

@tehrengruber tehrengruber Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hash_ in self._cache:
return self._cache[hash_]
return self._cache.setdefault(hash_, self.workflow(inp))
try:
result = self._cache[hash_]
except KeyError:
result = self._cache[hash_] = self.workflow(inp)
return result

suggested by @egparedes

87 changes: 62 additions & 25 deletions src/gt4py/next/program_processors/runners/gtfn_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import numpy as np
import numpy.typing as npt

from gt4py.eve.utils import content_hash
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import languages, stages, workflow
from gt4py.next.otf.binding import cpp_interface, pybind
from gt4py.next.otf.compilation import cache, compiler
from gt4py.next.otf.compilation.build_systems import compiledb
from gt4py.next.otf.workflow import CachedStep
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
from gt4py.next.type_system.type_translation import from_value


# TODO(ricoh): Add support for the whole range of arguments that can be passed to a fencil.
Expand All @@ -36,6 +39,18 @@ def convert_arg(arg: Any) -> Any:
return arg


def convert_args(inp: Callable) -> Callable:
def decorated_program(
*args, offset_provider: dict[str, common.Connectivity | common.Dimension]
):
return inp(
*[convert_arg(arg) for arg in args],
*extract_connectivity_args(offset_provider),
)

return decorated_program


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension]
) -> list[npt.NDArray]:
Expand All @@ -58,6 +73,21 @@ def extract_connectivity_args(
return args


def compilation_hash(otf_closure: stages.ProgramCall) -> int:
"""Given closure compute a hash uniquely determining if we need to recompile."""
offset_provider = otf_closure.kwargs["offset_provider"]
return hash(
(
otf_closure.program,
# As the frontend types contain lists they are not hashable. As a workaround we just
# use content_hash here.
content_hash(tuple(from_value(arg) for arg in otf_closure.args)),
id(offset_provider) if offset_provider else None,
otf_closure.kwargs.get("column_axis", None),
)
)


@dataclasses.dataclass(frozen=True)
class GTFNExecutor(ppi.ProgramExecutor):
language_settings: languages.LanguageWithHeaderFilesSettings = cpp_interface.CPP_DEFAULT
Expand All @@ -67,44 +97,51 @@ class GTFNExecutor(ppi.ProgramExecutor):
enable_itir_transforms: bool = True # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135
use_imperative_backend: bool = False

def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
"""
Execute the iterator IR program with the provided arguments.

The program is compiled to machine code with C++ as an intermediate step,
so the first execution is expected to have a significant overhead, while subsequent
calls are very fast. Only scalar and buffer arguments are supported currently.

See ``ProgramExecutorFunction`` for details.
"""

def convert_args(inp: Callable) -> Callable:
def decorated_program(*args):
return inp(
*[convert_arg(arg) for arg in args],
*extract_connectivity_args(kwargs["offset_provider"]),
)

return decorated_program

otf_workflow: Final[workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]] = (
# TODO(tehrengruber): Revisit default value. The hash function for the caching currently
# only uses a subset of the closure and implicitly relies on the workflow to only use that
# information. As this is dangerous the caching is disabled by default and should be considered
# experimental.
use_caching: bool = False
caching_strategy = cache.Strategy.SESSION

_otf_workflow: workflow.Workflow[
stages.ProgramCall, stages.CompiledProgram
] = dataclasses.field(repr=False, init=False)

def __post_init__(self):
# TODO(tehrengruber): Restrict arguments of OTF workflow to the parts it actually needs
# to compile the program instead of the entire closure.
otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram] = (
gtfn_module.GTFNTranslationStep(
self.language_settings, self.enable_itir_transforms, self.use_imperative_backend
)
.chain(pybind.bind_source)
.chain(
compiler.Compiler(
cache_strategy=cache.Strategy.SESSION, builder_factory=self.builder_factory
cache_strategy=self.caching_strategy, builder_factory=self.builder_factory
)
)
.chain(convert_args)
)

otf_closure = stages.ProgramCall(program, args, kwargs)
if self.use_caching:
otf_workflow = CachedStep(workflow=otf_workflow, hash_function=compilation_hash)

super().__setattr__("_otf_workflow", otf_workflow)

def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
"""
Execute the iterator IR program with the provided arguments.

The program is compiled to machine code with C++ as an intermediate step,
so the first execution is expected to have a significant overhead, while subsequent
calls are very fast. Only scalar and buffer arguments are supported currently.

compiled_runner = otf_workflow(otf_closure)
See ``ProgramExecutorFunction`` for details.
"""
compiled_runner = self._otf_workflow(stages.ProgramCall(program, args, kwargs))

compiled_runner(*args)
return compiled_runner(*args, offset_provider=kwargs["offset_provider"])

@property
def __name__(self) -> str:
Expand Down