From 36589a605585afde46ceec89bea790bff13c02fb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:36:56 -0500 Subject: [PATCH] feat: add the ability to bring your own (FileDescriptorSet) bytes --- protoletariat/__main__.py | 18 +++++- protoletariat/fdsetgen.py | 22 +++++-- protoletariat/tests/conftest.py | 106 ++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 5 deletions(-) diff --git a/protoletariat/__main__.py b/protoletariat/__main__.py index 57966fce..444d1cd4 100644 --- a/protoletariat/__main__.py +++ b/protoletariat/__main__.py @@ -2,11 +2,13 @@ from __future__ import annotations +import sys from pathlib import Path +from typing import IO import click -from .fdsetgen import Buf, Protoc +from .fdsetgen import Buf, Protoc, Raw def _overwrite(python_file: Path, code: str) -> None: @@ -155,5 +157,19 @@ def buf(ctx: click.Context, buf_path: str) -> None: Buf(buf_path).fix_imports(**ctx.obj) +@main.command(help="Rewrite imports using FileDescriptorSet bytes from a file or stdin") +@click.option( + "--descriptor-set", + type=click.File("rb"), + default=sys.stdin, + show_default=True, + show_envvar=True, + help="Path to the `buf` executable", +) +@click.pass_context +def raw(ctx: click.Context, descriptor_set: IO[bytes]) -> None: + Raw(descriptor_set.read()).fix_imports(**ctx.obj) + + if __name__ == "__main__": main() diff --git a/protoletariat/fdsetgen.py b/protoletariat/fdsetgen.py index 480ad756..c00c0453 100644 --- a/protoletariat/fdsetgen.py +++ b/protoletariat/fdsetgen.py @@ -28,9 +28,6 @@ def _should_ignore(fd_name: str, patterns: Sequence[str]) -> bool: class FileDescriptorSetGenerator(abc.ABC): """Base class that implements fixing imports.""" - def __init__(self, fdset_generator_binary: str) -> None: - self.fdset_generator_binary = fdset_generator_binary - @abc.abstractmethod def generate_file_descriptor_set_bytes(self) -> bytes: """Generate the bytes of a `FileDescriptorSet`""" @@ -90,13 +87,15 @@ def fix_imports( class Protoc(FileDescriptorSetGenerator): + """Generate the FileDescriptorSet using `protoc`.""" + def __init__( self, protoc_path: str, proto_files: Iterable[Path], proto_paths: Iterable[Path], ) -> None: - super().__init__(protoc_path) + self.fdset_generator_binary = protoc_path self.proto_files = list(proto_files) self.proto_paths = list(proto_paths) @@ -120,6 +119,11 @@ def generate_file_descriptor_set_bytes(self) -> bytes: class Buf(FileDescriptorSetGenerator): + """Generate the FileDescriptorSet using `buf`.""" + + def __init__(self, fdset_generator_binary: str) -> None: + self.fdset_generator_binary = fdset_generator_binary + def generate_file_descriptor_set_bytes(self) -> bytes: return subprocess.check_output( [ @@ -131,3 +135,13 @@ def generate_file_descriptor_set_bytes(self) -> bytes: "-", ] ) + + +class Raw(FileDescriptorSetGenerator): + """Generate the FileDescriptorSet using user-provided bytes.""" + + def __init__(self, fdset_bytes: bytes) -> None: + self.fdset_bytes = fdset_bytes + + def generate_file_descriptor_set_bytes(self) -> bytes: + return self.fdset_bytes diff --git a/protoletariat/tests/conftest.py b/protoletariat/tests/conftest.py index 6ad69da8..a3ebdd0a 100644 --- a/protoletariat/tests/conftest.py +++ b/protoletariat/tests/conftest.py @@ -6,6 +6,7 @@ import os import shutil import subprocess +import tempfile from functools import partial from pathlib import Path from typing import Generator, Iterable, NamedTuple, Sequence @@ -196,6 +197,74 @@ def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result: ) +class RawFixture(ProtoletariatFixture): + def __init__( + self, + *, + base_dir: Path, + package: str, + proto_texts: Iterable[ProtoFile], + monkeypatch: pytest.MonkeyPatch, + grpc: bool = False, + mypy: bool = False, + mypy_grpc: bool = False, + ) -> None: + super().__init__( + base_dir=base_dir, + package=package, + proto_texts=proto_texts, + monkeypatch=monkeypatch, + ) + self.grpc = grpc + self.mypy = mypy + self.mypy_grpc = mypy_grpc + + def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result: + # TODO: refactor this, it duplicates a lot of what's in ProtocFixture + with tempfile.NamedTemporaryFile(delete=False) as f: + filename = f.name + + protoc_args = [ + "protoc", + "--include_imports", + f"--descriptor_set_out={filename}", + "--proto_path", + str(self.base_dir), + "--python_out", + str(self.package_dir), + *(str(fn) for fn, _ in self.proto_texts), + ] + + if self.grpc: + # XXX: why isn't this found? PATH is set properly + grpc_python_plugin = shutil.which("grpc_python_plugin") + protoc_args.extend( + ( + f"--plugin=protoc-gen-grpc_python={grpc_python_plugin}", + "--grpc_python_out", + str(self.package_dir), + ) + ) + if self.mypy: + protoc_args.extend(("--mypy_out", str(self.package_dir))) + if self.mypy_grpc: + protoc_args.extend(("--mypy_grpc_out", str(self.package_dir))) + + subprocess.check_call(protoc_args) + + protol_args = [ + "--python-out", + str(self.package_dir), + *args, + "raw", + f"--descriptor-set={filename}", + ] + try: + return cli.invoke(main, protol_args, catch_exceptions=False) + finally: + os.unlink(filename) + + @pytest.fixture def cli() -> CliRunner: return CliRunner() @@ -250,6 +319,10 @@ def basic_cli_texts() -> list[ProtoFile]: partial(ProtocFixture, package="basic_cli"), id="basic_cli_protoc", ), + pytest.param( + partial(RawFixture, package="basic_cli"), + id="basic_cli_raw", + ), ] ) def basic_cli( @@ -324,6 +397,10 @@ def thing_service_texts() -> list[ProtoFile]: partial(ProtocFixture, package="thing_service", grpc=True), id="thing_service_protoc", ), + pytest.param( + partial(RawFixture, package="thing_service", grpc=True), + id="thing_service_raw", + ), ] ) def thing_service( @@ -375,6 +452,7 @@ def nested_texts() -> list[ProtoFile]: id="nested_buf", ), pytest.param(partial(ProtocFixture, package="nested"), id="nested_protoc"), + pytest.param(partial(RawFixture, package="nested"), id="nested_raw"), ] ) def nested( @@ -432,6 +510,16 @@ def no_imports_service_texts() -> list[ProtoFile]: ), id="no_imports_service_protoc", ), + pytest.param( + partial( + RawFixture, + package="no_imports_service", + grpc=True, + mypy=True, + mypy_grpc=True, + ), + id="no_imports_service_raw", + ), ] ) def no_imports_service( @@ -510,6 +598,16 @@ def imports_service_texts() -> list[ProtoFile]: ), id="imports_service_protoc", ), + pytest.param( + partial( + RawFixture, + package="imports_service", + grpc=True, + mypy=True, + mypy_grpc=True, + ), + id="imports_service_raw", + ), ] ) def grpc_imports( @@ -558,6 +656,10 @@ def long_names_texts() -> list[ProtoFile]: partial(ProtocFixture, package="long_names", mypy=True), id="long_names_protoc", ), + pytest.param( + partial(RawFixture, package="long_names", mypy=True), + id="long_names_raw", + ), ] ) def long_names( @@ -604,6 +706,10 @@ def ignored_import_texts() -> list[ProtoFile]: partial(ProtocFixture, package="ignored_imports"), id="ignored_imports_protoc", ), + pytest.param( + partial(RawFixture, package="ignored_imports"), + id="ignored_imports_raw", + ), ] ) def ignored_imports(