Skip to content

Add abililty to view interpolation to rcp_viewer.py #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions mlperf_logging/rcp_checker/visualization_scripts/rcp_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,41 @@

from mlperf_logging.rcp_checker.rcp_checker import RCP_Checker

def print_rcp_record(record):
print(f"{record['BS']},{record['RCP Mean']},{record['Min Epochs']}")

# this should be a method of rcp_checker.RCP_Checker, but it's missing.
# Instead we derived it from _find_min_rcp()
def find_max_rcp(checker, rcp_pass_arg='pruned_rcps'):
'''Find RCP with the smallest batch size for a benchmark'''
max_bs = -1
max_record = None
rcp_data = checker._get_rcp_data(rcp_pass_arg)
for _, record_contents in rcp_data.items():
if record_contents['BS'] > max_bs:
max_record = record_contents
max_bs = record_contents['BS']
return max_record

# this should be a method of rcp_checker.RCP_Checker, but it's missing.
# Instead we derived it by extracting parts of rcp_checker.check_directory()
def get_rcp_record_for_bs(bs, checker, rcp_pass_arg='pruned_rcps'):
rcp_record = checker._find_rcp(bs, rcp_pass_arg)
if rcp_record is None:
# bs is not one of the generated sizes, so need to interpolate:
rcp_max = checker._find_bottom_max_rcp(bs, rcp_pass_arg)
if rcp_max is None:
raise RuntimeError("Error: no sufficiently large RCP bs found")
rcp_min = checker._find_top_min_rcp(bs, rcp_pass_arg)
if rcp_min is None:
# bs is smaller than the smallest rcp, so just use smallest rcp
rcp_record = checker._find_min_rcp(rcp_pass_arg)
else:
# interpolate
interp_record_name, interp_record = checker._create_interp_rcp(bs, rcp_min, rcp_max)
rcp_record = interp_record
return rcp_record

def main():
parser = argparse.ArgumentParser(
description='Parse rcps_.json file, prune, and print out rcp means and mins'
Expand All @@ -27,18 +62,32 @@ def main():
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--unpruned', action='store_true',
help='print the unpruned rcps instead of the pruned')
parser.add_argument('--no-header', action='store_true',
help='do not print the header line')
parser.add_argument('--custom_rcps', type=argparse.FileType('r'),
help='specify an RCP json file to use')
parser.add_argument('--interpolate', action='store_true',
help='generate interpolated rcp min/mean for all batch sizes')


args = parser.parse_args()
checker=RCP_Checker(args.usage, args.version, args.benchmark, args.verbose, args.custom_rcps)
data=checker.pruned_rcp_data
rcp_pass_arg='pruned_rcps'
if (args.unpruned):
data=checker.rcp_data
rcp_pass_arg='full_rcps'

checker=RCP_Checker(args.usage, args.version, args.benchmark, args.verbose, args.custom_rcps)

if not args.no_header:
print("BS,Mean,Min")

print("BS,Mean,Min")
for key, record in data.items():
print(f"{record['BS']},{record['RCP Mean']},{record['Min Epochs']}")
if not args.interpolate:
data=checker._get_rcp_data(rcp_pass_arg)
for key, record in data.items():
print_rcp_record(record)
else:
for bs in range(checker._find_min_rcp(rcp_pass_arg)['BS'], find_max_rcp(checker, rcp_pass_arg)['BS']+1):
record = get_rcp_record_for_bs(bs, checker, rcp_pass_arg)
print_rcp_record(record)

if __name__ == '__main__':
main()