diff --git a/src/udf/README.md b/src/udf/README.md deleted file mode 100644 index 7b0eaa97b194f..0000000000000 --- a/src/udf/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# Python UDF Support - -🚧 Working in progress. - -# Usage - -```sh -pip3 install pyarrow -# run server -python3 python/example.py -# run client (test client for the arrow flight UDF client-server protocol) -cargo run --example client -``` - -Risingwave client: - -```sql -dev=> create function gcd(int, int) returns int language python as gcd using link 'http://localhost:8815'; -dev=> select gcd(25, 15); -``` diff --git a/src/udf/python/README.md b/src/udf/python/README.md new file mode 100644 index 0000000000000..8650f301d3ec6 --- /dev/null +++ b/src/udf/python/README.md @@ -0,0 +1,75 @@ +# RisingWave Python API + +This library provides a Python API for creating user-defined functions (UDF) in RisingWave. + +Currently, RisingWave supports user-defined functions implemented as external functions. +Users need to define functions using the API provided by this library, and then start a Python process as a UDF server. +RisingWave calls the function remotely by accessing the UDF server at a given address. + +## Installation + +```sh +pip install risingwave +``` + +## Usage + +Define functions in a Python file: + +```python +# udf.py +from risingwave.udf import udf, udtf, UdfServer + +# Define a scalar function +@udf(input_types=['INT', 'INT'], result_type='INT') +def gcd(x, y): + while y != 0: + (x, y) = (y, x % y) + return x + +# Define a table function +@udtf(input_types='INT', result_types='INT') +def series(n): + for i in range(n): + yield i + +# Start a UDF server +if __name__ == '__main__': + server = UdfServer(location="0.0.0.0:8815") + server.add_function(gcd) + server.add_function(series) + server.serve() +``` + +Start the UDF server: + +```sh +python3 udf.py +``` + +To create functions in RisingWave, use the following syntax: + +```sql +create function ( [, ...] ) + [ returns | returns table ( [, ...] ) ] + language python as + using link ''; +``` + +- The `language` parameter must be set to `python`. +- The `as` parameter specifies the function name defined in the UDF server. +- The `link` parameter specifies the address of the UDF server. + +For example: + +```sql +create function gcd(int, int) returns int +language python as gcd using link 'http://localhost:8815'; + +create function series(int) returns table (x int) +language python as series using link 'http://localhost:8815'; + +select gcd(25, 15); + +select * from series(10); +``` diff --git a/src/udf/python/example.py b/src/udf/python/example.py index 366ecda593e2c..86c4b8716d794 100644 --- a/src/udf/python/example.py +++ b/src/udf/python/example.py @@ -33,7 +33,7 @@ def series2(n: int) -> Iterator[tuple[int, str]]: if __name__ == '__main__': - server = UdfServer() + server = UdfServer(location="0.0.0.0:8815") server.add_function(random_int) server.add_function(gcd) server.add_function(gcd3) diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index 7134bc17014da..55ff4bb99d2cd 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -119,7 +119,21 @@ def udf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType result_type: Union[str, pa.DataType], name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]: """ - Annotation for creating a user-defined function. + Annotation for creating a user-defined scalar function. + + Parameters: + - input_types: A list of strings or Arrow data types that specifies the input data types. + - result_type: A string or an Arrow data type that specifies the return value type. + - name: An optional string specifying the function name. If not provided, the original name will be used. + + Example: + ``` + @udf(input_types=['INT', 'INT'], result_type='INT') + def gcd(x, y): + while y != 0: + (x, y) = (y, x % y) + return x + ``` """ return lambda f: UserDefinedScalarFunctionWrapper(f, input_types, result_type, name) @@ -130,6 +144,19 @@ def udtf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataTyp name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]: """ Annotation for creating a user-defined table function. + + Parameters: + - input_types: A list of strings or Arrow data types that specifies the input data types. + - result_types A list of strings or Arrow data types that specifies the return value types. + - name: An optional string specifying the function name. If not provided, the original name will be used. + + Example: + ``` + @udtf(input_types='INT', result_types='INT') + def series(n): + for i in range(n): + yield i + ``` """ return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name) @@ -137,13 +164,22 @@ def udtf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataTyp class UdfServer(pa.flight.FlightServerBase): """ - UDF server based on Apache Arrow Flight protocol. - Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight + A server that provides user-defined functions to clients. + + Example: + ``` + server = UdfServer(location="0.0.0.0:8815") + server.add_function(my_udf) + server.serve() + ``` """ + # UDF server based on Apache Arrow Flight protocol. + # Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight + _functions: Dict[str, UserDefinedFunction] - def __init__(self, location="grpc://0.0.0.0:8815", **kwargs): - super(UdfServer, self).__init__(location, **kwargs) + def __init__(self, location="0.0.0.0:8815", **kwargs): + super(UdfServer, self).__init__('grpc://' + location, **kwargs) self._functions = {} def get_flight_info(self, context, descriptor): diff --git a/src/udf/python/setup.py b/src/udf/python/setup.py new file mode 100644 index 0000000000000..0bc282d4516a4 --- /dev/null +++ b/src/udf/python/setup.py @@ -0,0 +1,15 @@ +from setuptools import find_packages, setup + +setup( + name="risingwave", + version="0.0.1", + author="RisingWave Labs", + description="RisingWave Python API", + url="https://github.com/risingwavelabs/risingwave", + packages=find_packages(), + classifiers=[ + "Programming Language :: Python", + "License :: OSI Approved :: Apache Software License" + ], + python_requires=">=3.10", +)