Skip to content

Commit

Permalink
fix: black & update black
Browse files Browse the repository at this point in the history
  • Loading branch information
AJDERS committed Feb 28, 2024
1 parent e355e46 commit f27802d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/ambv/black
rev: 23.7.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
36 changes: 12 additions & 24 deletions src/coral_models/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


class Processor(Protocol):
def __call__(self, *args, **kwargs) -> BatchEncoding:
...
def __call__(self, *args, **kwargs) -> BatchEncoding: ...

def decode(
self,
Expand All @@ -37,15 +36,12 @@ def decode(
lm_score_boundary: bool | None = None,
output_word_offsets: bool = False,
n_best: int = 1,
) -> Wav2Vec2DecoderWithLMOutput | str:
...
) -> Wav2Vec2DecoderWithLMOutput | str: ...

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "Processor":
...
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "Processor": ...

def save_pretrained(self, save_directory: str) -> None:
...
def save_pretrained(self, save_directory: str) -> None: ...


@dataclass
Expand All @@ -57,26 +53,18 @@ class PreTrainedModelData:


class ModelSetup(Protocol):
def __init__(self, cfg: DictConfig) -> None:
...
def __init__(self, cfg: DictConfig) -> None: ...

def load_processor(self) -> Processor:
...
def load_processor(self) -> Processor: ...

def load_model(self) -> PreTrainedModel:
...
def load_model(self) -> PreTrainedModel: ...

def load_data_collator(self) -> DataCollatorMixin:
...
def load_data_collator(self) -> DataCollatorMixin: ...

def load_trainer_class(self) -> Type[Trainer]:
...
def load_trainer_class(self) -> Type[Trainer]: ...

def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]:
...
def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: ...

def load_training_arguments(self) -> TrainingArguments:
...
def load_training_arguments(self) -> TrainingArguments: ...

def load_saved(self) -> PreTrainedModelData:
...
def load_saved(self) -> PreTrainedModelData: ...

0 comments on commit f27802d

Please sign in to comment.