Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zewenli98
Copy link
Collaborator

Description

  1. Supported weight-stripped engine for python runtime
  2. Added REFIT_IDENTICAL flag

Fixes #3146

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Sep 19, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 19, 2024
Comment on lines 79 to +82
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
Copy link
Collaborator Author

@zewenli98 zewenli98 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt refit already work on both apis?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why do we need the graph module in this module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In this PR I moved the refitting part into TRTModule, so only works for Python runtime.

  2. graph module is used for refitting

@@ -619,27 +609,32 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this maybe unrefitted engine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe

), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this function?

def _refit_single_trt_engine_with_gm(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function requires input_list which is not provided in the caller.

@@ -121,6 +124,52 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
Copy link
Collaborator

@narendasan narendasan Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely shouldnt be doing the refit in these modules

I think for weight stripping there are 3 workflows.

  1. a user just wants a weight stripped engine. They should use convert_exported_program_to_trt_engine with settings strip_weights. The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL (though it might not be entirely clear so we might want to think about that setting).
  2. We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of make_refittable controls if we use kREFIT or kREFIT_IDENTICAL. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today
  3. The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through refit_engine_weights before executing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()

The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 do you have a design for this feature?

@zewenli98
Copy link
Collaborator Author

@narendasan Ok, at first the overall design was like:

In TRTInterpreter.run():

if compilation_settings.strip_engine_weights is True:
    if engine_cache not hit:
        1. build a weight-stripped engine
        2. save the weight-stripped engine if engine_cache is set
        3. return the weight-stripped engine (not yet refit)
    else:
        load and return the weight-stripped engine (not yet refit)
else:
    if engine_cache not hit:
        1. build a weight-included engine
        2. save the weight-included engine if engine_cache is set
        3. return the weight-included engine (don't need to refit)
    else:
        load and return the weight-included engine (not yet refit)

Then, in TRTModule, refit if necessary before inference.
The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Weight specific engine caching
4 participants