Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to disable logger while compiling to avoid graph breaks #6496

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import sys
import os
import torch

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -19,6 +20,31 @@

class LoggerFactory:

def create_warning_filter(logger):
warn = False

def warn_once(record):
nonlocal warn
if torch.compiler.is_compiling() and not warn:
warn = True
logger.warning("It is recommended to set DISABLE_LOGS_WHILE_COMPILING=1,"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rephrase the warning to reflect that it applies to compile-mode execution and control by env var. For example,

Suggested change
logger.warning("It is recommended to set DISABLE_LOGS_WHILE_COMPILING=1,"
logger.warning("To avoid graph breaks in compile-mode, it is recommended to disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1,"

"to avoid graph breaks caused by logger.")
return True

return warn_once

@staticmethod
def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch._dynamo.is_compiling():
return
else:
return func(*args, **kwargs)

return wrapper

@staticmethod
def create_logger(name=None, level=logging.INFO):
"""create a logger
Expand All @@ -44,6 +70,12 @@ def create_logger(name=None, level=logging.INFO):
ch.setLevel(level)
ch.setFormatter(formatter)
logger_.addHandler(ch)
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
original_logger = getattr(logger_, method)
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
else:
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
return logger_


Expand Down