-
Notifications
You must be signed in to change notification settings - Fork 3
/
distill3b_ic_classifier_gpt.py
126 lines (107 loc) · 4.02 KB
/
distill3b_ic_classifier_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Distill3: Given triples, filter the events in the triples:
- remove OOC utterances (anything not classified as in-character with >80% confidence)
Input: {"before": [Message...], "commands": [Event...], "after": [Message...]}
Output: {
"before": [Message...],
"commands": [Event...],
"after": [Message...],
}
- with `after` filtered to only include IC-classified utterances (maybe `before` too - see how you feel about it)
"""
import glob
import logging
import pathlib
import time
import openai
import tqdm.contrib.concurrent
import tqdm.contrib.logging
from dataset.utils import read_gzipped_file, write_jsonl
DATA_DIR = pathlib.Path("data/")
# IN_DIR = pathlib.Path("extract/experiment2/")
IN_DIR = pathlib.Path("extract/experiment3a/")
OUT_DIR = pathlib.Path("extract/experiment3b/")
CLASSIFIER_FINETUNE = "ada:ft-ccb-lab-members-2022-11-28-18-29-25"
log = logging.getLogger("distill3")
loglevel = logging.INFO
logging.getLogger("openai").setLevel(logging.WARNING)
def get_ooc_ic_label(text, finetuned_model=CLASSIFIER_FINETUNE):
if not text:
return "out-of-character", 1
if "OOC" in text or "OOG" in text or text.startswith("("):
return "out-of-character", 1
# if text.startswith('"'):
# return "in-character"
if len(text.split(" ")) > 200:
text = " ".join(text.split(" ")[:200])
for _ in range(3):
response = openai.Completion.create(
model=finetuned_model,
prompt=text + "\nlabel: ",
temperature=0,
max_tokens=7,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=["###", "\n"],
logprobs=1,
)
time.sleep(0.05)
label = response["choices"][0]["text"].strip()
if label == "in-character" or label == "out-of-character" or label == "mixed":
prob = 2 ** (response["choices"][0]["logprobs"]["token_logprobs"][0])
return label, prob
return None, 1
def process_triple(triple) -> dict | None:
after = triple["after"]
filtered_utterances = []
for event in after:
content = event["content"].strip()
label, prob = get_ooc_ic_label(content)
log.info(f"{content}\n---\n{label} {prob:.2%}\n=====\n")
if not (label == "in-character" and prob > 0.8):
continue
filtered_utterances.append(event)
triple["after"] = filtered_utterances
log.info(f'after content: {sum(len(msg["content"]) for msg in triple["after"])} in {len(triple["after"])} events')
if triple["after"] or triple["before"]:
return triple
return None
def process_file(fp: pathlib.Path):
"""
Given a path to a file containing a list of triples, filter the triples and return a pair of
(n_triples_in, n_triples_out).
"""
triple_stream = read_gzipped_file(fp)
num_triples_in = 0
combat_id, *_ = fp.stem.split(".")
out = []
for triple in triple_stream:
num_triples_in += 1
processed = process_triple(triple)
if processed is not None:
out.append(processed)
# discard if we have nothing
if not out:
log.info("nothing was processed")
return num_triples_in, 0
# see what we get
write_jsonl(OUT_DIR / f"{combat_id}.jsonl.gz", out)
return num_triples_in, len(out)
if __name__ == "__main__":
logging.basicConfig(level=loglevel, format="%(name)s:%(levelname)s: %(message)s")
OUT_DIR.mkdir(parents=True, exist_ok=True)
filenames = sorted(glob.glob("*.gz", root_dir=IN_DIR))
files = [pathlib.Path(IN_DIR, fn) for fn in filenames]
with tqdm.contrib.logging.logging_redirect_tqdm():
results = []
for d in tqdm.tqdm(files):
results.append(process_file(d))
kept_distill_count = sum(1 for (i, o) in results if o)
n_triples_in = sum(i for i, o in results)
n_triples_out = sum(o for i, o in results)
print(
f"Distill complete!\n"
f"Instances: {len(filenames)} -> {kept_distill_count}\n"
f"Triples: {n_triples_in} -> {n_triples_out}"
)