Skip to content

Commit

Permalink
Merge pull request #29 from sam-writer/master
Browse files Browse the repository at this point in the history
More verbose logging output for compare m2
  • Loading branch information
chrisjbryant authored Apr 14, 2022
2 parents bd99745 + f93a0c8 commit 0dc0848
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions errant/commands/compare_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def main():
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
# original sentence for logging
original_sentence = sent[0][2:].split("\nA")[0]
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(
hyp_dict, ref_dict, best_dict, sent_id, args)
hyp_dict, ref_dict, best_dict, sent_id, original_sentence, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
Expand Down Expand Up @@ -198,7 +200,11 @@ def process_edits(edits, args):
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(hyp_dict, ref_dict, best, sent_id, args):
def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args):
# Verbose output: display the original sentence
if args.verbose:
print('{:-^40}'.format(""))
print("Original sentence " + str(sent_id) + ": " + original_sentence)
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
Expand Down Expand Up @@ -230,6 +236,11 @@ def evaluate_edits(hyp_dict, ref_dict, best, sent_id, args):
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# add categories
# hyp_dict[hyp_id] looks like (0, 1, "str")
# hyp_dict[hyp_id][h] is a list, always length one, of the corresponding category
hyp_verb = [h + (hyp_dict[hyp_id][h][0],) for h in hyp_verb]
ref_verb = [r + (ref_dict[ref_id][r][0],) for r in ref_verb]
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
Expand All @@ -246,6 +257,10 @@ def evaluate_edits(hyp_dict, ref_dict, best, sent_id, args):
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
print("Local results:")
header = ["Category", "TP", "FP", "FN"]
body = [[k, *v] for k, v in best_cat.items()]
print_table([header] + body)
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
Expand All @@ -254,7 +269,7 @@ def evaluate_edits(hyp_dict, ref_dict, best, sent_id, args):
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
Expand Down Expand Up @@ -360,22 +375,31 @@ def print_results(best, best_cats, args):
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)

# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")

def print_table(table):
longest_cols = [
(max([len(str(row[i])) for row in table]) + 3)
for i in range(len(table[0]))
]
row_format = "".join(["{:>" + str(longest_col) + "}" for longest_col in longest_cols])
for row in table:
print(row_format.format(*row))

if __name__ == "__main__":
# Run the program
main()
main()

0 comments on commit 0dc0848

Please sign in to comment.