Skip to content

Commit

Permalink
Merge pull request taoyds#2 from ElementAI/dima_def_main
Browse files Browse the repository at this point in the history
add the main function
  • Loading branch information
rizar authored Dec 14, 2020
2 parents 4cd68be + e982f73 commit 4f9c20f
Showing 1 changed file with 6 additions and 21 deletions.
27 changes: 6 additions & 21 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def evaluate(gold, predict, db_dir, etype, kmaps):
evaluator = Evaluator(db_dir, kmaps, etype)
results = []
for p, g in zip(plist, glist):
(predicted,) = p
predicted, db_name = p
gold, db_name = g
results.append(evaluator.evaluate_one(db_name, gold, predicted))
evaluator.finalize()
Expand Down Expand Up @@ -1078,27 +1078,12 @@ def build_foreign_key_map_from_json(table):
return tables


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gold", dest="gold", type=str)
parser.add_argument("--pred", dest="pred", type=str)
parser.add_argument("--db", dest="db", type=str)
parser.add_argument("--table", dest="table", type=str)
parser.add_argument("--etype", dest="etype", type=str)
parser.add_argument("--output")
args = parser.parse_args()

gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype

assert etype in ["all", "exec", "match"], "Unknown evaluation method"

def main(gold, pred, db_dir, table, etype, output):
if etype not in ['match', 'exec', 'all']:
raise ValueError()
kmaps = build_foreign_key_map_from_json(table)

results = evaluate(gold, pred, db_dir, etype, kmaps)
if args.output:
with open(args.output, "w") as f:
if output:
with open(output, "w") as f:
json.dump(results, f)

0 comments on commit 4f9c20f

Please sign in to comment.