Skip to content

Commit

Permalink
feat: fix selection to avoid too small dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AJDERS committed Jan 10, 2024
1 parent 5fb63cc commit b28fbc6
Showing 1 changed file with 169 additions and 58 deletions.
227 changes: 169 additions & 58 deletions src/scripts/select_testset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def get_recordings_length(recordings: pd.DataFrame) -> tuple[float, float]:
def select_by_region(
current_selection: pd.DataFrame,
) -> pd.DataFrame:
"""Selects speakers such that the selection has 10% from each region.
"""Selects speakers such that the selection has 15% from each region.
If the current selection already has 15% from each region, the selection
is returned as is. If the current selection has a region which account for less
than 15%, then the remaining regions are reduced such that the smallest region
accounts for 15% of the selection.
Args:
current_selection (pd.DataFrame):
Expand All @@ -57,13 +62,26 @@ def select_by_region(
pd.DataFrame:
The updated selection of speakers.
"""
# Sort by length and take the region with the least amount of recordings
# as the threshold.

# Check if all of the regions are above the threshold
threshold = 0.15 * current_selection.length.sum()
regions = [
(region, region.length.sum())
for _, region in current_selection.groupby("region")
]
new_threshold = sorted(regions, key=lambda x: x[1])[0][1]
if all([threshold < length for _, length in regions]):
return current_selection

# Define new threshold as the size of the smallest region
regions_dist = sorted(regions, key=lambda x: x[1])

# Get the length of speakers from smallest regions
least_repr_region_length = regions_dist[0][1]

# We remove speakers from the overrepresented regions, uniformly from each
# regions.
new_total_length = least_repr_region_length / 0.15
new_threshold = new_total_length * (0.85 / 4) # 4 remaining regions

# Remove recordings from the regions until they are below the threshold.
# We remove the recordings with the longest duration first.
Expand All @@ -72,7 +90,7 @@ def select_by_region(
for region, length in regions:
for _, row in region.sort_values(by="length", ascending=False).iterrows():
# Check if removing the speaker will bring the selection below the threshold
if length > new_threshold:
if length >= new_threshold:
length -= row.length
deselected_speakers.append(row.speaker_id)
else:
Expand All @@ -91,6 +109,11 @@ def select_by_accent(
) -> pd.DataFrame:
"""Remove speakers such 10% of speakers have accent and 90% do not.
If the current selection already has 10% speakers with accent, the selection
is returned as is. If the current selection has kess than 10% speakers with accent,
then the amount of speakers without accent is reduced such that speakers with
accent accounts for 10% of the selection.
Args:
current_selection (pd.DataFrame):
The current selection of speakers.
Expand All @@ -100,56 +123,65 @@ def select_by_accent(
The updated selection of speakers.
"""

# Remove speakers from the current such that the selection has 10% speakers
# with accent and 90% without accent
total_recordings = current_selection.length.sum()
threshold = 0.1 * total_recordings

# Get the distribution of speakers with and without accent
# Check if all the selection has more than 10% with accent
accent_dist = []
for _, accent in current_selection.groupby("accent"):
length = accent.conversation_length.sum() + accent.read_aloud_length.sum()
length = accent.length.sum()
accent_dist.append((accent, length))

# Sort by length
threshold = 0.1 * current_selection.length.sum()
if all([threshold < length for _, length in accent_dist]):
return current_selection

# Either there are too few speakers with accent, or too many. We remove
# speakers accordingly.
accent_dist = sorted(accent_dist, key=lambda x: x[1])

# Get the length of speakers with and without accent
least_repr_accent_group = accent_dist[0][0]
least_repr_accent_group_length = accent_dist[0][1]
most_repr_accent_group = accent_dist[1][1]
most_repr_accent_group_length = accent_dist[1][0]

# We remove speakers with accent until we have 10% speakers with accent
# and 90% without accent, this is done by scaling the threshold.

# If the smallest accent group is larger than the threshold, we remove
# speakers from the smallest accent group, otherwise the largest accent
# group is removed from. This is done to ensure that we remove the least
# amount of speakers. We remove uniformly from each region.
if least_repr_accent_group_length > threshold:
new_total_length = most_repr_accent_group_length / 0.9
new_threshold = new_total_length * 0.1
region_groups = least_repr_accent_group.groupby("region")
group_to_remove_from_length = least_repr_accent_group_length
else:
new_total_length = least_repr_accent_group_length / 0.1
new_threshold = new_total_length * 0.9
region_groups = most_repr_accent_group.groupby("region")
group_to_remove_from_length = most_repr_accent_group_length

most_repr_accent_group = accent_dist[1][0]
most_repr_accent_group_length = accent_dist[1][1]

# We remove speakers without accent until we have 10% speakers with accent.
# This is done by adjusting the threshold.
new_total_length = least_repr_accent_group_length / 0.1
new_threshold = new_total_length * 0.9
region_groups = most_repr_accent_group.groupby("region")
deselected_speakers = []
index = 0
while group_to_remove_from_length > new_threshold:

# We wish to remove speakers with accent uniformly over regions, but
# when we adjusted to have the correct amount of data from each of the
# regions, we removed the speakers with the longest recordings first from the
# overrepresented regions. This means that the region which started out being
# most underrepresented now still has the speakers with the longest recordings.
# Hence when we remove speakers from this region, we are removing far more data
# than when we remove speakers from the other regions. This is not ideal, as we
# will skew the region distribution. Hence we skip the smallest region when
# removing speakers with accent.
smallest_region = (
current_selection.groupby("region").length.sum().sort_values().keys()[0]
)
while most_repr_accent_group_length > new_threshold:
# Remove index-th speaker from each region, as long as the selection
# is above the threshold
for _, region in region_groups:
is_smallest_region = region.region.values[0] == smallest_region
no_more_speakers_to_remove = region.shape[0] <= index
group_is_a_single_speaker = region.shape[0] <= 1
if (
no_more_speakers_to_remove
or group_is_a_single_speaker
or is_smallest_region
):
continue

row = region.sort_values(by="length", ascending=False).iloc[index]

# Check if removing the speaker will bring the selection below the
# threshold
if group_to_remove_from_length > new_threshold:
group_to_remove_from_length -= row.length
if most_repr_accent_group_length > new_threshold:
most_repr_accent_group_length -= row.length
deselected_speakers.append(row.speaker_id)
else:
break
Expand All @@ -166,7 +198,7 @@ def select_by_accent(
def select_by_gender(
current_selection: pd.DataFrame,
) -> pd.DataFrame:
"""Remove speakers such that gender distribution is 50/50.
"""Remove speakers such that each gender is represented 45%.
Args:
current_selection (pd.DataFrame):
Expand All @@ -177,8 +209,7 @@ def select_by_gender(
The updated selection of speakers.
"""

# Get the distribution of male and female, and remove speakers from the
# remaining genders from the test set.
# Check if all the selection has more than 45% of each gender
deselected_speakers = []
gender_dist = []
for _, gender in current_selection.groupby("gender"):
Expand All @@ -187,6 +218,10 @@ def select_by_gender(
else:
deselected_speakers.extend(gender.speaker_id.values.tolist())

threshold = 0.45 * current_selection.length.sum()
if all([threshold < length for _, length in gender_dist]):
return current_selection

# Sort by length
gender_dist = sorted(gender_dist, key=lambda x: x[1])

Expand All @@ -195,23 +230,19 @@ def select_by_gender(
most_repr_gender = gender_dist[1][0]
most_repr_gender_length = gender_dist[1][1]

# We remove speakers from the overrepresented gender until we have 50/50
# gender distribution, this is done by scaling the threshold
new_total_length = least_repr_gender_length / 0.5
new_threshold = new_total_length * 0.5

# Remove speakers with overrepresented gender, uniformly from each region
# and accent group, as long as the selection is above the threshold.
# This can't be done uniformly across the whole selection, as some region
# and accent groups has too few speakers.
# We remove speakers from the overrepresented gender until we have 45%
# of each gender. We remove speakers uniformly from each region and accent.
new_total_length = least_repr_gender_length / 0.45
new_threshold = new_total_length * 0.55
region_accent_groups = most_repr_gender.groupby(["region", "accent"])
index = 0
while most_repr_gender_length > new_threshold:
# Remove index-th speaker from each region, as long as the selection
# is above the threshold
for _, region_accent in region_accent_groups:
if region_accent.shape[0] <= index:
# Skip if there is only one speaker in the region and accent group
no_more_speakers_to_remove = region_accent.shape[0] <= index
group_is_a_single_speaker = region_accent.shape[0] <= 1
if no_more_speakers_to_remove or group_is_a_single_speaker:
continue

row = region_accent.sort_values(by="length", ascending=False).iloc[index]
Expand All @@ -233,17 +264,90 @@ def select_by_gender(
return current_selection


def select_by_length(
current_selection: pd.DataFrame, type_of_recording: str
def select_by_age(
current_selection: pd.DataFrame,
) -> pd.DataFrame:
"""Remove speakers such that each age group is represented 20%.
If the current selection already has 20% speakers of each age group, the selection
is returned as is. If the current selection has less than this, then the amount of
speakers from the overrepresented age groups are reduced such that each age group
accounts for 20% of the selection.
Args:
current_selection (pd.DataFrame):
The current selection of speakers.
Returns:
pd.DataFrame:
The updated selection of speakers.
"""
# Sort by length and take the age group with the least amount of recordings
# as the threshold.
age_groups = [
(age_group, age_group.length.sum())
for _, age_group in current_selection.groupby("age")
]

# Check if all the selection has more than 20% of each age group
threshold = 0.2 * current_selection.length.sum()
if all([threshold < length for _, length in age_groups]):
return current_selection

# Define new threshold as the size of the smallest age group
age_group_dist = sorted(age_groups, key=lambda x: x[1])

# Get the length of speakers from smallest age group
least_repr_age_group_length = age_group_dist[0][1]

# We remove speakers from the overrepresented age groups, uniformly over
# regions, accent and gender.
new_total_length = least_repr_age_group_length / 0.1
new_threshold = new_total_length * (0.90 / 2) # 2 remaining age groups
deselected_speakers = []
for age_group, length in age_groups:
if length <= new_threshold:
continue
else:
region_accent_gender_groups = age_group.groupby(
["region", "accent", "gender"]
)
index = 0
while length > new_threshold:
# Remove index-th speaker from each region, as long as the selection
# is above the threshold
for _, region_accent_gender in region_accent_gender_groups:
no_more_speakers_to_remove = region_accent_gender.shape[0] <= index
group_is_a_single_speaker = region_accent_gender.shape[0] <= 1
if no_more_speakers_to_remove or group_is_a_single_speaker:
continue

row = region_accent_gender.sort_values(
by="length", ascending=False
).iloc[index]

# Check if removing the speaker will bring the selection below the
# threshold
if length > new_threshold:
length -= row.length
deselected_speakers.append(row.speaker_id)
else:
break
index += 1

# Update the current selection
current_selection = current_selection[
~current_selection.speaker_id.isin(deselected_speakers)
]

return current_selection


def select_by_length(current_selection: pd.DataFrame) -> pd.DataFrame:
"""Selects speakers based on the length of their recordings.
Args:
current_selection (pd.DataFrame):
The current selection of speakers.
type_of_recording (str):
The type of recordings to select on.
Either `conversation_length` or `read_aloud_length`.
Returns:
pd.DataFrame:
Expand Down Expand Up @@ -312,6 +416,10 @@ def main(cfg: DictConfig) -> None:
"sjælland",
)

dialect = speaker_metadata[
speaker_metadata.speaker_id == speaker_id
].dialect.values[0]

# Get the gender of the speaker
gender = speaker_metadata[
speaker_metadata.speaker_id == speaker_id
Expand Down Expand Up @@ -339,6 +447,7 @@ def main(cfg: DictConfig) -> None:
"gender": gender,
"age": age,
"accent": accent,
"dialect": dialect,
"conversation_length": conversation_length,
"read_aloud_length": read_aloud_length,
"length": conversation_length + read_aloud_length,
Expand Down Expand Up @@ -372,10 +481,12 @@ def main(cfg: DictConfig) -> None:
current_selection = select_by_gender(
current_selection=current_selection,
)
current_selection = select_by_age(
current_selection=current_selection,
)

current_selection = select_by_length(
current_selection=current_selection,
type_of_recording="conversation_length",
)


Expand Down

0 comments on commit b28fbc6

Please sign in to comment.