Skip to content

Commit

Permalink
Update FieldTransformation syntax
Browse files Browse the repository at this point in the history
Now field_transformations.when matches on FeildDescriptor
and field_transformation.set accepts FieldOptions,
however it only allows setting ScalaPB options for the time
being.

For #1007
  • Loading branch information
thesamet committed Jan 20, 2021
1 parent 811b978 commit 50f75bd
Show file tree
Hide file tree
Showing 8 changed files with 447 additions and 218 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package scalapb.compiler

import com.google.protobuf.Message
import scalapb.options.Scalapb
import scalapb.options.Scalapb.FieldTransformation.MatchType
import com.google.protobuf.Descriptors.FieldDescriptor
import com.google.protobuf.Descriptors.FileDescriptor
import scala.jdk.CollectionConverters._
import com.google.protobuf.DescriptorProtos.FieldOptions
import com.google.protobuf.DescriptorProtos.{FieldOptions, FieldDescriptorProto}
import scalapb.options.Scalapb.FieldTransformation
import com.google.protobuf.DynamicMessage
import com.google.protobuf.Descriptors.FieldDescriptor.Type
Expand All @@ -18,59 +19,125 @@ import java.util.regex.Matcher
private[compiler] case class ResolvedFieldTransformation(
whenFields: Map[FieldDescriptor, Any],
set: scalapb.options.Scalapb.FieldOptions,
matchType: MatchType
matchType: MatchType,
extensions: Set[FieldDescriptor]
)

private[compiler] case class ExtensionResolutionContext(
currentFile: String,
extensions: Set[FieldDescriptor]
)

private[compiler] object ResolvedFieldTransformation {
def apply(
currentFile: String,
ft: FieldTransformation,
extensions: Set[FieldDescriptor]
file: FileDescriptor,
ft: FieldTransformation
): ResolvedFieldTransformation = {
val context =
ExtensionResolutionContext(
file.getFullName(),
FieldTransformations.fieldExtensionsForFile(file)
)
if (!ft.getSet().getAllFields().keySet().asScala.subsetOf(Set(Scalapb.field.getDescriptor()))
|| !ft
.getSet()
.getUnknownFields()
.asMap()
.keySet()
.asScala
.subsetOf(Set(Scalapb.field.getNumber()))) {
throw new GeneratorException(
s"${file.getFullName}: FieldTransformation.set must contain only [scalapb.field] field"
)
}
ResolvedFieldTransformation(
FieldTransformations.fieldMap(
currentFile,
FieldOptions.parseFrom(ft.getWhen.toByteArray()),
extensions = extensions
FieldDescriptorProto.parseFrom(ft.getWhen.toByteArray()),
context = context
),
ft.getSet(),
ft.getMatchType()
ft.getSet().getExtension(Scalapb.field),
ft.getMatchType(),
context.extensions
)
}
}

private[compiler] object FieldTransformations {
def matches[T <: Message](
def matches(
currentFile: String,
input: Map[FieldDescriptor, Any],
input: FieldDescriptorProto,
transformation: ResolvedFieldTransformation
): Boolean = {
transformation.matchType match {
case MatchType.CONTAINS =>
matchContains(currentFile, input, transformation.whenFields)
case MatchType.EXACT => input == transformation.whenFields
case MatchType.PRESENCE => matchPresence(input, transformation.whenFields)
matchContains(
input,
transformation.whenFields,
ExtensionResolutionContext(currentFile, transformation.extensions)
)
case MatchType.EXACT => input == transformation.whenFields
case MatchType.PRESENCE =>
matchPresence(
input,
transformation.whenFields,
ExtensionResolutionContext(currentFile, transformation.extensions)
)
}
}

def matchContains(
currentFile: String,
input: Map[FieldDescriptor, Any],
pattern: Map[FieldDescriptor, Any]
input: Message,
pattern: Map[FieldDescriptor, Any],
context: ExtensionResolutionContext
): Boolean = {
pattern.forall { case (fd, v) =>
input.get(fd) match {
case None => false
case Some(u) =>
pattern.forall {
case (fd, v) =>
if (!fd.isExtension()) {
if (fd.getType() != Type.MESSAGE)
u == v
else
matchContains(
currentFile,
fieldMap(currentFile, u.asInstanceOf[Message], Set.empty),
fieldMap(currentFile, v.asInstanceOf[Message], Set.empty)
input.hasField(fd) && input.getField(fd) == v
else {
input.hasField(fd) && matchContains(
input.getField(fd).asInstanceOf[Message],
v.asInstanceOf[Map[FieldDescriptor, Any]],
context
)
}
}
} else {
input.getUnknownFields().hasField(fd.getNumber) &&
matchContains(
getExtensionField(input, fd),
v.asInstanceOf[Map[FieldDescriptor, Any]],
context
)
}
}
}

def matchPresence(
input: Message,
pattern: Map[FieldDescriptor, Any],
context: ExtensionResolutionContext
): Boolean = {
pattern.forall {
case (fd, v) =>
if (!fd.isExtension()) {
if (fd.getType() != Type.MESSAGE)
input.hasField(fd)
else {
input.hasField(fd) && matchPresence(
input.getField(fd).asInstanceOf[Message],
v.asInstanceOf[Map[FieldDescriptor, Any]],
context
)
}
} else {
input.getUnknownFields().hasField(fd.getNumber) &&
matchPresence(
getExtensionField(input, fd),
v.asInstanceOf[Map[FieldDescriptor, Any]],
context
)
}
}
}

Expand All @@ -91,14 +158,14 @@ private[compiler] object FieldTransformations {

def processField(fd: FieldDescriptor): Seq[AuxFieldOptions] =
if (transforms.nonEmpty) {
val noReg = FieldOptions.parseFrom(fd.getOptions().toByteArray())
val input = fieldMap(f.getFullName(), noReg, extensions)
val noReg = FieldDescriptorProto.parseFrom(fd.toProto().toByteArray())
val context = ExtensionResolutionContext(f.getFullName(), extensions)
transforms.flatMap { transform =>
if (matches(f.getFullName(), input, transform))
if (matches(f.getFullName(), noReg, transform))
Seq(
AuxFieldOptions.newBuilder
.setTarget(fd.getFullName())
.setOptions(interpolateStrings(transform.set, fd.getOptions(), extensions))
.setOptions(interpolateStrings(transform.set, fd.toProto(), context))
.build
)
else Seq.empty
Expand All @@ -107,25 +174,6 @@ private[compiler] object FieldTransformations {
processFile
}

def matchPresence(
input: Map[FieldDescriptor, Any],
pattern: Map[FieldDescriptor, Any]
): Boolean = {
pattern.forall { case (fd, value) =>
if (fd.isRepeated())
throw new GeneratorException(
"Presence matching on repeated fields is not supported"
)
else if (fd.getType() == Type.MESSAGE && input.contains(fd))
matchPresence(
input(fd).asInstanceOf[Message].getAllFields().asScala.toMap,
value.asInstanceOf[Message].getAllFields().asScala.toMap
)
else
input.contains(fd)
}
}

def fieldExtensionsForFile(f: FileDescriptor): Set[FieldDescriptor] = {
(f.getExtensions()
.asScala
Expand All @@ -140,28 +188,42 @@ private[compiler] object FieldTransformations {
// Like m.getAllFields(), but also resolves unknown fields from extensions available in the scope
// of the message.
def fieldMap(
currentFile: String,
m: Message,
extensions: Set[FieldDescriptor]
context: ExtensionResolutionContext
): Map[FieldDescriptor, Any] = {
val unknownFields = for {
number <- m.getUnknownFields().asMap().keySet().asScala
} yield {
val ext = extensions
val ext = context.extensions
.find(_.getNumber == number)
.getOrElse(
throw new GeneratorException(
s"$currentFile: Could not find extension number $number when processing a field " +
s"${context.currentFile}: Could not find extension number $number when processing a field " +
"transformation. A proto file defining this extension needs to be imported directly or transitively in this file."
)
)
ext -> getExtensionField(m, ext)
ext -> fieldMap(getExtensionField(m, ext), context)
}

unknownFields.toMap ++ m.getAllFields().asScala
val knownFields = m.getAllFields().asScala.map {
case (field, value) =>
if (field.getType() == Type.MESSAGE && !field.isOptional()) {
throw new GeneratorException(
s"${context.currentFile}: matching is supported only for scalar types and optional message fields."
)
}
(field -> (if (field.getType() == Type.MESSAGE)
fieldMap(value.asInstanceOf[Message], context)
else value))
}

unknownFields.toMap ++ knownFields
}

def getExtensionField(m: Message, ext: FieldDescriptor): Message = {
def getExtensionField(
m: Message,
ext: FieldDescriptor
): Message = {
if (ext.getType != Type.MESSAGE || !ext.isOptional) {
throw new GeneratorException(
s"Unknown extension fields must be optional message types: ${ext}"
Expand Down Expand Up @@ -198,11 +260,11 @@ private[compiler] object FieldTransformations {
def fieldByPath(
message: Message,
path: String,
extensions: Set[FieldDescriptor]
context: ExtensionResolutionContext
): String =
if (path.isEmpty()) throw new GeneratorException("Got an empty path")
else
fieldByPath(message, splitPath(path), path, extensions) match {
fieldByPath(message, splitPath(path), path, context) match {
case Left(error) => throw new GeneratorException(error)
case Right(value) => value
}
Expand All @@ -211,31 +273,29 @@ private[compiler] object FieldTransformations {
message: Message,
path: List[String],
allPath: String,
extensions: Set[FieldDescriptor]
context: ExtensionResolutionContext
): Either[String, String] = {
for {
fieldName <- path.headOption.toRight("Got an empty path")
fd <-
if (fieldName.startsWith("["))
extensions
.find(_.getFullName == fieldName.substring(1, fieldName.length() - 1))
.toRight(
s"Could not find extension $fieldName when resolving $allPath"
)
else
Option(message.getDescriptorForType().findFieldByName(fieldName))
.toRight(
s"Could not find field named $fieldName when resolving $allPath"
)
_ <-
if (fd.isRepeated()) Left("Repeated fields are not supported")
else Right(())
fd <- if (fieldName.startsWith("["))
context.extensions
.find(_.getFullName == fieldName.substring(1, fieldName.length() - 1))
.toRight(
s"Could not find extension $fieldName when resolving $allPath"
)
else
Option(message.getDescriptorForType().findFieldByName(fieldName))
.toRight(
s"Could not find field named $fieldName when resolving $allPath"
)
_ <- if (fd.isRepeated()) Left("Repeated fields are not supported")
else Right(())
v = if (fd.isExtension) getExtensionField(message, fd) else message.getField(fd)
res <- path match {
case _ :: Nil => Right(v.toString())
case _ :: tail =>
if (fd.getType() == Type.MESSAGE)
fieldByPath(v.asInstanceOf[Message], tail, allPath, extensions)
fieldByPath(v.asInstanceOf[Message], tail, allPath, context)
else
Left(
s"Type ${fd.getType.toString} does not have a field ${tail.head} in $allPath"
Expand All @@ -250,28 +310,28 @@ private[compiler] object FieldTransformations {
private[compiler] def interpolateStrings[T <: Message](
msg: T,
data: Message,
extensions: Set[FieldDescriptor]
context: ExtensionResolutionContext
): T = {
val b = msg.toBuilder()
for {
(field, value) <- msg.getAllFields().asScala
} field.getType() match {
case Type.STRING if (!field.isRepeated()) =>
b.setField(field, interpolate(value.asInstanceOf[String], data, extensions))
b.setField(field, interpolate(value.asInstanceOf[String], data, context))
case Type.MESSAGE =>
if (field.isRepeated())
b.setField(
field,
value
.asInstanceOf[java.util.List[Message]]
.asScala
.map(interpolateStrings(_, data, extensions))
.map(interpolateStrings(_, data, context))
.asJava
)
else
b.setField(
field,
interpolateStrings(value.asInstanceOf[Message], data, extensions)
interpolateStrings(value.asInstanceOf[Message], data, context)
)
case _ =>
}
Expand All @@ -285,9 +345,9 @@ private[compiler] object FieldTransformations {
private[compiler] def interpolate(
value: String,
data: Message,
extensions: Set[FieldDescriptor]
context: ExtensionResolutionContext
): String =
replaceAll(value, FieldPath, m => fieldByPath(data, m.group(1), extensions))
replaceAll(value, FieldPath, m => fieldByPath(data, m.group(1), context))

// Matcher.replaceAll appeared on Java 9, so we have this Java 8 compatible version instead. Adapted
// from https://stackoverflow.com/a/43372206/97524
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ object FileOptionsCache {
.asScala
.map(t =>
ResolvedFieldTransformation(
file.getFullName(),
t,
FieldTransformations.fieldExtensionsForFile(file)
file,
t
)
)
.toSeq
Expand Down
Loading

0 comments on commit 50f75bd

Please sign in to comment.