Skip to content

Commit

Permalink
Update HuggingFace api doc and add text2audio pipeline task (#28474)
Browse files Browse the repository at this point in the history
* update api doc and add text2audio pipeline task

* update doc

* update doc

* fix indent

* correct example snippet
  • Loading branch information
riteshghorse authored Sep 20, 2023
1 parent b390898 commit 4184f5e
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class PipelineTask(str, Enum):
TextClassification = 'text-classification'
TextGeneration = 'text-generation'
Text2TextGeneration = 'text2text-generation'
TextToAudio = 'text-to-audio'
TokenClassification = 'token-classification'
Translation = 'translation'
VideoClassification = 'video-classification'
Expand Down Expand Up @@ -570,7 +571,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
def __init__(
self,
task: Union[str, PipelineTask] = "",
model=None,
model: str = "",
*,
inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn,
load_pipeline_args: Optional[Dict[str, Any]] = None,
Expand All @@ -593,9 +594,18 @@ def __init__(
Args:
task (str or enum.Enum): task supported by HuggingFace Pipelines.
Accepts a string task or an enum.Enum from PipelineTask.
model : path to pretrained model on Hugging Face Models Hub to use custom
model for the chosen task. If the model already defines the task then
no need to specify the task parameter.
model (str): path to the pretrained *model-id* on Hugging Face Models Hub
to use custom model for the chosen task. If the `model` already defines
the task then no need to specify the `task` parameter.
Use the *model-id* string instead of an actual model here.
Model-specific kwargs for `from_pretrained(..., **model_kwargs)` can be
specified with `model_kwargs` using `load_pipeline_args`.
Example Usage::
model_handler = HuggingFacePipelineModelHandler(
task="text-generation", model="meta-llama/Llama-2-7b-hf",
load_pipeline_args={'model_kwargs':{'quantization_map':config}})
inference_fn: the inference function to use during RunInference.
Default is _default_pipeline_inference_fn.
load_pipeline_args (Dict[str, Any]): keyword arguments to provide load
Expand Down

0 comments on commit 4184f5e

Please sign in to comment.