Skip to content

Commit

Permalink
fix dynamic engine generation
Browse files Browse the repository at this point in the history
  • Loading branch information
the-database committed Feb 26, 2024
1 parent 8ac6be7 commit 4209a43
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def create_dynamic_engine(onnx_name, width, height):
if not os.path.isfile(onnx_path):
raise FileNotFoundError(onnx_path)

engine_path = get_dynamic_engine_path(onnx_name, width, height)
engine_path = get_dynamic_engine_path(onnx_name)

commands = [os.path.join(plugin_path, "trtexec"), "--fp16", f"--onnx={onnx_path}",
"--minShapes=input:1x3x8x8", "--optShapes=input:1x3x1080x1920", "--maxShapes=input:1x3x1080x1920",
Expand Down Expand Up @@ -142,20 +142,20 @@ def upscale2x(clip, backend, engine_name, num_streams):
# if static engine already exists, use it
static_engine_path = get_static_engine_path(engine_name, clip.width, clip.height)
if os.path.isfile(static_engine_path):
# print(f'Static shapes engine already exists, use static shapes engine at {static_engine_path}', flush=True)
logger.debug(f'Static shapes engine already exists, use static shapes engine at {static_engine_path}')
return upscale2x_trt_static(clip, engine_name, num_streams)
# use dynamic engine if video is 1920x1080 or smaller
if use_dynamic_engine(clip.width, clip.height):
try:
# print('Trying dynamic shapes engine', flush=True)
logger.debug('Trying dynamic shapes engine')
return upscale2x_trt_dynamic(clip, engine_name, num_streams)
except:
# print('Failed to generate dynamic shapes engine; fall back to static shapes engine', flush=True)
logger.debug('Failed to generate dynamic shapes engine; fall back to static shapes engine')
# fall back to static engine since not all models support dynamic shapes
return upscale2x_trt_static(clip, engine_name, num_streams)

# use static engine if the video is larger than 1920x1080
# print('Using static shapes engine for video higher than 1080p', flush=True)
logger.debug('Using static shapes engine for video higher than 1080p')
return upscale2x_trt_static(clip, engine_name, num_streams)


Expand Down

0 comments on commit 4209a43

Please sign in to comment.