From 4184f5ea521941fcba6b90769b2d7f35b7262471 Mon Sep 17 00:00:00 2001 From: Ritesh Ghorse Date: Tue, 19 Sep 2023 22:40:15 -0400 Subject: [PATCH] Update HuggingFace api doc and add text2audio pipeline task (#28474) * update api doc and add text2audio pipeline task * update doc * update doc * fix indent * correct example snippet --- .../ml/inference/huggingface_inference.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index aee613363781..3ec063808ae3 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -98,6 +98,7 @@ class PipelineTask(str, Enum): TextClassification = 'text-classification' TextGeneration = 'text-generation' Text2TextGeneration = 'text2text-generation' + TextToAudio = 'text-to-audio' TokenClassification = 'token-classification' Translation = 'translation' VideoClassification = 'video-classification' @@ -570,7 +571,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str, def __init__( self, task: Union[str, PipelineTask] = "", - model=None, + model: str = "", *, inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn, load_pipeline_args: Optional[Dict[str, Any]] = None, @@ -593,9 +594,18 @@ def __init__( Args: task (str or enum.Enum): task supported by HuggingFace Pipelines. Accepts a string task or an enum.Enum from PipelineTask. - model : path to pretrained model on Hugging Face Models Hub to use custom - model for the chosen task. If the model already defines the task then - no need to specify the task parameter. + model (str): path to the pretrained *model-id* on Hugging Face Models Hub + to use custom model for the chosen task. If the `model` already defines + the task then no need to specify the `task` parameter. + Use the *model-id* string instead of an actual model here. + Model-specific kwargs for `from_pretrained(..., **model_kwargs)` can be + specified with `model_kwargs` using `load_pipeline_args`. + + Example Usage:: + model_handler = HuggingFacePipelineModelHandler( + task="text-generation", model="meta-llama/Llama-2-7b-hf", + load_pipeline_args={'model_kwargs':{'quantization_map':config}}) + inference_fn: the inference function to use during RunInference. Default is _default_pipeline_inference_fn. load_pipeline_args (Dict[str, Any]): keyword arguments to provide load