Skip to content

Commit

Permalink
feat(protoc): forward protoc arguments to protoc generator
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 19, 2023
1 parent 2cc748c commit 2a0e96a
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 26 deletions.
23 changes: 8 additions & 15 deletions protoletariat/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from pathlib import Path
from typing import IO
from typing import IO, Iterable

import click

Expand Down Expand Up @@ -100,7 +100,10 @@ def main(
)


@main.command(help="Use protoc to generate the FileDescriptorSet blob")
@main.command(
context_settings=dict(ignore_unknown_options=True),
help="Use protoc to generate the FileDescriptorSet blob",
)
@click.option(
"--protoc-path",
envvar="PROTOC_PATH",
Expand All @@ -124,28 +127,18 @@ def main(
),
help="Protobuf file search path(s). Accepts multiple values.",
)
@click.argument(
"proto_files",
nargs=-1,
required=True,
type=click.Path(
file_okay=True,
dir_okay=False,
exists=True,
path_type=Path,
),
)
@click.argument("protoc_args", nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def protoc(
ctx: click.Context,
protoc_path: str,
proto_paths: list[Path],
proto_files: list[Path],
protoc_args: Iterable[str],
) -> None:
Protoc(
protoc_path=os.fsdecode(protoc_path),
proto_files=[Path(os.fsdecode(proto_file)) for proto_file in proto_files],
proto_paths=[Path(os.fsdecode(proto_path)) for proto_path in proto_paths],
protoc_args=list(protoc_args),
).fix_imports(**ctx.obj)


Expand Down
21 changes: 10 additions & 11 deletions protoletariat/fdsetgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,24 @@ def __init__(
self,
*,
protoc_path: str,
proto_files: Iterable[Path],
proto_paths: Iterable[Path],
protoc_args: Iterable[str],
) -> None:
self.protoc_path = protoc_path
self.proto_files = proto_files
self.proto_paths = proto_paths
self.protoc_args = protoc_args

def generate_file_descriptor_set_bytes(self) -> bytes:
with tempfile.NamedTemporaryFile(delete=False) as f:
filename = Path(f.name)
subprocess.check_output(
[
*shlex.split(self.protoc_path),
"--include_imports",
f"--descriptor_set_out={filename}",
*map("--proto_path={}".format, self.proto_paths),
*map(str, self.proto_files),
]
)
args = [
*shlex.split(self.protoc_path),
"--include_imports",
f"--descriptor_set_out={filename}",
*map("--proto_path={}".format, self.proto_paths),
*self.protoc_args,
]
subprocess.check_output(args)

try:
return filename.read_bytes()
Expand Down
116 changes: 116 additions & 0 deletions protoletariat/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,83 @@ class GrpcIoToolsFixture(ProtocFixture):
protoc_exe = sys.executable, "-m", "grpc_tools.protoc"


class RawProtocFixture(ProtoletariatFixture):
protoc_exe = ("protoc",)

def __init__(
self,
*,
base_dir: Path,
package: str,
proto_texts: Iterable[ProtoFile],
monkeypatch: pytest.MonkeyPatch,
grpc: bool = False,
mypy: bool = False,
mypy_grpc: bool = False,
) -> None:
super().__init__(
base_dir=base_dir,
package=package,
proto_texts=proto_texts,
monkeypatch=monkeypatch,
)
self.grpc = grpc
self.mypy = mypy
self.mypy_grpc = mypy_grpc

def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result:
with tempfile.NamedTemporaryFile(delete=False) as f:
filename = f.name

protoc_args = [
"protoc",
"--include_imports",
f"--descriptor_set_out={filename}",
"--proto_path",
str(self.base_dir),
"--python_out",
str(self.package_dir),
*(str(fn) for fn, _ in self.proto_texts),
]

if self.grpc:
# XXX: why isn't this found? PATH is set properly
grpc_python_plugin = shutil.which("grpc_python_plugin")
protoc_args.extend(
(
f"--plugin=protoc-gen-grpc_python={grpc_python_plugin}",
"--grpc_python_out",
str(self.package_dir),
)
)
if self.mypy:
protoc_args.extend(("--mypy_out", str(self.package_dir)))
if self.mypy_grpc:
protoc_args.extend(("--mypy_grpc_out", str(self.package_dir)))

subprocess.check_call(protoc_args)

try:
return cli.invoke(
main,
[
"--python-out",
str(self.package_dir),
*args,
"protoc",
"--protoc-path",
shlex.join(self.protoc_exe),
"--proto-path",
str(self.base_dir),
f"--descriptor_set_in={filename}",
*(str(filename) for filename, _ in self.proto_texts),
],
catch_exceptions=False,
)
finally:
os.unlink(filename)


class RawFixture(ProtoletariatFixture):
def __init__(
self,
Expand Down Expand Up @@ -337,6 +414,10 @@ def basic_cli_texts(request: SubRequest) -> list[ProtoFile]:
partial(RawFixture, package="basic_cli"),
id="basic_cli_raw",
),
pytest.param(
partial(RawProtocFixture, package="basic_cli"),
id="basic_cli_raw_protoc",
),
]
)
def basic_cli(
Expand Down Expand Up @@ -420,6 +501,10 @@ def thing_service_texts(request: SubRequest) -> list[ProtoFile]:
partial(RawFixture, package="thing_service", grpc=True),
id="thing_service_raw",
),
pytest.param(
partial(RawProtocFixture, package="thing_service", grpc=True),
id="thing_service_raw_protoc",
),
]
)
def thing_service(
Expand Down Expand Up @@ -475,6 +560,9 @@ def nested_texts() -> list[ProtoFile]:
partial(GrpcIoToolsFixture, package="nested"), id="nested_grpc_io_tools"
),
pytest.param(partial(RawFixture, package="nested"), id="nested_raw"),
pytest.param(
partial(RawProtocFixture, package="nested"), id="nested_raw_protoc"
),
]
)
def nested(
Expand Down Expand Up @@ -553,6 +641,16 @@ def no_imports_service_texts(request: SubRequest) -> list[ProtoFile]:
),
id="no_imports_service_raw",
),
pytest.param(
partial(
RawProtocFixture,
package="no_imports_service",
grpc=True,
mypy=True,
mypy_grpc=True,
),
id="no_imports_service_raw_protoc",
),
]
)
def no_imports_service(
Expand Down Expand Up @@ -652,6 +750,16 @@ def imports_service_texts(request: SubRequest) -> list[ProtoFile]:
),
id="imports_service_raw",
),
pytest.param(
partial(
RawProtocFixture,
package="imports_service",
grpc=True,
mypy=True,
mypy_grpc=True,
),
id="imports_service_raw_protoc",
),
]
)
def grpc_imports(
Expand Down Expand Up @@ -708,6 +816,10 @@ def long_names_texts() -> list[ProtoFile]:
partial(RawFixture, package="long_names", mypy=True),
id="long_names_raw",
),
pytest.param(
partial(RawProtocFixture, package="long_names", mypy=True),
id="long_names_raw_protoc",
),
]
)
def long_names(
Expand Down Expand Up @@ -763,6 +875,10 @@ def ignored_import_texts(request: SubRequest) -> list[ProtoFile]:
partial(RawFixture, package="ignored_imports"),
id="ignored_imports_raw",
),
pytest.param(
partial(RawProtocFixture, package="ignored_imports"),
id="ignored_imports_raw_protoc",
),
]
)
def ignored_imports(
Expand Down

0 comments on commit 2a0e96a

Please sign in to comment.