From a5c02c29f809df3cb48e825a6da3c7c555504cc5 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 20 Apr 2020 20:23:31 +0100 Subject: [PATCH] Pass handles as fields in Protobuf messages. Also includes: - Custom derive macro to visit all handles in a struct/enum (crate oak_derive) - New example "injection" to showcase passing handles around. --- BUILD | 35 -- examples/Cargo.lock | 20 ++ examples/Cargo.toml | 2 + examples/chat/module/rust/build.rs | 2 +- examples/chat/module/rust/src/backend.rs | 21 +- examples/chat/module/rust/src/command.rs | 65 ---- examples/chat/module/rust/src/lib.rs | 21 +- examples/chat/proto/BUILD | 6 +- examples/chat/proto/chat.proto | 15 + examples/injection/client/BUILD | 34 ++ examples/injection/client/injection.cc | 77 +++++ examples/injection/config/BUILD | 31 ++ examples/injection/config/config.toml | 8 + examples/injection/module/rust/Cargo.toml | 16 + examples/injection/module/rust/build.rs | 22 ++ examples/injection/module/rust/src/lib.rs | 289 ++++++++++++++++ examples/injection/proto/BUILD | 46 +++ examples/injection/proto/injection.proto | 92 +++++ oak/proto/BUILD | 6 + oak/proto/handle.proto | 24 ++ oak_abi/build.rs | 5 +- oak_utils/src/lib.rs | 26 ++ sdk/Cargo.lock | 11 + sdk/Cargo.toml | 3 +- sdk/rust/oak/Cargo.toml | 1 + sdk/rust/oak/build.rs | 20 +- sdk/rust/oak/src/handle.rs | 233 +++++++++++++ sdk/rust/oak/src/io/decodable.rs | 9 +- sdk/rust/oak/src/io/encodable.rs | 8 +- sdk/rust/oak/src/io/receiver.rs | 55 +++ sdk/rust/oak/src/io/sender.rs | 55 +++ sdk/rust/oak/src/lib.rs | 1 + .../oak/tests/handle_extract_inject.proto | 52 +++ sdk/rust/oak/tests/handle_extract_inject.rs | 314 ++++++++++++++++++ sdk/rust/oak_derive/Cargo.toml | 16 + sdk/rust/oak_derive/src/lib.rs | 122 +++++++ sdk/rust/oak_derive/tests/visit.rs | 127 +++++++ 37 files changed, 1756 insertions(+), 134 deletions(-) delete mode 100644 BUILD delete mode 100644 examples/chat/module/rust/src/command.rs create mode 100644 examples/injection/client/BUILD create mode 100644 examples/injection/client/injection.cc create mode 100644 examples/injection/config/BUILD create mode 100644 examples/injection/config/config.toml create mode 100644 examples/injection/module/rust/Cargo.toml create mode 100644 examples/injection/module/rust/build.rs create mode 100644 examples/injection/module/rust/src/lib.rs create mode 100644 examples/injection/proto/BUILD create mode 100644 examples/injection/proto/injection.proto create mode 100644 oak/proto/handle.proto create mode 100644 sdk/rust/oak/src/handle.rs create mode 100644 sdk/rust/oak/tests/handle_extract_inject.proto create mode 100644 sdk/rust/oak/tests/handle_extract_inject.rs create mode 100644 sdk/rust/oak_derive/Cargo.toml create mode 100644 sdk/rust/oak_derive/src/lib.rs create mode 100644 sdk/rust/oak_derive/tests/visit.rs diff --git a/BUILD b/BUILD deleted file mode 100644 index 0443acbd5cd..00000000000 --- a/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -# -# Copyright 2019 The Project Oak Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# An empty BUILD file in the project root is required for `bazel-gazelle` that is -# loaded by `rules_docker`: -# https://github.com/bazelbuild/bazel-gazelle/issues/609 - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -# Export LICENSE file for projects that reference Oak in Bazel as an external dependency. -exports_files(["LICENSE"]) - -# These files are built via cargo outside of Bazel. -exports_files(srcs = glob(["target/x86_64-unknown-linux-musl/release/*oak_loader"])) - -exports_files(srcs = glob(["examples/target/wasm32-unknown-unknown/release/*.wasm"])) - -# These files are necessary for the backend server in the Aggregator example application. -exports_files(srcs = glob(["examples/target/x86_64-unknown-linux-gnu/release/aggregator_*"])) diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 5bdd354efbc..f5f0eaa4a5f 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -786,6 +786,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "injection" +version = "0.1.0" +dependencies = [ + "log", + "oak", + "oak_utils", + "prost", +] + [[package]] name = "iovec" version = "0.1.4" @@ -1099,6 +1109,7 @@ dependencies = [ "fmt", "log", "oak_abi", + "oak_derive", "oak_utils", "prost", "prost-types", @@ -1117,6 +1128,15 @@ dependencies = [ "prost-types", ] +[[package]] +name = "oak_derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "oak_runtime" version = "0.1.0" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 519ab633621..301793c7568 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -12,6 +12,7 @@ members = [ "chat/module/rust", "hello_world/grpc", "hello_world/module/rust", + "injection/module/rust", "machine_learning/module/rust", "trusted_information_retrieval/backend", "trusted_information_retrieval/client/rust", @@ -35,6 +36,7 @@ members = [ # Oak. oak = { path = "../sdk/rust/oak" } oak_abi = { path = "../oak_abi" } +oak_derive = { path = "../sdk/rust/oak_derive" } oak_runtime = { path = "../oak/server/rust/oak_runtime" } oak_tests = { path = "../sdk/rust/oak_tests" } oak_utils = { path = "../oak_utils" } diff --git a/examples/chat/module/rust/build.rs b/examples/chat/module/rust/build.rs index 9a0135101a8..e231723c346 100644 --- a/examples/chat/module/rust/build.rs +++ b/examples/chat/module/rust/build.rs @@ -17,6 +17,6 @@ fn main() { oak_utils::compile_protos( &["../../proto/chat.proto"], - &["../../proto", "../../third_party"], + &["../../proto", "../../../../"], ); } diff --git a/examples/chat/module/rust/src/backend.rs b/examples/chat/module/rust/src/backend.rs index d487f2984e5..70071274595 100644 --- a/examples/chat/module/rust/src/backend.rs +++ b/examples/chat/module/rust/src/backend.rs @@ -14,10 +14,12 @@ // limitations under the License. // -use crate::{command::Command, proto::Message}; -use log::info; +use crate::proto::{ + command::Command::{JoinRoom, SendMessage}, + Command, Message, +}; +use log::{info, warn}; use oak::Node; -use prost::Message as _; oak::entrypoint!(backend_oak_main => |in_channel| { oak::logger::init_default(); @@ -32,16 +34,13 @@ struct Room { impl Node for Room { fn handle_command(&mut self, command: Command) -> Result<(), oak::OakError> { - match command { - Command::Join(h) => { - let sender = oak::io::Sender::new(h); + match command.command { + Some(JoinRoom(sender)) => { self.clients .push(oak::grpc::ChannelResponseWriter::new(sender)); Ok(()) } - Command::SendMessage(message_bytes) => { - let message = Message::decode(message_bytes.as_slice()) - .expect("could not parse message from bytes"); + Some(SendMessage(message)) => { self.messages.push(message.clone()); info!("fan out message to {} clients", self.clients.len()); for writer in &mut self.clients { @@ -52,6 +51,10 @@ impl Node for Room { } Ok(()) } + None => { + warn!("Received empty command"); + Err(oak::OakError::OakStatus(oak::OakStatus::ErrInvalidArgs)) + } } } } diff --git a/examples/chat/module/rust/src/command.rs b/examples/chat/module/rust/src/command.rs deleted file mode 100644 index 654a19aed3d..00000000000 --- a/examples/chat/module/rust/src/command.rs +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright 2019 The Project Oak Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize)] -pub enum Command { - Join(oak::WriteHandle), - // TODO(#426): Embed native Message struct here when we support proto - // serialization via Serde. See - // https://github.com/stepancheg/rust-protobuf#serde_derive-support - SendMessage(Vec), -} - -// TODO(#389): Automatically generate this code. -// -// Currently we use [bincode](https://github.com/servo/bincode) to serialize data together with a -// tag that allows to reconstruct the enum variant on the other side. We then send the tag+data as -// bytes, and separately we send the handles, which we have to manually re-assemble on the other -// side. -// -// FIDL implements something similar to this in -// https://fuchsia.googlesource.com/fuchsia/+/refs/heads/master/garnet/public/lib/fidl/rust/fidl/src/encoding.rs. -impl oak::io::Encodable for Command { - fn encode(&self) -> Result { - // TODO(#746): Propagate more details about the source error. - let bytes = bincode::serialize(self).map_err(|_| oak::OakStatus::ErrInvalidArgs)?; - // Serialize handles separately. - let handles = match self { - Command::Join(h) => vec![h.handle], - Command::SendMessage(_) => vec![], - }; - Ok(oak::io::Message { bytes, handles }) - } -} - -// TODO(#389): Automatically generate this code. -impl oak::io::Decodable for Command { - fn decode(message: &oak::io::Message) -> Result { - // TODO(#746): Propagate more details about the source error. - let command: Command = - bincode::deserialize(&message.bytes).map_err(|_| oak::OakStatus::ErrInvalidArgs)?; - // Restore handles in the received message. - let command = match command { - Command::Join(_) => Command::Join(oak::WriteHandle { - handle: message.handles[0], - }), - Command::SendMessage(message_bytes) => Command::SendMessage(message_bytes), - }; - Ok(command) - } -} diff --git a/examples/chat/module/rust/src/lib.rs b/examples/chat/module/rust/src/lib.rs index a34ee3685e3..a1d1ba683b5 100644 --- a/examples/chat/module/rust/src/lib.rs +++ b/examples/chat/module/rust/src/lib.rs @@ -14,18 +14,16 @@ // limitations under the License. // -use command::Command; use log::info; -use oak::grpc; -use prost::Message; +use oak::{grpc, io::Sender}; use proto::{ - Chat, ChatDispatcher, CreateRoomRequest, DestroyRoomRequest, SendMessageRequest, + command::Command::{JoinRoom, SendMessage}, + Chat, ChatDispatcher, Command, CreateRoomRequest, DestroyRoomRequest, SendMessageRequest, SubscribeRequest, }; use std::collections::{hash_map::Entry, HashMap}; mod backend; -mod command; mod proto { include!(concat!(env!("OUT_DIR"), "/oak.examples.chat.rs")); } @@ -126,7 +124,9 @@ impl Chat for Node { } Some(room) => { info!("new subscription to room {:?}", req.room_id); - let command = Command::Join(writer.handle()); + let command = Command { + command: Some(JoinRoom(Sender::new(writer.handle()))), + }; room.sender .send(&command) .expect("could not send command to room Node"); @@ -140,12 +140,9 @@ impl Chat for Node { None => room_id_not_found_err(), Some(room) => { info!("new message to room {:?}", req.room_id); - let mut message_bytes = Vec::new(); - req.message - .unwrap_or_default() - .encode(&mut message_bytes) - .expect("could not convert message to bytes"); - let command = Command::SendMessage(message_bytes); + let command = Command { + command: req.message.map(SendMessage), + }; room.sender .send(&command) .expect("could not send command to room Node"); diff --git a/examples/chat/proto/BUILD b/examples/chat/proto/BUILD index 1a220f0fb36..57a315e0995 100644 --- a/examples/chat/proto/BUILD +++ b/examples/chat/proto/BUILD @@ -27,13 +27,17 @@ proto_library( name = "chat_proto", srcs = ["chat.proto"], deps = [ + "//oak/proto:grpc_encap_proto", + "//oak/proto:handle_proto", "@com_google_protobuf//:empty_proto", ], ) cc_proto_library( name = "chat_cc_proto", - deps = [":chat_proto"], + deps = [ + ":chat_proto", + ], ) cc_grpc_library( diff --git a/examples/chat/proto/chat.proto b/examples/chat/proto/chat.proto index 1dfcdc332aa..bc138cfb97b 100644 --- a/examples/chat/proto/chat.proto +++ b/examples/chat/proto/chat.proto @@ -19,6 +19,8 @@ syntax = "proto3"; package oak.examples.chat; import "google/protobuf/empty.proto"; +import "oak/proto/handle.proto"; +import "oak/proto/grpc_encap.proto"; message CreateRoomRequest { // ID used to identify the room; knowledge of this value allows entry to the room. The client @@ -66,3 +68,16 @@ service Chat { // Send a message to a chat room. rpc SendMessage(SendMessageRequest) returns (google.protobuf.Empty); } + +// Command sent to room nodes. +// +// This message is only used for inter-node communication, it is not exposed through gRPC. +message Command { + oneof command { + // Sent when a new subscriber joins the room. + oak.handle.Sender join_room = 1 [(oak.handle.message_type) = ".oak.encap.GrpcResponse"]; + + // Command to send a message to the room (and thus to all subscribers). + Message send_message = 2; + } +} diff --git a/examples/injection/client/BUILD b/examples/injection/client/BUILD new file mode 100644 index 00000000000..57c5f0a2971 --- /dev/null +++ b/examples/injection/client/BUILD @@ -0,0 +1,34 @@ +# +# Copyright 2019 The Project Oak Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@rules_cc//cc:defs.bzl", "cc_binary") + +package( + licenses = ["notice"], +) + +cc_binary( + name = "client", + srcs = ["injection.cc"], + deps = [ + "//examples/injection/proto:injection_cc_grpc", + "//oak/client:application_client", + "@com_github_google_glog//:glog", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + ], +) diff --git a/examples/injection/client/injection.cc b/examples/injection/client/injection.cc new file mode 100644 index 00000000000..767fb68f850 --- /dev/null +++ b/examples/injection/client/injection.cc @@ -0,0 +1,77 @@ +/* + * Copyright 2019 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "examples/injection/proto/injection.grpc.pb.h" +#include "examples/injection/proto/injection.pb.h" +#include "glog/logging.h" +#include "include/grpcpp/grpcpp.h" +#include "oak/client/application_client.h" +#include "oak/common/label.h" + +ABSL_FLAG(std::string, address, "localhost:8080", "Address of the Oak application to connect to"); +ABSL_FLAG(std::string, ca_cert, "", "Path to the PEM-encoded CA root certificate"); + +using ::oak::examples::injection::BlobResponse; +using ::oak::examples::injection::BlobStore; +using ::oak::examples::injection::GetBlobRequest; +using ::oak::examples::injection::PutBlobRequest; + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + + std::string address = absl::GetFlag(FLAGS_address); + std::string ca_cert = oak::ApplicationClient::LoadRootCert(absl::GetFlag(FLAGS_ca_cert)); + LOG(INFO) << "Connecting to Oak Application: " << address; + + // TODO(#1066): Use a more restrictive Label. + oak::label::Label label = oak::PublicUntrustedLabel(); + // Connect to the Oak Application. + auto stub = BlobStore::NewStub(oak::ApplicationClient::CreateChannel( + address, oak::ApplicationClient::GetTlsChannelCredentials(ca_cert), label)); + if (stub == nullptr) { + LOG(FATAL) << "Failed to create application stub"; + } + + PutBlobRequest putRequest; + putRequest.set_blob("Hello, blob store!"); + grpc::ClientContext putContext; + BlobResponse putResponse; + grpc::Status putStatus = stub->PutBlob(&putContext, putRequest, &putResponse); + if (!putStatus.ok()) { + LOG(FATAL) << "PutBlob failed: " << putStatus.error_code() << ": " << putStatus.error_message(); + } + LOG(INFO) << "Blob stored at id: " << putResponse.id(); + + GetBlobRequest getRequest; + getRequest.set_id(putResponse.id()); + grpc::ClientContext getContext; + BlobResponse getResponse; + grpc::Status getStatus = stub->GetBlob(&getContext, getRequest, &getResponse); + if (!getStatus.ok()) { + LOG(FATAL) << "GetBlob failed: " << getStatus.error_code() << ": " << getStatus.error_message(); + } + LOG(INFO) << "Successfully retrieved Blob"; + + if (putRequest.blob() != getResponse.blob()) { + LOG(FATAL) << "Blobs were different. Original: '" << putRequest.blob() << "', retrieved: '" + << getResponse.blob() << "'"; + } + LOG(INFO) << "Blobs match!"; + + return EXIT_SUCCESS; +} diff --git a/examples/injection/config/BUILD b/examples/injection/config/BUILD new file mode 100644 index 00000000000..1b97e7f233f --- /dev/null +++ b/examples/injection/config/BUILD @@ -0,0 +1,31 @@ +# +# Copyright 2020 The Project Oak Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("//oak/common:app_config.bzl", "serialized_config") + +package( + licenses = ["notice"], +) + +exports_files(srcs = glob(["*.textproto"])) + +serialized_config( + name = "config", + modules = { + "app": "//:examples/target/wasm32-unknown-unknown/release/injection.wasm", + }, + textproto = ":config.textproto", +) diff --git a/examples/injection/config/config.toml b/examples/injection/config/config.toml new file mode 100644 index 00000000000..5cc69dc09ec --- /dev/null +++ b/examples/injection/config/config.toml @@ -0,0 +1,8 @@ +name = "injection" + +[modules] +app = { path = "examples/injection/bin/injection.wasm" } + +[initial_node_configuration] +wasm_module_name = "app" +wasm_entrypoint_name = "grpc_fe" diff --git a/examples/injection/module/rust/Cargo.toml b/examples/injection/module/rust/Cargo.toml new file mode 100644 index 00000000000..58599f38903 --- /dev/null +++ b/examples/injection/module/rust/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "injection" +version = "0.1.0" +authors = ["Daan de Graaf "] +edition = "2018" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +log = "*" +oak = "=0.1.0" +prost = "*" + +[build-dependencies] +oak_utils = "*" diff --git a/examples/injection/module/rust/build.rs b/examples/injection/module/rust/build.rs new file mode 100644 index 00000000000..fea5bfd91df --- /dev/null +++ b/examples/injection/module/rust/build.rs @@ -0,0 +1,22 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +fn main() { + oak_utils::compile_protos( + &["../../proto/injection.proto"], + &["../../proto", "../../../../"], + ); +} diff --git a/examples/injection/module/rust/src/lib.rs b/examples/injection/module/rust/src/lib.rs new file mode 100644 index 00000000000..91007d05e65 --- /dev/null +++ b/examples/injection/module/rust/src/lib.rs @@ -0,0 +1,289 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//! Example module showcasing handle passing inside of Protocol Buffers. +//! +//! The following diagram illustrates how the system stores client blobs: +//! +//! ```text +//! Client +//! /\ (1) \/ PutBlob request +//! || (7) /\ BlobResponse (from (6)) +//! \/ +//! gRPC service +//! /\ +//! || Proxies messages to and from the frontend +//! \/ +//! BlobStore frontend +//! /\ /\ (2) \/ Request a BlobStoreInterface +//! || || (4) /\ Return BlobStoreInterface to frontend +//! || \/ +//! || BlobStore provider +//! || | +//! || | (3) Start a new BlobStore impl +//! || | +//! || | (5) \/ PutBlob request (from (1)) +//! || | (6) /\ BlobResponse +//! \/ | +//! BlobStore impl +//! ``` +//! +//! Note that this will only ever create one BlobStore impl, as the handles to the store are +//! cached by the frontend. Subsequent requests to store or receive blobs are forwarded to the +//! existing BlobStore by the BlobStore frontend: +//! +//! ```text +//! Client +//! /\ (1) \/ GetBlob request +//! || (4) /\ BlobResponse (from (3)) +//! \/ +//! gRPC service +//! /\ +//! || Proxies messages to and from the frontend +//! \/ +//! BlobStore frontend +//! /\ +//! || (2) \/ GetBlob request (from (1)) +//! \/ (3) /\ BlobResponse +//! BlobStore impl +//! ``` + +mod proto { + include!(concat!(env!("OUT_DIR"), "/oak.examples.injection.rs")); +} + +use oak::{ + grpc, + io::{Receiver, Sender}, +}; + +use proto::{ + blob_request::Request, BlobRequest, BlobResponse, BlobStore, BlobStoreDispatcher, + BlobStoreInterface, BlobStoreProviderSender, BlobStoreRequest, BlobStoreSender, GetBlobRequest, + PutBlobRequest, +}; + +oak::entrypoint!(grpc_fe => |_in_channel| { + oak::logger::init_default(); + let (to_provider_write_handle, to_provider_read_handle) = oak::channel_create().unwrap(); + let (from_provider_write_handle, from_provider_read_handle) = oak::channel_create().unwrap(); + + oak::node_create(&oak::node_config::wasm("app", "provider"), to_provider_read_handle) + .expect("Failed to create provider"); + oak::channel_close(to_provider_read_handle.handle).expect("Failed to close channel"); + + Sender::new(to_provider_write_handle) + .send(&BlobStoreProviderSender { sender: Some(Sender::new(from_provider_write_handle)) }) + .expect("Failed to send handle to provider"); + oak::channel_close(from_provider_write_handle.handle).expect("Failed to close channel"); + + let frontend = BlobStoreFrontend::new( + Sender::new(to_provider_write_handle), + Receiver::new(from_provider_read_handle)); + let dispatcher = BlobStoreDispatcher::new(frontend); + let grpc_channel = oak::grpc::server::init("[::]:8080") + .expect("could not create gRPC server pseudo-Node"); + + oak::run_event_loop(dispatcher, grpc_channel); +}); + +oak::entrypoint!(provider => |frontend_read| { + oak::logger::init_default(); + let frontend_sender = + Receiver::::new(frontend_read).receive() + .expect("Did not receive a decodable message") + .sender + .expect("No sender in received message"); + oak::run_event_loop(BlobStoreProvider::new(frontend_sender), + Receiver::::new(frontend_read)); +}); + +oak::entrypoint!(store => |reader| { + oak::logger::init_default(); + let sender = + Receiver::::new(reader).receive() + .expect("Did not receive a write handle") + .sender + .expect("No write handle in received message"); + oak::run_event_loop(BlobStoreImpl::new(sender), + Receiver::::new(reader)); +}); + +enum BlobStoreAccess { + BlobStoreProvider { + sender: Sender, + receiver: Receiver, + }, + BlobStore(BlobStoreInterface), +} + +struct BlobStoreFrontend { + store: BlobStoreAccess, +} + +impl BlobStoreFrontend { + pub fn new( + sender: Sender, + receiver: Receiver, + ) -> BlobStoreFrontend { + BlobStoreFrontend { + store: BlobStoreAccess::BlobStoreProvider { sender, receiver }, + } + } + + fn get_interface(&mut self) -> &BlobStoreInterface { + // Make sure it is cached + if let BlobStoreAccess::BlobStoreProvider { sender, receiver } = &self.store { + sender + .send(&BlobStoreRequest {}) + .expect("Failed to send BlobStoreRequest"); + sender.close().expect("Failed to close sender"); + + let iface = receiver + .receive() + .expect("Failed to receive BlobStoreInterface"); + receiver.close().expect("Failed to close receiver"); + + self.store = BlobStoreAccess::BlobStore(iface); + }; + + match &self.store { + BlobStoreAccess::BlobStore(iface) => &iface, + _ => unreachable!(), + } + } + + fn send(&mut self, request: &BlobRequest) -> BlobResponse { + let iface = self.get_interface(); + iface + .sender + .as_ref() + .expect("No sender present on interface") + .send(request) + .expect("Could not send request"); + iface + .receiver + .as_ref() + .expect("No receiver present on interface") + .receive() + .expect("Could not receive response") + } +} + +impl BlobStore for BlobStoreFrontend { + fn get_blob(&mut self, request: GetBlobRequest) -> grpc::Result { + Ok(self.send(&BlobRequest { + request: Some(Request::Get(request)), + })) + } + + fn put_blob(&mut self, request: PutBlobRequest) -> grpc::Result { + Ok(self.send(&BlobRequest { + request: Some(Request::Put(request)), + })) + } +} + +struct BlobStoreProvider { + sender: Sender, +} + +impl BlobStoreProvider { + pub fn new(sender: Sender) -> BlobStoreProvider { + BlobStoreProvider { sender } + } +} + +impl oak::Node for BlobStoreProvider { + fn handle_command(&mut self, _command: BlobStoreRequest) -> Result<(), oak::OakError> { + // Create new BlobStore + let (to_store_write_handle, to_store_read_handle) = oak::channel_create().unwrap(); + let (from_store_write_handle, from_store_read_handle) = oak::channel_create().unwrap(); + oak::node_create( + &oak::node_config::wasm("app", "store"), + to_store_read_handle, + )?; + oak::channel_close(to_store_read_handle.handle).expect("Failed to close channel"); + + Sender::new(to_store_write_handle).send(&BlobStoreSender { + sender: Some(Sender::new(from_store_write_handle)), + })?; + oak::channel_close(from_store_write_handle.handle).expect("Failed to close channel"); + + self.sender.send(&BlobStoreInterface { + sender: Some(Sender::new(to_store_write_handle)), + receiver: Some(Receiver::new(from_store_read_handle)), + }) + } +} + +struct BlobStoreImpl { + sender: Sender, + blobs: Vec>, +} + +impl BlobStoreImpl { + pub fn new(sender: Sender) -> BlobStoreImpl { + BlobStoreImpl { + sender, + blobs: Vec::new(), + } + } + + fn get_blob(&mut self, request: GetBlobRequest) -> BlobResponse { + self.blobs + .get(blob_index(request.id)) + .map(|blob| BlobResponse { + blob: blob.clone(), + id: request.id, + }) + // Return the default instance if the blob was not found. + .unwrap_or_default() + } + + fn put_blob(&mut self, request: PutBlobRequest) -> BlobResponse { + if request.id == 0 { + // Insert a new blob + self.blobs.push(request.blob.clone()); + BlobResponse { + id: self.blobs.len() as u64, + blob: request.blob, + } + } else if let Some(blob) = self.blobs.get_mut(blob_index(request.id)) { + *blob = request.blob.clone(); + BlobResponse { + id: request.id, + blob: request.blob, + } + } else { + BlobResponse::default() + } + } +} + +fn blob_index(id: u64) -> usize { + (id - 1) as usize +} + +impl oak::Node for BlobStoreImpl { + fn handle_command(&mut self, request: BlobRequest) -> Result<(), oak::OakError> { + let response = match request.request { + Some(Request::Get(req)) => self.get_blob(req), + Some(Request::Put(req)) => self.put_blob(req), + None => panic!("No inner request"), + }; + self.sender.send(&response) + } +} diff --git a/examples/injection/proto/BUILD b/examples/injection/proto/BUILD new file mode 100644 index 00000000000..a146f7f35d8 --- /dev/null +++ b/examples/injection/proto/BUILD @@ -0,0 +1,46 @@ +# +# Copyright 2020 The Project Oak Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +proto_library( + name = "injection_proto", + srcs = ["injection.proto"], + deps = [ + "//oak/proto:handle_proto", + "@com_google_protobuf//:empty_proto", + ], +) + +cc_proto_library( + name = "injection_cc_proto", + deps = [":injection_proto"], +) + +cc_grpc_library( + name = "injection_cc_grpc", + srcs = ["injection_proto"], + grpc_only = True, + well_known_protos = True, + deps = ["injection_cc_proto"], +) diff --git a/examples/injection/proto/injection.proto b/examples/injection/proto/injection.proto new file mode 100644 index 00000000000..28a124e1a61 --- /dev/null +++ b/examples/injection/proto/injection.proto @@ -0,0 +1,92 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +import "oak/proto/handle.proto"; + +package oak.examples.injection; + +// ======================== +// gRPC messages / services +// ======================== + +// Request the contents of an existing blob in the store. +message GetBlobRequest { + fixed64 id = 1; +} + +// Request to insert or update a blob in the store. +message PutBlobRequest { + // Leave unset to store a new blob. + fixed64 id = 1; + bytes blob = 2; +} + +// Returned in response to a GetBlobRequest or PutBlobRequest. +message BlobResponse { + fixed64 id = 1; + bytes blob = 2; +} + +service BlobStore { + // Retrieve a stored blob + rpc GetBlob(GetBlobRequest) returns (BlobResponse); + // Insert or update a blob + rpc PutBlob(PutBlobRequest) returns (BlobResponse); +} + +// ============================================== +// Inter-node messages, not exposed through gRPC. +// ============================================== + +// Generic request to a BlobStore. +message BlobRequest { + oneof request { + GetBlobRequest get = 1; + PutBlobRequest put = 2; + } +} + +// Request handles to a blob store. +message BlobStoreRequest { +} + +// Response to a BlobStoreRequest +message BlobStoreResponse { + BlobStoreInterface interface = 1; +} + +// Container message for a sender that can write BlobStoreInterface messages. +message BlobStoreProviderSender { + oak.handle.Sender sender = 1 + [(oak.handle.message_type) = ".oak.examples.injection.BlobStoreInterface"]; +} + +// Container message for a sender that can write BlobResponse messages. +message BlobStoreSender { + oak.handle.Sender sender = 1 [(oak.handle.message_type) = ".oak.examples.injection.BlobResponse"]; +} + +// Handles to and from a BlobStore. +// +// This is equivalent to the BlobStore service definition, but is encoded as a sender and receiver. +// Eventually we should automatically generate code for this, but for now we do it by hand. +message BlobStoreInterface { + oak.handle.Sender sender = 1 [(oak.handle.message_type) = ".oak.examples.injection.BlobRequest"]; + oak.handle.Receiver receiver = 2 + [(oak.handle.message_type) = ".oak.examples.injection.BlobResponse"]; +} diff --git a/oak/proto/BUILD b/oak/proto/BUILD index 6a719e62ab8..3f6ef7f1eb2 100644 --- a/oak/proto/BUILD +++ b/oak/proto/BUILD @@ -137,3 +137,9 @@ cc_proto_library( name = "roughtime_service_cc_proto", deps = [":roughtime_service_proto"], ) + +proto_library( + name = "handle_proto", + srcs = ["handle.proto"], + deps = ["@com_google_protobuf//:descriptor_proto"], +) diff --git a/oak/proto/handle.proto b/oak/proto/handle.proto new file mode 100644 index 00000000000..39837b76201 --- /dev/null +++ b/oak/proto/handle.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package oak.handle; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.FieldOptions { + // Fully qualified path to the protobuf message sent or received from a `Sender` or `Receiver`. + string message_type = 79658; +} + +// Sender handle for an Oak channel. +// This type is sent over the wire, but in types generated by prost it is replaced with +// [`oak::io::Sender`](../io/struct.Sender.html). +message Sender { + fixed64 id = 1; +} + +// Receiver handle for an Oak channel. +// This type is sent over the wire, but in types generated by prost it is replaced with +// [`oak::io::Receiver`](../io/struct.Receiver.html). +message Receiver { + fixed64 id = 1; +} diff --git a/oak_abi/build.rs b/oak_abi/build.rs index 746b7ab7cd7..4cf1c1b7f6f 100644 --- a/oak_abi/build.rs +++ b/oak_abi/build.rs @@ -30,9 +30,10 @@ fn main() { ], &[".."], oak_utils::ProtoOptions { - // Exclude generation of service code, as it would require a reference to the Oak SDK to - // compile. + // Exclude generation of service code and HandleVisit auto-derive, as it would require a + // reference to the Oak SDK to compile. generate_services: false, + derive_handle_visit: false, ..Default::default() }, ); diff --git a/oak_utils/src/lib.rs b/oak_utils/src/lib.rs index a817a6ae4b7..37bca163995 100644 --- a/oak_utils/src/lib.rs +++ b/oak_utils/src/lib.rs @@ -171,6 +171,18 @@ pub struct ProtoOptions { /// /// Generated code depends on the `oak` SDK crate. pub generate_services: bool, + + /// Automatically derive [`HandleVisit`](../oak/handle/trait.HandleVisit.html) for generated + /// Protobuf types. If this is enabled, the generated types can contain handles and can be used + /// to exchange handles with other nodes using inter-node communication over Protocol + /// Buffers. + /// + /// Default: **true**. + /// + /// Generated code depends on the `oak` SDK crate. + pub derive_handle_visit: bool, + + pub out_dir_override: Option, } /// The default option values. @@ -178,6 +190,8 @@ impl Default for ProtoOptions { fn default() -> ProtoOptions { ProtoOptions { generate_services: true, + derive_handle_visit: true, + out_dir_override: None, } } } @@ -210,6 +224,18 @@ where if options.generate_services { prost_config.service_generator(Box::new(OakServiceGenerator)); } + if options.derive_handle_visit { + prost_config + // Auto-derive the HandleVisit trait + .type_attribute(".", "#[derive(::oak::handle::HandleVisit)]") + // Link relevant Oak protos to the Oak SDK types. + .extern_path(".oak.handle", "::oak::handle") + .extern_path(".oak.encap.GrpcRequest", "::oak::grpc::GrpcRequest") + .extern_path(".oak.encap.GrpcResponse", "::oak::grpc::GrpcResponse"); + } + if let Some(out_dir) = options.out_dir_override { + prost_config.out_dir(out_dir); + } prost_config // We require label-related types to be comparable and hashable so that they can be used in // hash-based collections. diff --git a/sdk/Cargo.lock b/sdk/Cargo.lock index f342194ec6c..359fa4fc8f8 100644 --- a/sdk/Cargo.lock +++ b/sdk/Cargo.lock @@ -786,6 +786,7 @@ dependencies = [ "fmt", "log", "oak_abi", + "oak_derive", "oak_tests", "oak_utils", "prost", @@ -824,6 +825,16 @@ dependencies = [ "toml", ] +[[package]] +name = "oak_derive" +version = "0.1.0" +dependencies = [ + "oak", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "oak_runtime" version = "0.1.0" diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 3f46a857be9..a52c678f0cf 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["rust/oak", "rust/oak_config_serializer", "rust/oak_tests"] +members = ["rust/oak", "rust/oak_config_serializer", "rust/oak_derive", "rust/oak_tests"] # Patch dependencies on oak crates so that they refer to the versions within this same repository. # @@ -11,6 +11,7 @@ members = ["rust/oak", "rust/oak_config_serializer", "rust/oak_tests"] oak = { path = "rust/oak" } oak_abi = { path = "../oak_abi" } oak_config_serializer = { path = "rust/oak_config_serializer" } +oak_derive = { path = "rust/oak_derive" } oak_runtime = { path = "../oak/server/rust/oak_runtime" } oak_tests = { path = "rust/oak_tests" } oak_utils = { path = "../oak_utils" } diff --git a/sdk/rust/oak/Cargo.toml b/sdk/rust/oak/Cargo.toml index 6d14ef4587c..b3bc45b66ae 100644 --- a/sdk/rust/oak/Cargo.toml +++ b/sdk/rust/oak/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" [dependencies] oak_abi = "=0.1.0" +oak_derive = "=0.1.0" prost = "*" prost-types = "*" fmt = "*" diff --git a/sdk/rust/oak/build.rs b/sdk/rust/oak/build.rs index 3d2ba103d2c..ac1be407955 100644 --- a/sdk/rust/oak/build.rs +++ b/sdk/rust/oak/build.rs @@ -15,11 +15,29 @@ // fn main() { - oak_utils::compile_protos( + oak_utils::compile_protos_with_options( &[ "../../../oak/proto/storage_service.proto", "../../../oak/proto/roughtime_service.proto", + "../../../oak/proto/handle.proto", ], &["../../.."], + oak_utils::ProtoOptions { + // We can't derive the HandleVisit trait as we are defining it in this crate. + derive_handle_visit: false, + ..Default::default() + }, + ); + + let mut handle_tests_out = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap()); + handle_tests_out.push("handle_tests"); + std::fs::create_dir_all(&handle_tests_out).unwrap(); + oak_utils::compile_protos_with_options( + &["tests/handle_extract_inject.proto"], + &["tests/", "../../../oak/proto"], + oak_utils::ProtoOptions { + out_dir_override: Some(handle_tests_out), + ..Default::default() + }, ); } diff --git a/sdk/rust/oak/src/handle.rs b/sdk/rust/oak/src/handle.rs new file mode 100644 index 00000000000..d6674ad5c5f --- /dev/null +++ b/sdk/rust/oak/src/handle.rs @@ -0,0 +1,233 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//! Utilities for visiting, extract and injecting handles. +//! +//! Applications will usually not interact with the types in this module directly, as the +//! `HandleVisit` trait is automatically derived for all protobuf types compiled with `oak_utils`, +//! and extracting and injecting handles is taken care of by +//! [`oak::io::Receiver`](../io/struct.Receiver.html) and +//! [`oak::io::Sender`](../io/struct.Sender.html). +include!(concat!(env!("OUT_DIR"), "/oak.handle.rs")); + +use crate::OakError; + +/// Raw handle identifier. +/// +/// This is only meant to be used with the [HandleVisit](trait.HandleVisit.html) trait, for +/// anything else [Handle](../struct.Handle.html) should be preferred. +// Note to maintainers: users may not add oak_abi as an explicit dependency, so we re-export this +// type here for the custom derive to use. +pub type Handle = oak_abi::Handle; + +/// Visit all handles present in a type. +/// +/// The most common types that contains handles are +/// [`oak::io::Receiver`](../io/struct.Receiver.html) and +/// [`oak::io::Sender`](../io/struct.Sender.html). +/// +/// This trait can be automatically derived: +/// +/// ``` +/// use oak::handle::HandleVisit; +/// +/// #[derive(HandleVisit)] +/// struct Thing { +/// receiver: oak::io::Receiver, +/// sender: oak::io::Sender, +/// } +/// ``` +/// +/// Alternatively, you can implement the trait manually. This is required if you want the trait to +/// visit handles that are directly contained in a type. +/// +/// ``` +/// use oak::handle::{Handle, HandleVisit}; +/// +/// struct Thing { +/// handle: Handle, +/// } +/// +/// impl HandleVisit for Thing { +/// fn visit(&mut self, mut visitor: F) -> F { +/// visitor(&mut self.handle); +/// visitor +/// } +/// } +/// ``` +pub trait HandleVisit { + /// Invokes the provided closure on every handle contained in `self`. + /// + /// The mutable reference allows modifying the handles. + fn visit(&mut self, visitor: F) -> F; +} + +// A default implementation of the HandleVisit trait that does nothing +macro_rules! handle_visit_blanket_impl { + ($($t:ty),+) => { + $( + impl HandleVisit for $t { + fn visit(&mut self, visitor: F) -> F { + visitor + } + } + )+ + }; +} + +// Provide an implementation for all scalar types in Prost. +// See: https://github.com/danburkert/prost#scalar-values +handle_visit_blanket_impl!((), f64, f32, i32, i64, u32, u64, bool, String, Vec); + +// Provide an implementation for oak_abi types that an implementation cannot be derived for. +// These do not contains handles, so a blanket impl is sufficient. +handle_visit_blanket_impl!( + oak_abi::proto::oak::encap::GrpcResponse, + oak_abi::proto::oak::encap::GrpcRequest, + oak_abi::proto::oak::log::LogMessage, + oak_abi::proto::oak::application::ConfigMap +); + +/// Return all handles in `T`. +/// +/// Also marks all handles in `T` invalid. +/// +/// The original message can be reconstructed by calling +/// [`inject_handles`](fn.inject_handles.html). +/// +/// ``` +/// use oak::handle::{extract_handles, Handle, HandleVisit}; +/// +/// struct Thing { +/// handle: Handle, +/// } +/// # impl HandleVisit for Thing { +/// # fn visit(&mut self, mut visitor: F) -> F { +/// # visitor(&mut self.handle); +/// # visitor +/// # } +/// # } +/// +/// let mut thing = Thing { handle: 42 }; +/// +/// let handles = extract_handles(&mut thing); +/// +/// assert_eq!(handles, vec![42]); +/// ``` +pub fn extract_handles(msg: &mut T) -> Vec { + let mut handles = Vec::new(); + msg.visit(|handle: &mut Handle| { + handles.push(*handle); + *handle = oak_abi::INVALID_HANDLE; + }); + handles +} + +/// Inject handles into a message. +/// +/// If the number of handles provided is not exactly equal to the number of handles needed to fill +/// the message, an error is returned. +/// +/// Order is significant: handles are injected starting at the first field, recursing +/// into nested structs before moving on to the next field. +/// +/// ``` +/// use oak::handle::{inject_handles, Handle, HandleVisit}; +/// +/// # #[derive(Debug, PartialEq)] +/// struct Thing { +/// handle: Handle, +/// } +/// # impl HandleVisit for Thing { +/// # fn visit(&mut self, mut visitor: F) -> F { +/// # visitor(&mut self.handle); +/// # visitor +/// # } +/// # } +/// +/// let handles = vec![42]; +/// let mut thing = Thing { handle: 0 }; +/// +/// inject_handles(&mut thing, &handles).unwrap(); +/// +/// assert_eq!(thing, Thing { handle: 42 }); +/// ``` +pub fn inject_handles(msg: &mut T, handles: &[Handle]) -> Result<(), OakError> { + let mut handles = handles.iter(); + let mut result = Ok(()); + msg.visit(|handle| { + if let Some(to_inject) = handles.next() { + *handle = *to_inject; + } else { + result = Err(OakError::ProtobufDecodeError(None)); + } + }); + if handles.next().is_some() { + result = Err(OakError::ProtobufDecodeError(None)); + } + result +} + +// Import the procedural macro that automatically derives implementations of the trait. +pub use oak_derive::HandleVisit; + +// Implementations for the types generated from different field modifiers +// (https://github.com/danburkert/prost#scalar-values). + +// Optional fields +impl HandleVisit for Option { + fn visit(&mut self, visitor: F) -> F { + if let Some(inner) = self { + inner.visit(visitor) + } else { + visitor + } + } +} + +// For repeated fields. +impl HandleVisit for Vec { + fn visit(&mut self, visitor: F) -> F { + self.iter_mut() + .fold(visitor, |visitor, item| item.visit(visitor)) + } +} + +// For recursive messages. +impl HandleVisit for Box { + fn visit(&mut self, visitor: F) -> F { + self.as_mut().visit(visitor) + } +} + +// For maps. This is only supported for maps that have a key implementing `Ord`, because we need to +// be able to define an order in which to inject/extract handles. Since protobuf only supports +// integral and string types for keys, having this constraint is fine. +// +// See https://developers.google.com/protocol-buffers/docs/proto3#maps for more details. +impl HandleVisit + for std::collections::HashMap +{ + fn visit(&mut self, visitor: F) -> F { + let mut entries: Vec<(&K, &mut V)> = self.iter_mut().collect(); + // Can be unstable because keys are guaranteed to be unique. + entries.sort_unstable_by_key(|&(k, _)| k); + entries + .into_iter() + .map(|(_, v)| v) + .fold(visitor, |visitor, value| value.visit(visitor)) + } +} diff --git a/sdk/rust/oak/src/io/decodable.rs b/sdk/rust/oak/src/io/decodable.rs index 85190267e75..4a27b150788 100644 --- a/sdk/rust/oak/src/io/decodable.rs +++ b/sdk/rust/oak/src/io/decodable.rs @@ -21,12 +21,11 @@ pub trait Decodable: Sized { fn decode(message: &Message) -> Result; } -impl Decodable for T { +impl Decodable for T { fn decode(message: &Message) -> Result { - if !message.handles.is_empty() { - return Err(OakError::ProtobufDecodeError(None)); - } - let value = T::decode(message.bytes.as_slice())?; + let mut value = T::decode(message.bytes.as_slice())?; + let handles: Vec = message.handles.iter().map(|h| h.id).collect(); + crate::handle::inject_handles(&mut value, &handles)?; Ok(value) } } diff --git a/sdk/rust/oak/src/io/encodable.rs b/sdk/rust/oak/src/io/encodable.rs index be3e7b34849..a208db9f564 100644 --- a/sdk/rust/oak/src/io/encodable.rs +++ b/sdk/rust/oak/src/io/encodable.rs @@ -21,11 +21,15 @@ pub trait Encodable { fn encode(&self) -> Result; } -impl Encodable for T { +impl Encodable for T { fn encode(&self) -> Result { + let mut msg = self.clone(); + let handles = crate::handle::extract_handles(&mut msg) + .into_iter() + .map(crate::Handle::from_raw) + .collect(); let mut bytes = Vec::new(); self.encode(&mut bytes)?; - let handles = Vec::new(); Ok(crate::io::Message { bytes, handles }) } } diff --git a/sdk/rust/oak/src/io/receiver.rs b/sdk/rust/oak/src/io/receiver.rs index c8689ea0b0c..4de9ac577fc 100644 --- a/sdk/rust/oak/src/io/receiver.rs +++ b/sdk/rust/oak/src/io/receiver.rs @@ -16,6 +16,10 @@ use crate::{io::Decodable, ChannelReadStatus, OakError, OakStatus, ReadHandle}; use log::error; +use prost::{ + bytes::{Buf, BufMut}, + encoding::{DecodeContext, WireType}, +}; use serde::{Deserialize, Serialize}; /// Wrapper for a handle to the read half of a channel, allowing to receive data that can be decoded @@ -114,3 +118,54 @@ impl Receiver { } } } + +impl crate::handle::HandleVisit for Receiver { + fn visit(&mut self, mut visitor: F) -> F { + visitor(&mut self.handle.handle.id); + visitor + } +} + +impl Receiver { + pub fn as_proto_handle(&self) -> crate::handle::Receiver { + crate::handle::Receiver { + id: self.handle.handle.id, + } + } +} + +// Lean on the auto-generated impl of oak::handle::Receiver. +impl prost::Message for Receiver { + fn encoded_len(&self) -> usize { + self.as_proto_handle().encoded_len() + } + + fn clear(&mut self) { + self.handle.handle.id = 0; + } + + fn encode_raw(&self, buf: &mut B) { + self.as_proto_handle().encode_raw(buf); + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), prost::DecodeError> { + let mut proto = self.as_proto_handle(); + proto.merge_field(tag, wire_type, buf, ctx)?; + self.handle.handle.id = proto.id; + Ok(()) + } +} + +impl Default for Receiver { + fn default() -> Receiver { + Receiver::new(ReadHandle { + handle: crate::Handle::invalid(), + }) + } +} diff --git a/sdk/rust/oak/src/io/sender.rs b/sdk/rust/oak/src/io/sender.rs index 8e19c7e807f..21b4b215d36 100644 --- a/sdk/rust/oak/src/io/sender.rs +++ b/sdk/rust/oak/src/io/sender.rs @@ -15,6 +15,10 @@ // use crate::{io::Encodable, OakError, OakStatus, WriteHandle}; +use prost::{ + bytes::{Buf, BufMut}, + encoding::{DecodeContext, WireType}, +}; use serde::{Deserialize, Serialize}; /// Wrapper for a handle to the send half of a channel, allowing to send data that can be encoded as @@ -53,3 +57,54 @@ impl Sender { Ok(()) } } + +impl crate::handle::HandleVisit for Sender { + fn visit(&mut self, mut visitor: F) -> F { + visitor(&mut self.handle.handle.id); + visitor + } +} + +impl Sender { + pub fn as_proto_handle(&self) -> crate::handle::Sender { + crate::handle::Sender { + id: self.handle.handle.id, + } + } +} + +// Lean on the auto-generated impl of oak::handle::Sender. +impl prost::Message for Sender { + fn encoded_len(&self) -> usize { + self.as_proto_handle().encoded_len() + } + + fn clear(&mut self) { + self.handle.handle.id = 0; + } + + fn encode_raw(&self, buf: &mut B) { + self.as_proto_handle().encode_raw(buf); + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), prost::DecodeError> { + let mut proto = self.as_proto_handle(); + proto.merge_field(tag, wire_type, buf, ctx)?; + self.handle.handle.id = proto.id; + Ok(()) + } +} + +impl Default for Sender { + fn default() -> Sender { + Sender::new(WriteHandle { + handle: crate::Handle::invalid(), + }) + } +} diff --git a/sdk/rust/oak/src/lib.rs b/sdk/rust/oak/src/lib.rs index c3e7ec60db1..65fdf650502 100644 --- a/sdk/rust/oak/src/lib.rs +++ b/sdk/rust/oak/src/lib.rs @@ -32,6 +32,7 @@ mod stubs; pub use error::OakError; pub mod grpc; +pub mod handle; pub mod io; pub mod logger; pub mod node_config; diff --git a/sdk/rust/oak/tests/handle_extract_inject.proto b/sdk/rust/oak/tests/handle_extract_inject.proto new file mode 100644 index 00000000000..46386b2c9c0 --- /dev/null +++ b/sdk/rust/oak/tests/handle_extract_inject.proto @@ -0,0 +1,52 @@ +syntax = "proto3"; + +package tests; + +import "handle.proto"; + +message TestMessage { + string other_arbitrary_field = 1; + oak.handle.Sender test_sender = 2 [(oak.handle.message_type) = ".tests.TestMessageType"]; + oak.handle.Receiver test_receiver = 3 [(oak.handle.message_type) = ".tests.TestMessageType"]; +} + +message TestMessageWithEnum { + oneof either { + oak.handle.Sender either_sender = 1 [(oak.handle.message_type) = ".tests.TestMessageType"]; + oak.handle.Receiver either_receiver = 2 [(oak.handle.message_type) = ".tests.TestMessageType"]; + } +} + +message RecursiveMessage { + oak.handle.Sender sender = 1 [(oak.handle.message_type) = ".tests.TestMessageType"]; + RecursiveMessage recursive_message = 2; +} + +message RepeatedMessage { + repeated oak.handle.Sender sender = 1 [(oak.handle.message_type) = ".tests.TestMessageType"]; +} + +message TestMessageType { + string body = 1; +} + +message LookMaNoHandles { + string a = 1; + int64 b = 2; +} + +message SaneHandleOrder { + oak.handle.Sender sender = 1 [(oak.handle.message_type) = ".tests.TestMessageType"]; + repeated SaneHandleOrder children = 2; + oak.handle.Receiver receiver = 3 [(oak.handle.message_type) = ".tests.TestMessageType"]; +} + +// This message should contain all other test messages. It is used to test roundtrips +message RoundtripContainer { + TestMessage test_message = 1; + TestMessageWithEnum test_message_with_enum = 2; + RecursiveMessage recursive_message = 3; + RepeatedMessage repeated_message = 4; + LookMaNoHandles look_ma_no_handles = 5; + SaneHandleOrder sane_handle_order = 6; +} diff --git a/sdk/rust/oak/tests/handle_extract_inject.rs b/sdk/rust/oak/tests/handle_extract_inject.rs new file mode 100644 index 00000000000..03739ceeb0b --- /dev/null +++ b/sdk/rust/oak/tests/handle_extract_inject.rs @@ -0,0 +1,314 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +include!(concat!(env!("OUT_DIR"), "/handle_tests/tests.rs")); + +use oak::{ + handle::{extract_handles, inject_handles}, + io::{Receiver, Sender}, +}; + +#[test] +fn extract_nothing() { + let mut message = LookMaNoHandles { + a: "Hello, world!".to_string(), + b: 42, + }; + + let handles = extract_handles(&mut message); + + assert_eq!(handles, vec![]); +} + +#[test] +fn extract_struct() { + let mut message = TestMessage { + other_arbitrary_field: "Test".to_string(), + test_sender: Some(sender(42)), + test_receiver: Some(receiver(1337)), + }; + + let handles = extract_handles(&mut message); + + assert_eq!(handles, vec![42, 1337]); + assert_eq!( + message, + TestMessage { + other_arbitrary_field: "Test".to_string(), + test_sender: Some(sender(0)), + test_receiver: Some(receiver(0)), + } + ); +} + +#[test] +fn enum_extract_sender() { + let mut message = TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherSender(sender(42))), + }; + + let handles = extract_handles(&mut message); + + assert_eq!(handles, vec![42]); + assert_eq!( + message, + TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherSender(sender(0))), + } + ); +} + +#[test] +fn enum_extract_receiver() { + let mut message = TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherReceiver(receiver(42))), + }; + + let handles = extract_handles(&mut message); + + assert_eq!(handles, vec![42]); + assert_eq!( + message, + TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherReceiver(receiver(0))), + } + ); +} + +#[test] +fn enum_inject_sender() { + let mut message = TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherSender(sender(0))), + }; + + inject_handles(&mut message, &[42]).unwrap(); + + assert_eq!( + message, + TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherSender(sender(42))), + } + ); +} + +#[test] +fn map_extract() { + use dummy_hash::DummyBuildHasher; + use std::collections::HashMap; + let mut map: HashMap, DummyBuildHasher> = + HashMap::with_hasher(DummyBuildHasher); + map.insert(1, sender(10)); + map.insert(2, sender(20)); + // DummyHasher should yield elements in reverse order. + assert_eq!( + map.values().cloned().collect::>>(), + vec![sender(20), sender(10)] + ); + + let handles = extract_handles(&mut map); + + // Even though the hashmap returns the values in reverse order, we expect the values to be + // extracted in the order of their keys. + assert_eq!(handles, vec![10, 20]); +} + +#[test] +fn map_inject() { + use dummy_hash::DummyBuildHasher; + use std::collections::HashMap; + let mut map: HashMap, DummyBuildHasher> = + HashMap::with_hasher(DummyBuildHasher); + map.insert(1, sender(0)); + map.insert(2, sender(0)); + + inject_handles(&mut map, &[10, 20]).unwrap(); + + assert_eq!(map.get(&1).cloned(), Some(sender(10))); + assert_eq!(map.get(&2).cloned(), Some(sender(20))); +} + +#[test] +fn recursive_extract() { + let mut msg = RecursiveMessage { + sender: None, + recursive_message: Some(Box::new(RecursiveMessage { + sender: Some(sender(42)), + recursive_message: None, + })), + }; + + let handles = extract_handles(&mut msg); + + assert_eq!(handles, vec![42]); +} + +#[test] +fn repeated_extract() { + let mut msg = RepeatedMessage { + sender: vec![sender(1), sender(2), sender(3)], + }; + + let handles = extract_handles(&mut msg); + + assert_eq!(handles, vec![1, 2, 3]); +} + +#[test] +fn inject_too_many_fails() { + let mut message = TestMessage { + other_arbitrary_field: "Test".to_string(), + test_sender: Some(sender(0)), + test_receiver: Some(receiver(0)), + }; + + let handles = vec![1, 2, 3]; + + assert!(inject_handles(&mut message, &handles).is_err()); +} + +#[test] +fn inject_too_few_fails() { + let mut message = TestMessage { + other_arbitrary_field: "Test".to_string(), + test_sender: Some(sender(0)), + test_receiver: Some(receiver(0)), + }; + + let handles = vec![1]; + + assert!(inject_handles(&mut message, &handles).is_err()); +} + +#[test] +fn sane_handle_order() { + let reference = SaneHandleOrder { + sender: Some(sender(1)), + children: vec![ + SaneHandleOrder { + sender: Some(sender(2)), + children: vec![SaneHandleOrder { + sender: Some(sender(3)), + children: vec![], + receiver: Some(receiver(4)), + }], + receiver: Some(receiver(5)), + }, + SaneHandleOrder { + sender: Some(sender(6)), + children: vec![], + receiver: Some(receiver(7)), + }, + ], + receiver: Some(receiver(8)), + }; + let mut message = reference.clone(); + + let handles = extract_handles(&mut message); + inject_handles(&mut message, &handles).expect("Failed to re-inject handles"); + + assert_eq!(handles, vec![1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(reference, message); +} + +#[test] +fn roundtrip() { + let reference = RoundtripContainer { + look_ma_no_handles: Some(LookMaNoHandles { + a: "test".to_string(), + b: 2, + }), + recursive_message: Some(RecursiveMessage { + sender: Some(sender(1)), + recursive_message: Some(Box::new(RecursiveMessage { + sender: None, + recursive_message: Some(Box::new(RecursiveMessage { + sender: Some(sender(2)), + recursive_message: None, + })), + })), + }), + repeated_message: Some(RepeatedMessage { + sender: vec![sender(10), sender(9), sender(8)], + }), + sane_handle_order: Some(SaneHandleOrder { + sender: Some(sender(10)), + children: vec![SaneHandleOrder { + sender: Some(sender(9)), + children: vec![], + receiver: Some(receiver(2)), + }], + receiver: Some(receiver(1)), + }), + test_message: Some(TestMessage { + other_arbitrary_field: "test".to_string(), + test_sender: Some(sender(1)), + test_receiver: Some(receiver(2)), + }), + test_message_with_enum: Some(TestMessageWithEnum { + either: Some(test_message_with_enum::Either::EitherReceiver(receiver(1))), + }), + }; + + let mut message = reference.clone(); + + let handles = extract_handles(&mut message); + inject_handles(&mut message, &handles).expect("Failed to inject handles back"); + + assert_eq!(reference, message); +} + +fn sender(id: u64) -> Sender { + Sender::new(oak::WriteHandle { + handle: oak::Handle::from_raw(id), + }) +} + +fn receiver(id: u64) -> Receiver { + Receiver::new(oak::ReadHandle { + handle: oak::Handle::from_raw(id), + }) +} + +// Dummy hashing utilities to make the order of elements returned from a HashMap deterministic +// (reverse sorted order by key). +mod dummy_hash { + use std::hash::{BuildHasher, Hasher}; + + pub struct DummyHasher(u64); + + impl Hasher for DummyHasher { + fn finish(&self) -> u64 { + // Reverse the order + core::u64::MAX - self.0 + } + + fn write(&mut self, bytes: &[u8]) { + for b in bytes { + self.0 += *b as u64; + } + } + } + + pub struct DummyBuildHasher; + + impl BuildHasher for DummyBuildHasher { + type Hasher = DummyHasher; + + fn build_hasher(&self) -> Self::Hasher { + DummyHasher(0) + } + } +} diff --git a/sdk/rust/oak_derive/Cargo.toml b/sdk/rust/oak_derive/Cargo.toml new file mode 100644 index 00000000000..e9c1232f003 --- /dev/null +++ b/sdk/rust/oak_derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "oak_derive" +version = "0.1.0" +authors = ["Daan de Graaf "] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +syn = "1.0" +quote = "1.0" +proc-macro2 = "1.0" + +[dev-dependencies] +oak = "*" diff --git a/sdk/rust/oak_derive/src/lib.rs b/sdk/rust/oak_derive/src/lib.rs new file mode 100644 index 00000000000..7e2a58091fb --- /dev/null +++ b/sdk/rust/oak_derive/src/lib.rs @@ -0,0 +1,122 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{Data, Fields, Ident}; + +/// Automatically derives the [`HandleVisit`](../oak/handle/trait.HandleVisit.html) trait for +/// structs and enums generated by prost. +#[proc_macro_derive(HandleVisit)] +pub fn handle_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(input as syn::DeriveInput); + + let name = &ast.ident; + match &ast.data { + Data::Struct(data) => struct_impls(name, data), + Data::Enum(data) => enum_impls(name, data), + Data::Union(_) => panic!("HandleVisit cannot be derived for unions"), + } + .into() +} + +fn struct_impls(name: &Ident, data: &syn::DataStruct) -> TokenStream { + let accessors: Vec = match &data.fields { + Fields::Named(named) => named + .named + .iter() + .flat_map(|f| f.ident.clone()) + .map(|i| quote!(self.#i)) + .collect(), + Fields::Unnamed(unnamed) => unnamed + .unnamed + .iter() + .enumerate() + .map(|(i, _)| { + let index = syn::Index::from(i); + quote!(self.#index) + }) + .collect(), + Fields::Unit => Vec::new(), + }; + let body = accessors_visit(&accessors); + + quote! { + impl ::oak::handle::HandleVisit for #name { + fn visit(&mut self, visitor: F) -> F { + #body + } + } + } +} + +fn enum_impls(name: &Ident, data: &syn::DataEnum) -> TokenStream { + let variants: Vec = data.variants.iter().map(variant_impl).collect(); + quote! { + impl ::oak::handle::HandleVisit for #name { + fn visit(&mut self, visitor: F) -> F { + match self { + #( + #name::#variants, + )* + } + } + } + } +} + +fn variant_impl(variant: &syn::Variant) -> TokenStream { + let variant_ident = &variant.ident; + match &variant.fields { + Fields::Named(fields) => { + let fields: Vec = fields.named.iter().flat_map(|f| f.ident.clone()).collect(); + let body = accessors_visit(&fields); + quote! { + #variant_ident { #( #fields ),* } => { + #body + } + } + } + Fields::Unnamed(fields) => { + // Name the fields _0, _1, ... + let accessors: Vec = fields + .unnamed + .iter() + .enumerate() + .map(|(i, _)| format_ident!("_{}", i)) + .collect(); + let body = accessors_visit(&accessors); + quote! { + #variant_ident( #( #accessors ),* ) => { + #body + } + } + } + Fields::Unit => quote! { + #variant => visitor + }, + } +} + +fn accessors_visit(accessors: &[T]) -> TokenStream { + quote! { + let mut _v = visitor; + #( + _v = #accessors.visit(_v); + )* + _v + } +} diff --git a/sdk/rust/oak_derive/tests/visit.rs b/sdk/rust/oak_derive/tests/visit.rs new file mode 100644 index 00000000000..95363cbfbe0 --- /dev/null +++ b/sdk/rust/oak_derive/tests/visit.rs @@ -0,0 +1,127 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +use oak::handle::{Handle, HandleVisit}; + +#[derive(Default)] +struct Visited(Handle); + +impl HandleVisit for Visited { + fn visit(&mut self, mut f: F) -> F { + f(&mut self.0); + f + } +} + +#[test] +fn named() { + #[derive(Default, HandleVisit)] + struct Named { + a: Visited, + b: u64, + } + + assert_visit(Named::default(), 1); +} + +#[test] +fn unnamed() { + #[derive(Default, HandleVisit)] + struct Unnamed(u64, Visited); + + assert_visit(Unnamed::default(), 1); +} + +#[test] +fn unit() { + #[derive(HandleVisit)] + struct Unit; + + assert_visit(Unit, 0); +} + +#[test] +fn multiple() { + #[derive(Default, HandleVisit)] + struct Multiple { + a: Visited, + b: Visited, + c: Visited, + d: u64, + } + + assert_visit(Multiple::default(), 3); +} + +#[test] +fn nested() { + #[derive(Default, HandleVisit)] + struct Inner { + a: Visited, + b: u64, + } + + #[derive(Default, HandleVisit)] + struct Outer { + a: Inner, + b: Visited, + c: u64, + } + + assert_visit(Outer::default(), 2); +} + +mod enums { + use super::*; + #[derive(HandleVisit)] + enum Enum { + Named { a: Visited, b: u64 }, + Unnamed(u64, Visited), + Unit, + } + + #[test] + fn named() { + assert_visit( + Enum::Named { + a: Visited::default(), + b: 0, + }, + 1, + ); + } + + #[test] + fn unnamed() { + assert_visit(Enum::Unnamed(0, Visited::default()), 1); + } + + #[test] + fn unit() { + assert_visit(Enum::Unit, 0); + } +} + +// Asserts that `t` is visited exactly `count` times when calling `HandleVisit::visit`. +fn assert_visit(mut t: T, count: usize) { + let mut counter = 0; + + t.visit(|_| { + counter += 1; + }); + + assert_eq!(counter, count); +}