Skip to content

Commit

Permalink
Merge corpus rewrite to python (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregtatum authored Oct 17, 2024
1 parent db41329 commit f1668c1
Show file tree
Hide file tree
Showing 15 changed files with 737 additions and 213 deletions.
318 changes: 318 additions & 0 deletions pipeline/clean/merge-corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
"""
Merges multiple corpora into a single "source" language file, and a single "target"
language file, each.
For instance:
dataset1.en.zst dataset1.ru.zst
dataset2.en.zst dataset2.ru.zst
dataset3.en.zst dataset3.ru.zst
Gets merged into:
corpus.en.zst
corpus.ru.zst
"""

import argparse
from contextlib import ExitStack
from glob import glob
from pathlib import Path
from typing import Generator, Optional
from pipeline.common.datasets import (
FilteringStep,
Statistics,
WeakStringSet,
shuffle_with_max_lines,
)
from pipeline.common.downloads import get_human_readable_file_size, read_lines, write_lines
from pipeline.common.logging import get_logger

logger = get_logger(__file__)

# TODO(CJK) - Issue #424
MAX_WORDS_IN_SENTENCE = 100


class FilteringStatistics(Statistics):
"""
Gather statistics about the filtering process.
"""

def __init__(self, dataset_path: Path) -> None:
super().__init__(dataset_path)
self.parallel_corpus = FilteringStep(
"The parallel corpora are merged and deduplicated",
)
self.final_truncated = FilteringStep("The final result can be truncated by max_lines")
self.datasets = []

def add_parallel_dataset(self, location: str):
# e.g. /path/to/ada83_v1.en.zst
path = Path(location)
# e.g. ada83_v1
dataset_stem = Path(path.stem).stem
step = FilteringStep(dataset_stem)
self.datasets.append(step)
return step


def log_dataset(location: str):
logger.info(f"Reading dataset {location}")


class DeduplicateCorpus:
def __init__(
self,
datasets_src: list[Path],
datasets_trg: list[Path],
src_outpath: Path,
trg_outpath: Path,
stats: FilteringStatistics,
) -> None:
self.datasets_src: list[Path] = datasets_src
self.datasets_trg: list[Path] = datasets_trg
self.src_outpath: Path = src_outpath
self.trg_outpath: Path = trg_outpath
self.stats: FilteringStatistics = stats
self.dataset_stats: FilteringStep = None

def run(
self,
total_corpus_bytes: int,
max_lines: Optional[int],
):
stats = self.stats
with ExitStack() as stack:
src_outfile = stack.enter_context(write_lines(self.src_outpath))
trg_outfile = stack.enter_context(write_lines(self.trg_outpath))

if max_lines:
for line in shuffle_with_max_lines(
line_stream=self.yield_lines_string(stack),
seed=38540735095,
max_lines=max_lines,
max_words_in_sentence=MAX_WORDS_IN_SENTENCE,
total_byte_size=total_corpus_bytes,
):
src_line, trg_line = line.split("\t")
src_outfile.write(src_line)
trg_outfile.write(trg_line)

stats.final_truncated.visited = stats.parallel_corpus.kept
stats.final_truncated.kept = min(max_lines, stats.parallel_corpus.kept)
else:
for src_line, trg_line in self.yield_lines_tuple(stack):
src_outfile.write(src_line)
trg_outfile.write(trg_line)

stats.final_truncated.kept = stats.parallel_corpus.kept
stats.final_truncated.visited = stats.parallel_corpus.kept

def yield_lines_tuple(self, stack: ExitStack) -> Generator[tuple[str, str], None, None]:
strings_seen = WeakStringSet()
stats = self.stats
src_lines: Generator[str, None, None] = stack.enter_context(
read_lines(self.datasets_src, on_enter_location=self.on_enter_location)
)
trg_lines: Generator[str, None, None] = stack.enter_context(
read_lines(self.datasets_trg, on_enter_location=log_dataset)
)

for src_line, trg_line in zip(src_lines, trg_lines):
# No separator is needed as the newline is included.
line = src_line + trg_line

if line in strings_seen:
stats.parallel_corpus.filtered += 1
self.dataset_stats.filtered += 1
else:
stats.parallel_corpus.kept += 1
self.dataset_stats.kept += 1

strings_seen.add(line)

yield src_line, trg_line

def yield_lines_string(self, stack: ExitStack) -> Generator[str, None, None]:
for src_line, trg_line in self.yield_lines_tuple(stack):
if "\t" in src_line or "\t" in trg_line:
logger.error("A line contained a tab character, skipping:")
logger.error(f" src: {src_line}")
logger.error(f" trg: {src_line}")
else:
yield f"{src_line}\t{trg_line}"

def on_enter_location(self, location):
log_dataset(location)
self.dataset_stats = self.stats.add_parallel_dataset(location)


def sample_corpus(
artifacts: Path, name: str, sample_size: int, src_outpath: Path, trg_outpath: Path
):
"""
Generate a sample of the corpus data with the following format:
e.g.
> cat artifacts/corpus.sample.txt
Sentence 1 in source language
Sentence 1 in target language
Sentence 2 in source language
Sentence 2 in target language
Sentence 3 in source language
Sentence 3 in target language
...
"""
total_byte_size = src_outpath.stat().st_size + trg_outpath.stat().st_size

with ExitStack() as stack:
sample_path = artifacts / f"{name}.sample.txt"

src_lines = stack.enter_context(read_lines(src_outpath))
trg_lines = stack.enter_context(read_lines(trg_outpath))
sample_outfile = stack.enter_context(
write_lines(
sample_path,
# The browser won't know the encoding when viewing this sample without including
# a "byte order mark", which python can do via this encoding.
encoding="utf-8-sig",
)
)

def join_src_trg():
for src_line, trg_line in zip(src_lines, trg_lines):
# The src and trg line each have a newline at the end. This means that
# each sentence pair will be separate by a blank line to make for easy
# scanning of datasets.
yield f"{src_line}{trg_line}\n"

logger.info("Stream in:")
logger.info(f" - {src_outpath}")
logger.info(f" - {trg_outpath}")
logger.info(f"Write a {sample_size:,} line sample of the merged corpus:")
logger.info(f" - {sample_path}")

for line in shuffle_with_max_lines(
line_stream=join_src_trg(),
seed=9834523434,
max_lines=sample_size,
max_words_in_sentence=MAX_WORDS_IN_SENTENCE,
total_byte_size=total_byte_size,
):
sample_outfile.write(line)


def get_datasets(src: str, trg: str, datasets_glob: str):
dataset_paths: list[str] = glob(datasets_glob)
datasets_src: list[Path] = []
datasets_trg: list[Path] = []
dataset_paths.sort()

total_corpus_bytes = 0

for dataset in dataset_paths:
path = Path(dataset)
if dataset.endswith(f"{src}.zst"):
datasets_src.append(path)
elif dataset.endswith(f"{trg}.zst"):
datasets_trg.append(path)
else:
raise Exception(f"Dataset does not match naming scheme: {dataset}")

formatted_size, bytes = get_human_readable_file_size(path)
logger.info(f" - {path} ({formatted_size})")
total_corpus_bytes += bytes

return datasets_src, datasets_trg, total_corpus_bytes


def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
# Preserves whitespace in the help text.
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--src",
type=str,
help="The source locale",
)

parser.add_argument(
"--trg",
type=str,
help="The target locale",
)

parser.add_argument(
"--datasets_glob",
type=str,
help="A glob-style path to the mono datasets, e.g. /path/to/*.zst",
)

parser.add_argument(
"--max_lines",
type=str,
default="None",
help="The (optionally) maximum number of sentences that will be merged.",
)

parser.add_argument(
"--sample_size", type=int, default=10_000, help="Generate a random sample of sentences."
)

parser.add_argument(
"--artifacts",
type=Path,
help="The path to the artifacts directory.",
)

parser.add_argument(
"--name",
type=str,
help='The final corpus name, e.g. "corpus" will output a "corpus.en.zst" file.',
)

args = parser.parse_args()

datasets_src, datasets_trg, total_corpus_bytes = get_datasets(
args.src, args.trg, args.datasets_glob
)

logger.info("Parallel datasets:")

src_outpath = args.artifacts / f"{args.name}.{args.src}.zst"
trg_outpath = args.artifacts / f"{args.name}.{args.trg}.zst"

stats = FilteringStatistics(args.artifacts / args.name)

max_lines: Optional[int] = None
if args.max_lines != "None":
max_lines = int(args.max_lines)

deduplicate_corpus = DeduplicateCorpus(
datasets_src,
datasets_trg,
src_outpath,
trg_outpath,
stats,
)

deduplicate_corpus.run(total_corpus_bytes, max_lines)

sample_corpus(
artifacts=args.artifacts,
name=args.name,
sample_size=args.sample_size,
src_outpath=src_outpath,
trg_outpath=trg_outpath,
)

stats.save_json()


if __name__ == "__main__":
main()
42 changes: 0 additions & 42 deletions pipeline/clean/merge-corpus.sh

This file was deleted.

7 changes: 6 additions & 1 deletion pipeline/clean/merge-mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ def deduplicate_lines(lines: Generator[str, None, None]) -> Generator[str, None,
log_memory(gc_collect=True)
sample_path = output_path.parent / f"{output_path.stem}.sample.txt"
logger.info(f"Write a 10,000 line sample of the final: {sample_path}")
with write_lines(sample_path) as outfile:
with write_lines(
sample_path,
# The browser won't know the encoding when viewing this sample without including
# a "byte order mark", which python can do via this encoding.
encoding="utf-8-sig",
) as outfile:
for line in shuffle_with_max_lines(
line_stream=final_lines,
seed=9834523434,
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit f1668c1

Please sign in to comment.