Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add custom Triton cache manager to resolve MoE MP issue #6140

Merged
merged 10 commits into from
Jul 15, 2024

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jul 4, 2024

Fixes #6103

We have been using this fix via our fork (see here) for a while and it seems stable.

Note, this will only resolve the problem if you are using vLLM from the docker image. Maybe a better approach would be to bundle the custom cache manager code inside vllm package, that way it will get shipped via pip install too, and the user could still set env variable to enable it.

Update: I've now implemented it by including the custom cache manager inside vLLM and setting the necessary env variable via code.

cc @jeejeelee

tdoublep and others added 2 commits July 4, 2024 09:18
Co-authored-by: Chih-Chieh-Yang <chih.chieh.yang@ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@jeejeelee
Copy link
Contributor

Thanks

IMHO, this issue should be addressed by bundling the custom cache manager code inside the vllm.

cc @simon-mo @youkaichao @Yard1

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@njhill
Copy link
Member

njhill commented Jul 4, 2024

Thanks @tdoublep, I had mentioned this to @youkaichao previously but kept forgetting to open a PR.

Not immediately obvious why this seems to only affect the mp case and not ray.

I agree that it would be better for this to be incorporated into the library if possible.

I wonder if we could open a PR or issue in the triton for this (if one doesn't already exist)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
else:
raise RuntimeError("Could not create or locate cache dir")

print(f"Triton cache dir: {self.cache_dir=}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a debug log instead, it produces a lot of output.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have just removed it for now

@tdoublep tdoublep changed the title [Bugfix] Install custom cache manager in Dockerfile to resolve Triton MoE MP issue [Bugfix] Add custom Triton cache manager to resolve MoE MP issue Jul 4, 2024
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 4, 2024

@njhill @jeejeelee I have re-implemented it as part of the vllm library.

One thing I'm not sure about is whether setting the env variable from fused_moe code is sufficient, or whether there are other parts of the code where this fix would be needed. Maybe it's OK for now.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 4, 2024

After reading the conversation here, it sounds like we would also need to set this env variable accordingly when using Triton punica kernel (e.g., once we merge this PR).

@jeejeelee
Copy link
Contributor

After reading the conversation here, it sounds like we would also need to set this env variable accordingly when using Triton punica kernel (e.g., once we merge this PR).

Even if we don't consider #5036, prefix_prefill and triton_flash_attention are still necessary.

def maybe_set_triton_cache_manager(module: str) -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger != module:
os.environ["TRITON_CACHE_MANAGER"] = module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user manually sets this env, can we modify it? Additionally, I suggest adding a log message for clarification

Copy link
Member Author

@tdoublep tdoublep Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have changed it so that we only set it if the user has not. also added a log message

@tdoublep
Copy link
Member Author

tdoublep commented Jul 4, 2024

Even if we don't consider #5036, prefix_prefill and triton_flash_attention are still necessary.

@jeejeelee ok, in that case I guess it makes sense to call maybe_set_triton_cache_manager in one single place rather than in each individual place we use Triton. perhaps we can do it if we detect tp>1 and multi-processing being used?

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 4, 2024

@njhill @jeejeelee I've moved the call to maybe_set_triton_cache_manager to the MultiprocessingGPUExecutor. I guess this is safer and should cover all cases we need.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 5, 2024

CI tests failures look like network blips (Read timed out.)

os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some docstrings

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 5, 2024

CI failure looks unrelated:

FAILED distributed/test_multimodal_broadcast.py::test_models[5-128-half-2] - huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on

@youkaichao
Copy link
Member

Not immediately obvious why this seems to only affect the mp case and not ray.

I'm also wondering this, too. cc anyscale folks @cadedaniel @Yard1 for visibility.

Comment on lines +22 to +25
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If triton 3.0.0 could solve this problem, it'd be better to note here that this custom cache manager can be removed when we upgrade triton.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix for the issue is not yet in v3.0.0, but I guess would be in whatever version comes after that (see my summary here). I will add a comment to that end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@LSC527 LSC527 mentioned this pull request Jul 11, 2024
@tdoublep
Copy link
Member Author

All comments have been addressed. Is there anything else you would like to see? @comaniac @njhill @jeejeelee @simon-mo

I think it would be good to get this one in since there are quite a few people struggling with this issue.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. cc @Yard1 to take a final pass.

@simon-mo
Copy link
Collaborator

merging to unblock release

@simon-mo simon-mo merged commit eaec4b9 into vllm-project:main Jul 15, 2024
71 checks passed
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 17, 2024
…m-project#6140)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Chih-Chieh-Yang <chih.chieh.yang@ibm.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
…m-project#6140)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Chih-Chieh-Yang <chih.chieh.yang@ibm.com>
dtrifiro added a commit to dtrifiro/vllm that referenced this pull request Jul 24, 2024
dtrifiro added a commit to opendatahub-io/vllm that referenced this pull request Jul 25, 2024
dtrifiro added a commit to opendatahub-io/vllm that referenced this pull request Aug 6, 2024
dtrifiro added a commit to opendatahub-io/vllm that referenced this pull request Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: fused_moe_kernel compile bug
6 participants