Skip to content

Commit

Permalink
feat: add protogen.test module
Browse files Browse the repository at this point in the history
  • Loading branch information
fischor committed Mar 14, 2023
1 parent 4e754dd commit 0198bf8
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 4 deletions.
66 changes: 62 additions & 4 deletions protogen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def generate(gen: protogen.Plugin):
import enum
import keyword
import sys
from typing import BinaryIO, Callable, Dict, List, Optional, Set
from typing import BinaryIO, Callable, Dict, List, Optional, Set, Tuple
from operator import ior
from functools import reduce

Expand Down Expand Up @@ -1350,6 +1350,59 @@ def _indent(s: str, width: int) -> str:
return "\n".join(prefix_lines)


class CodeGeneratorResponse:
"""A code generator response.
This is the ``protogen`` equivalent to a protobuf CodeGeneratorResponse.
Attributes
----------
proto : google.protobuf.descriptor_pb2.CodeGeneratorResponse
The raw CodeGeneratorResponse.
"""

def __init__(
self, proto: google.protobuf.compiler.plugin_pb2.CodeGeneratorResponse
) -> None:
self.proto = proto

def has_file(self, filename: str) -> bool:
"""Checks if a file in the CodeGeneratorResponse.
Arguments
---------
filename : str
Name of the file to check.
Returns
-------
ok: bool
`True`, if the file is contained in the response, `False` otherwise.
"""
return any([file.name == filename for file in self.proto.file])

def file_content(self, filename) -> Tuple[str, bool]:
"""Returns the content of a file from the CodeGeneratorResponse.
Arguments
---------
filename : str
Name of the file to get the content for.
Returns
-------
content: Tuple[bool, str]
Returns `True` and the content of the file if a file with that name
exists in the CodeGeneratorResponse. Otherwise `False` and the empty
string is returned.
"""
for file in self.proto.file:
if file.name == filename:
return file.content, True

return "", False


class GeneratedFile:
"""An output buffer to write generated code to.
Expand Down Expand Up @@ -1670,10 +1723,10 @@ def __init__(
to use :func:`default_py_import_func`.
input : BinaryIO, optional
The input stream to read the CodeGeneratorRequest from. Defaults
to :attr:`sys.stdin.buffer`.
to :attr:`sys.stdin.buffer` if set as None.
output : BinaryIO, optional
The output stream to write the CodeGeneratorResponse to.
Defaults to :attr:`sys.stdout.buffer`.
Defaults to :attr:`sys.stdout.buffer` if set as None.
supported_features : List[str]
List of features that are supported by the plugin. This list will be
delegated to protoc via the CodeGeneratorresponse.supported_features
Expand All @@ -1683,7 +1736,11 @@ def __init__(
in the list.
"""
self._input = input
if input is None:
self._input = sys.stdin.buffer
self._output = output
if output is None:
self._output = sys.stdout.buffer
self._py_import_func = py_import_func
self._supported_features = supported_features

Expand Down Expand Up @@ -1763,5 +1820,6 @@ def run(self, f: Callable[[Plugin], None]):
# Write response.
resp = plugin._response()
if len(self._supported_features) > 0:
resp.supported_features = reduce(ior, self._supported_features)
resp.supported_features = reduce(ior, self._supported_features)
self._output.write(resp.SerializeToString())
self._output.flush()
136 changes: 136 additions & 0 deletions protogen/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""The test module helps writing tests protogen protoc plugins."""

from importlib import import_module
from typing import Dict, List, Tuple
import io
import os.path
import subprocess
import sys
import tempfile

import protogen

import google.protobuf.compiler.plugin_pb2
import google.protobuf.descriptor_pb2


def run_plugin(
proto_paths: List[str],
files_to_generate: List[str],
plugin: str,
additional_protoc_args: List[str] = ["--experimental_allow_proto3_optional"],
parameter: Dict[str, str] = {},
) -> protogen.CodeGeneratorResponse:
"""Run a protoc plugin python module.
Runs the protoc plugin module `plugin` within the current process (as
opposed to protoc which calls the plugin in a subprocess). This makes it
easy to attach debugger to the current process to debug the plugin.
Arguments
---------
proto_paths : List[str]
List of directories that act as proto paths (will be passed via
--proto_path/-I flag to protoc).
files_to_generate : List[str]
List of proto files to generate output for (will be passed as positional
arguments to protoc).
plugin : str
Python module name of the plugin to run. The plugin will be called and
must be loadable and executable via `importlib.import_module(plugin)`.
additional_protc_args : List[str]
Additional arguments that will be passed to protoc.
parameter : Dict[str, str]
Parameter that will be passed to the plugin.
Returns
-------
response: protogen.CodeGeneratorResponse
The response generated by the plugin.
"""

req = _prepare_code_generator_request(
proto_paths, files_to_generate, additional_protoc_args, parameter
)

# Open stdin and stdout for the plugin.
# We need SpoledTemporaryFile(mode="w+t")._file because its a TextIOWrapper
# which sys.stdin and sys.stdout also are.
fake_stdin = tempfile.SpooledTemporaryFile(mode="w+t")._file
fake_stdin.buffer.write(req.SerializeToString())
fake_stdin.flush()
fake_stdin.seek(0)
fake_stdout = tempfile.SpooledTemporaryFile(mode="w+t")._file

_stdin, sys.stdin = sys.stdin, fake_stdin
_stdout, sys.stdout = sys.stdout, fake_stdout

# Call the plugin under test.
import_module(plugin)

fake_stdout.seek(0)
resp = google.protobuf.compiler.plugin_pb2.CodeGeneratorResponse.FromString(
fake_stdout.buffer.read()
)

fake_stdin.close() # will remove tmp files
fake_stdout.close()

# Reset stdin and stdout.
sys.stdin = _stdin
sys.stdout = _stdout

return protogen.CodeGeneratorResponse(resp)


def _prepare_code_generator_request(
proto_paths: List[str],
files_to_generate: List[str],
additional_protoc_args: List[str],
parameter: Dict[str, str],
):
req = google.protobuf.compiler.plugin_pb2.CodeGeneratorRequest(
file_to_generate=files_to_generate,
parameter=",".join([f"{k}={v}" for (k, v) in parameter.items()]),
proto_file=[],
compiler_version=google.protobuf.compiler.plugin_pb2.Version(
major=1, minor=2, patch=3, suffix=""
),
)

with tempfile.TemporaryDirectory() as tmpdirname:
f = os.path.join(tmpdirname, "descriptor_set.pb")

cmd = ["protoc"]
for proto_path in proto_paths:
cmd.extend(["-I", proto_path])
cmd.append(f"--descriptor_set_out={f}")
cmd.append("--include_imports")
cmd.append("--include_source_info")
cmd.extend(additional_protoc_args)
cmd.extend(files_to_generate)

code, output = _run_protoc(cmd)
if code != 0:
raise Exception(output)

ff = io.open(f, "rb")
desc_set = google.protobuf.descriptor_pb2.FileDescriptorSet.FromString(
ff.read()
)
req.proto_file.extend(desc_set.file)

return req


def _run_protoc(args: List[str]) -> Tuple[int, str]:
proc = subprocess.Popen(
args, text=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE
)
# Executed from current directory (repo root)
code = proc.wait()
if code == 0:
output = proc.stdout.read()
else:
output = proc.stderr.read()
return code, output
22 changes: 22 additions & 0 deletions test/test_run_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from protogen.test import run_plugin


def test_run_plugin():
x = 1
resp = run_plugin(
proto_paths=["test/vendor"],
files_to_generate=["google/api/annotations.proto", "google/api/http.proto"],
plugin="test.plugin.main",
)

assert len(resp.file) == 2


def test_run_plugin_with_proto3_optionals():
resp = run_plugin(
proto_paths=["test/optional"],
files_to_generate=["optional.proto"],
plugin="test.plugin.main",
)

assert len(resp.file) == 1

0 comments on commit 0198bf8

Please # to comment.