From e10dda42a44a342a5a1f1ae840ea70772b1e8074 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Thu, 6 Apr 2023 10:04:45 +0000 Subject: [PATCH] Change the way buffers are passed to the wasm-node (#396) * Change the way buffers are passed to the wasm-node * Directly use Uint8Array instead of ArrayBuffer * Free the buffers ASAP * CHANGELOG * Docs and spellcheck --- wasm-node/CHANGELOG.md | 1 + .../src/instance/bindings-smoldot-light.ts | 61 +++--- wasm-node/javascript/src/instance/bindings.ts | 11 +- wasm-node/javascript/src/instance/instance.ts | 57 +++--- .../javascript/src/instance/raw-instance.ts | 7 +- wasm-node/rust/src/bindings.rs | 173 ++++++++++-------- wasm-node/rust/src/lib.rs | 65 ++----- wasm-node/rust/src/platform.rs | 61 ++---- 8 files changed, 193 insertions(+), 243 deletions(-) diff --git a/wasm-node/CHANGELOG.md b/wasm-node/CHANGELOG.md index ef3488f98d..5c87a8d288 100644 --- a/wasm-node/CHANGELOG.md +++ b/wasm-node/CHANGELOG.md @@ -9,6 +9,7 @@ ### Fixed +- Fix a potential undefined behavior in the way the Rust and JavaScript communicate. ([#396](https://github.com/smol-dot/smoldot/pull/396)) - Properly check whether Yamux substream IDs allocated by the remote are valid. ([#383](https://github.com/smol-dot/smoldot/pull/383)) - Fix the size of the data of Yamux frames with the `SYN` flag not being verified against the allowed credits. ([#383](https://github.com/smol-dot/smoldot/pull/383)) - Fix Yamux repeatedly sending empty data frames when the allowed window size is 0. ([#383](https://github.com/smol-dot/smoldot/pull/383)) diff --git a/wasm-node/javascript/src/instance/bindings-smoldot-light.ts b/wasm-node/javascript/src/instance/bindings-smoldot-light.ts index 0cc50067ba..2b16b0f7b1 100644 --- a/wasm-node/javascript/src/instance/bindings-smoldot-light.ts +++ b/wasm-node/javascript/src/instance/bindings-smoldot-light.ts @@ -26,6 +26,13 @@ import type { SmoldotWasmInstance } from './bindings.js'; export interface Config { instance?: SmoldotWasmInstance, + /** + * Array used to store the buffers provided to the Rust code. + * + * When `buffer_size` or `buffer_index` are called, the buffer is found here. + */ + bufferIndices: Array, + /** * Returns the number of milliseconds since an arbitrary epoch. */ @@ -254,6 +261,19 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, config.onPanic(message); }, + buffer_size: (bufferIndex: number) => { + const buf = config.bufferIndices[bufferIndex]!; + return buf.byteLength; + }, + + buffer_copy: (bufferIndex: number, targetPtr: number) => { + const instance = config.instance!; + targetPtr = targetPtr >>> 0; + + const buf = config.bufferIndices[bufferIndex]!; + new Uint8Array(instance.exports.memory.buffer).set(buf, targetPtr); + }, + // Used by the Rust side to notify that a JSON-RPC response or subscription notification // is available in the queue of JSON-RPC responses. json_rpc_responses_non_empty: (chainId: number) => { @@ -322,12 +342,12 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, // Must create a new connection object. This implementation stores the created object in // `connections`. - connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorPtrPtr: number) => { + connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorBufferIndexPtr: number) => { const instance = config.instance!; addrPtr >>>= 0; addrLen >>>= 0; - errorPtrPtr >>>= 0; + errorBufferIndexPtr >>>= 0; if (!!connections[connectionId]) { throw new Error("internal error: connection already allocated"); @@ -350,13 +370,13 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, break } case 'multi-stream': { - const bufferLen = 1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length; - const ptr = instance.exports.alloc(bufferLen) >>> 0; - const mem = new Uint8Array(instance.exports.memory.buffer); - buffer.writeUInt8(mem, ptr, 0); - mem.set(info.localTlsCertificateMultihash, ptr + 1) - mem.set(info.remoteTlsCertificateMultihash, ptr + 1 + info.localTlsCertificateMultihash.length) - instance.exports.connection_open_multi_stream(connectionId, ptr, bufferLen); + const handshakeTy = new Uint8Array(1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length); + buffer.writeUInt8(handshakeTy, 0, 0); + handshakeTy.set(info.localTlsCertificateMultihash, 1) + handshakeTy.set(info.remoteTlsCertificateMultihash, 1 + info.localTlsCertificateMultihash.length) + config.bufferIndices[0] = handshakeTy; + instance.exports.connection_open_multi_stream(connectionId, 0); + delete config.bufferIndices[0] break } } @@ -365,10 +385,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, onConnectionReset: (message: string) => { if (killedTracked.killed) return; try { - const encoded = new TextEncoder().encode(message) - const ptr = instance.exports.alloc(encoded.length) >>> 0; - new Uint8Array(instance.exports.memory.buffer).set(encoded, ptr); - instance.exports.connection_reset(connectionId, ptr, encoded.length); + config.bufferIndices[0] = new TextEncoder().encode(message); + instance.exports.connection_reset(connectionId, 0); + delete config.bufferIndices[0] } catch(_error) {} }, onWritableBytes: (numExtra, streamId) => { @@ -384,9 +403,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, onMessage: (message: Uint8Array, streamId?: number) => { if (killedTracked.killed) return; try { - const ptr = instance.exports.alloc(message.length) >>> 0; - new Uint8Array(instance.exports.memory.buffer).set(message, ptr) - instance.exports.stream_message(connectionId, streamId || 0, ptr, message.length); + config.bufferIndices[0] = message; + instance.exports.stream_message(connectionId, streamId || 0, 0); + delete config.bufferIndices[0] } catch(_error) {} }, onStreamOpened: (streamId: number, direction: 'inbound' | 'outbound', initialWritableBytes) => { @@ -418,13 +437,11 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports, if (error instanceof Error) { errorStr = error.toString(); } + const mem = new Uint8Array(instance.exports.memory.buffer); - const encoded = new TextEncoder().encode(errorStr) - const ptr = instance.exports.alloc(encoded.length) >>> 0; - mem.set(encoded, ptr); - buffer.writeUInt32LE(mem, errorPtrPtr, ptr); - buffer.writeUInt32LE(mem, errorPtrPtr + 4, encoded.length); - buffer.writeUInt8(mem, errorPtrPtr + 8, isBadAddress ? 1 : 0); + config.bufferIndices[0] = new TextEncoder().encode(errorStr) + buffer.writeUInt32LE(mem, errorBufferIndexPtr, 0); + buffer.writeUInt8(mem, errorBufferIndexPtr + 4, isBadAddress ? 1 : 0); return 1; } }, diff --git a/wasm-node/javascript/src/instance/bindings.ts b/wasm-node/javascript/src/instance/bindings.ts index b3baff7065..8bc46799c9 100644 --- a/wasm-node/javascript/src/instance/bindings.ts +++ b/wasm-node/javascript/src/instance/bindings.ts @@ -26,22 +26,21 @@ export interface SmoldotWasmExports extends WebAssembly.Exports { init: (maxLogLevel: number, enableCurrentTask: number, cpuRateLimit: number, periodicallyYield: number) => void, set_periodically_yield: (periodicallyYield: number) => void, start_shutdown: () => void, - alloc: (len: number) => number, - add_chain: (chainSpecPointer: number, chainSpecLen: number, databaseContentPointer: number, databaseContentLen: number, jsonRpcRunning: number, potentialRelayChainsPtr: number, potentialRelayChainsLen: number) => number; + add_chain: (chainSpecBufferIndex: number, databaseContentBufferIndex: number, jsonRpcRunning: number, potentialRelayChainsBufferIndex: number) => number; remove_chain: (chainId: number) => void, chain_is_ok: (chainId: number) => number, chain_error_len: (chainId: number) => number, chain_error_ptr: (chainId: number) => number, - json_rpc_send: (textPtr: number, textLen: number, chainId: number) => number, + json_rpc_send: (textBufferIndex: number, chainId: number) => number, json_rpc_responses_peek: (chainId: number) => number, json_rpc_responses_pop: (chainId: number) => void, timer_finished: (timerId: number) => void, connection_open_single_stream: (connectionId: number, handshakeTy: number, initialWritableBytes: number, writeClosable: number) => void, - connection_open_multi_stream: (connectionId: number, handshakeTyPtr: number, handshakeTyLen: number) => void, + connection_open_multi_stream: (connectionId: number, handshakeTyBufferIndex: number) => void, stream_writable_bytes: (connectionId: number, streamId: number, numBytes: number) => void, - stream_message: (connectionId: number, streamId: number, ptr: number, len: number) => void, + stream_message: (connectionId: number, streamId: number, bufferIndex: number) => void, connection_stream_opened: (connectionId: number, streamId: number, outbound: number, initialWritableBytes: number) => void, - connection_reset: (connectionId: number, ptr: number, len: number) => void, + connection_reset: (connectionId: number, bufferIndex: number) => void, stream_reset: (connectionId: number, streamId: number) => void, } diff --git a/wasm-node/javascript/src/instance/instance.ts b/wasm-node/javascript/src/instance/instance.ts index fb1c4a8c8f..2109c41321 100644 --- a/wasm-node/javascript/src/instance/instance.ts +++ b/wasm-node/javascript/src/instance/instance.ts @@ -76,7 +76,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform // - At initialization, it is a Promise containing the Wasm VM is still initializing. // - After the Wasm VM has finished initialization, contains the `WebAssembly.Instance` object. // - let state: { initialized: false, promise: Promise } | { initialized: true, instance: SmoldotWasmInstance, unregisterCallback: () => void }; + let state: { initialized: false, promise: Promise<[SmoldotWasmInstance, Array]> } | { initialized: true, instance: SmoldotWasmInstance, bufferIndices: Array, unregisterCallback: () => void }; const crashError: { error?: CrashError } = {}; @@ -127,7 +127,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform }; state = { - initialized: false, promise: instance.startInstance(config, platformBindings).then((instance) => { + initialized: false, promise: instance.startInstance(config, platformBindings).then(([instance, bufferIndices]) => { // `config.cpuRateLimit` is a floating point that should be between 0 and 1, while the value // to pass as parameter must be between `0` and `2^32-1`. // The few lines of code below should handle all possible values of `number`, including @@ -148,22 +148,22 @@ export function start(configMessage: Config, platformBindings: instance.Platform }); instance.exports.init(configMessage.maxLogLevel, configMessage.enableCurrentTask ? 1 : 0, cpuRateLimit, periodicallyYield ? 1 : 0); - state = { initialized: true, instance, unregisterCallback }; - return instance; + state = { initialized: true, instance, bufferIndices, unregisterCallback }; + return [instance, bufferIndices]; }) }; - async function queueOperation(operation: (instance: SmoldotWasmInstance) => T): Promise { + async function queueOperation(operation: (instance: SmoldotWasmInstance, bufferIndices: Array) => T): Promise { // What to do depends on the type of `state`. // See the documentation of the `state` variable for information. if (!state.initialized) { // A message has been received while the Wasm VM is still initializing. Queue it for when // initialization is over. - return state.promise.then((instance) => operation(instance)) + return state.promise.then(([instance, bufferIndices]) => operation(instance, bufferIndices)) } else { // Everything is already initialized. Process the message synchronously. - return operation(state.instance) + return operation(state.instance, state.bufferIndices) } } @@ -179,10 +179,8 @@ export function start(configMessage: Config, platformBindings: instance.Platform let retVal; try { - const encoded = new TextEncoder().encode(request) - const ptr = state.instance.exports.alloc(encoded.length) >>> 0; - new Uint8Array(state.instance.exports.memory.buffer).set(encoded, ptr); - retVal = state.instance.exports.json_rpc_send(ptr, encoded.length, chainId) >>> 0; + state.bufferIndices[0] = new TextEncoder().encode(request) + retVal = state.instance.exports.json_rpc_send(0, chainId) >>> 0; } catch (_error) { console.assert(crashError.error); throw crashError.error @@ -234,38 +232,27 @@ export function start(configMessage: Config, platformBindings: instance.Platform }, addChain: (chainSpec: string, databaseContent: string, potentialRelayChains: number[], disableJsonRpc: boolean): Promise<{ success: true, chainId: number } | { success: false, error: string }> => { - return queueOperation((instance) => { + return queueOperation((instance, bufferIndices) => { if (crashError.error) throw crashError.error; try { - // Write the chain specification into memory. - const chainSpecEncoded = new TextEncoder().encode(chainSpec) - const chainSpecPtr = instance.exports.alloc(chainSpecEncoded.length) >>> 0; - new Uint8Array(instance.exports.memory.buffer).set(chainSpecEncoded, chainSpecPtr); - - // Write the database content into memory. - const databaseContentEncoded = new TextEncoder().encode(databaseContent) - const databaseContentPtr = instance.exports.alloc(databaseContentEncoded.length) >>> 0; - new Uint8Array(instance.exports.memory.buffer).set(databaseContentEncoded, databaseContentPtr); - - // Write the potential relay chains into memory. - const potentialRelayChainsLen = potentialRelayChains.length; - const potentialRelayChainsPtr = instance.exports.alloc(potentialRelayChainsLen * 4) >>> 0; - for (let idx = 0; idx < potentialRelayChains.length; ++idx) { - buffer.writeUInt32LE(new Uint8Array(instance.exports.memory.buffer), potentialRelayChainsPtr + idx * 4, potentialRelayChains[idx]!); - } - // `add_chain` unconditionally allocates a chain id. If an error occurs, however, this chain // id will refer to an *erroneous* chain. `chain_is_ok` is used below to determine whether it // has succeeeded or not. // Note that `add_chain` properly de-allocates buffers even if it failed. - const chainId = instance.exports.add_chain( - chainSpecPtr, chainSpecEncoded.length, - databaseContentPtr, databaseContentEncoded.length, - disableJsonRpc ? 0 : 1, - potentialRelayChainsPtr, potentialRelayChainsLen - ); + bufferIndices[0] = new TextEncoder().encode(chainSpec) + bufferIndices[1] = new TextEncoder().encode(databaseContent) + const potentialRelayChainsEncoded = new Uint8Array(potentialRelayChains.length * 4) + for (let idx = 0; idx < potentialRelayChains.length; ++idx) { + buffer.writeUInt32LE(potentialRelayChainsEncoded, idx * 4, potentialRelayChains[idx]!); + } + bufferIndices[2] = potentialRelayChainsEncoded + const chainId = instance.exports.add_chain(0, 1, disableJsonRpc ? 0 : 1, 2); + + delete bufferIndices[0] + delete bufferIndices[1] + delete bufferIndices[2] if (instance.exports.chain_is_ok(chainId) != 0) { console.assert(!chains.has(chainId)); diff --git a/wasm-node/javascript/src/instance/raw-instance.ts b/wasm-node/javascript/src/instance/raw-instance.ts index 55addf1604..789288c3e9 100644 --- a/wasm-node/javascript/src/instance/raw-instance.ts +++ b/wasm-node/javascript/src/instance/raw-instance.ts @@ -91,7 +91,7 @@ export interface PlatformBindings { connect(config: ConnectionConfig): Connection; } -export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise { +export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise<[SmoldotWasmInstance, Array]> { // The actual Wasm bytecode is base64-decoded then deflate-decoded from a constant found in a // different file. // This is suboptimal compared to using `instantiateStreaming`, but it is the most @@ -100,8 +100,11 @@ export async function startInstance(config: Config, platformBindings: PlatformBi let killAll: () => void; + const bufferIndices = new Array; + // Used to bind with the smoldot-light bindings. See the `bindings-smoldot-light.js` file. const smoldotJsConfig: SmoldotBindingsConfig = { + bufferIndices, performanceNow: platformBindings.performanceNow, connect: platformBindings.connect, onPanic: (message) => { @@ -141,5 +144,5 @@ export async function startInstance(config: Config, platformBindings: PlatformBi const instance = result.instance as SmoldotWasmInstance; smoldotJsConfig.instance = instance; wasiConfig.instance = instance; - return instance; + return [instance, bufferIndices]; } diff --git a/wasm-node/rust/src/bindings.rs b/wasm-node/rust/src/bindings.rs index 2daaaef4dd..c1179f4914 100644 --- a/wasm-node/rust/src/bindings.rs +++ b/wasm-node/rust/src/bindings.rs @@ -60,11 +60,13 @@ //! the guest. //! //! It is, however, important when the value needs to be interpreted from the host side, such as -//! for example the return value of [`alloc`]. When using JavaScript as the host, you must do -//! `>>> 0` on all the `u32` values before interpreting them, in order to be certain than they -//! are treated as unsigned integers by the JavaScript. +//! for example the `message_ptr` parameter of [`panic()`]. When using JavaScript as the host, you +//! must do `>>> 0` on all the `u32` values before interpreting them, in order to be certain than +//! they are treated as unsigned integers by the JavaScript. //! +use core::mem; + #[link(wasm_import_module = "smoldot")] extern "C" { /// Must stop the execution immediately. The message is a UTF-8 string found in the memory of @@ -96,6 +98,21 @@ extern "C" { /// behave like `abort` and prevent any further execution. pub fn panic(message_ptr: u32, message_len: u32); + /// Copies the entire content of the buffer with the given index to the memory of the + /// WebAssembly at offset `target_pointer`. + /// + /// In situations where a buffer must be provided from the JavaScript to the Rust code, the + /// JavaScript must (prior to calling the Rust function that requires the buffer) assign a + /// "buffer index" to the buffer it wants to provide. The Rust code then calls the + /// [`buffer_size`] and [`buffer_copy`] functions in order to obtain the length and content + /// of the buffer. + pub fn buffer_copy(buffer_index: u32, target_pointer: u32); + + /// Returns the size (in bytes) of the buffer with the given index. + /// + /// See the documentation of [`buffer_copy`] for context. + pub fn buffer_size(buffer_index: u32) -> u32; + /// The queue of JSON-RPC responses of the given chain is no longer empty. /// /// Use [`json_rpc_responses_peek`] in order to obtain information about the responses in the @@ -176,14 +193,14 @@ extern "C" { /// > **Note**: If you implement this function using for example `new WebSocket()`, please /// > keep in mind that exceptions should be caught and turned into an error code. /// - /// The `error_ptr_ptr` parameter should be treated as a pointer to two consecutive - /// little-endian 32-bits unsigned numbers and a 8-bits unsigned number. If an error happened, - /// call [`alloc`] to allocate memory, write a UTF-8 error message in that given location, - /// then write that location at the location indicated by `error_ptr_ptr` and the length of - /// that string at the location `error_ptr_ptr + 4`. The buffer will be de-allocated by the - /// client. Then, write at location `error_ptr_ptr + 8` a `1` if the error is caused by the - /// address being forbidden or unsupported, and `0` otherwise. If no error happens, nothing - /// should be written to `error_ptr_ptr`. + /// If an error happened, assign a so-called "buffer index" (a `u32`) representing the buffer + /// containing the UTF-8 error message, then write this buffer index as little-endian to the + /// memory of the WebAssembly indicated by `error_buffer_index_ptr`. The Rust code will call + /// [`buffer_size`] and [`buffer_copy`] in order to obtain the content of this buffer. The + /// buffer index should remain assigned and buffer alive until the next time the JavaScript + /// code retains control. Then, write at location `error_buffer_index_ptr + 4` a `1` if the + /// error is caused by the address being forbidden or unsupported, and `0` otherwise. If no + /// error happens, nothing should be written to `error_buffer_index_ptr`. /// /// At any time, a connection can be in one of the three following states: /// @@ -205,7 +222,12 @@ extern "C" { /// multiplexing are handled internally by smoldot. Multi-stream connections open and close /// streams over time using [`connection_stream_opened`] and [`stream_reset`], and the /// encryption and multiplexing are handled by the user of these bindings. - pub fn connection_new(id: u32, addr_ptr: u32, addr_len: u32, error_ptr_ptr: u32) -> u32; + pub fn connection_new( + id: u32, + addr_ptr: u32, + addr_len: u32, + error_buffer_index_ptr: u32, + ) -> u32; /// Abruptly close a connection previously initialized with [`connection_new`]. /// @@ -359,40 +381,22 @@ pub extern "C" fn start_shutdown() { super::advance_execution(); } -/// Allocates a buffer of the given length, with an alignment of 1. -/// -/// This must be used in the context of [`add_chain`] and other functions that similarly require -/// passing data of variable length. -/// -/// > **Note**: If using JavaScript as the host, you likely need to perform `>>> 0` on the return -/// > value. See the module-level documentation. -#[no_mangle] -pub extern "C" fn alloc(len: u32) -> u32 { - let len = usize::try_from(len).unwrap(); - let mut vec = Vec::::with_capacity(len); - unsafe { - vec.set_len(len); - } - let ptr: *mut [u8] = Box::into_raw(vec.into_boxed_slice()); - u32::try_from(ptr as *mut u8 as usize).unwrap() -} - /// Adds a chain to the client. The client will try to stay connected and synchronize this chain. /// -/// Use [`alloc`] to allocate a buffer for the spec and the database of the chain that needs to -/// be started. Write the chain spec and database content in these buffers as UTF-8. Then, pass -/// the pointers and lengths (in bytes) as parameter to this function. +/// Assign a so-called "buffer index" (a `u32`) representing the chain specification, database +/// content, and list of potential relay chains, then provide these buffer indices to the function. +/// The Rust code will call [`buffer_size`] and [`buffer_copy`] in order to obtain the content of +/// these buffers. The buffer indices can be de-assigned and buffers destroyed once this function +/// returns. +/// +/// The content of the chain specification and database content must be in UTF-8. /// /// > **Note**: The database content is an opaque string that can be obtained by calling /// > the `chainHead_unstable_finalizedDatabase` JSON-RPC function. /// -/// Similarly, use [`alloc`] to allocate a buffer containing a list of 32-bits-little-endian chain -/// ids. Pass the pointer and number of chain ids (*not* length in bytes of the buffer) to this -/// function. If the chain specification refer to a parachain, these chain ids are the ones that -/// will be looked up to find the corresponding relay chain. -/// -/// These three buffers **must** have been allocated with [`alloc`]. They are freed when this -/// function is called, even if an error code is returned. +/// The list of potential relay chains is a buffer containing a list of 32-bits-little-endian chain +/// ids. If the chain specification refer to a parachain, these chain ids are the ones that will be +/// looked up to find the corresponding relay chain. /// /// If `json_rpc_running` is 0, then no JSON-RPC service will be started and it is forbidden to /// send JSON-RPC requests targeting this chain. This can be used to save up resources. @@ -404,22 +408,16 @@ pub extern "C" fn alloc(len: u32) -> u32 { /// message. #[no_mangle] pub extern "C" fn add_chain( - chain_spec_pointer: u32, - chain_spec_len: u32, - database_content_pointer: u32, - database_content_len: u32, + chain_spec_buffer_index: u32, + database_content_buffer_index: u32, json_rpc_running: u32, - potential_relay_chains_ptr: u32, - potential_relay_chains_len: u32, + potential_relay_chains_buffer_index: u32, ) -> u32 { let success_code = super::add_chain( - chain_spec_pointer, - chain_spec_len, - database_content_pointer, - database_content_len, + get_buffer(chain_spec_buffer_index), + get_buffer(database_content_buffer_index), json_rpc_running, - potential_relay_chains_ptr, - potential_relay_chains_len, + get_buffer(potential_relay_chains_buffer_index), ); super::advance_execution(); success_code @@ -472,8 +470,10 @@ pub extern "C" fn chain_error_ptr(chain_id: u32) -> u32 { /// format of the JSON-RPC requests and notifications is described in /// [the standard JSON-RPC 2.0 specification](https://www.jsonrpc.org/specification). /// -/// The buffer passed as parameter **must** have been allocated with [`alloc`]. It is freed when -/// this function is called. +/// Assign a so-called "buffer index" (a `u32`) representing the buffer containing the UTF-8 +/// request, then provide this buffer index to the function. The Rust code will call +/// [`buffer_size`] and [`buffer_copy`] in order to obtain the content of this buffer. The buffer +/// index can be de-assigned and buffer destroyed once this function returns. /// /// Responses and notifications are notified using [`json_rpc_responses_non_empty`], and can /// be read with [`json_rpc_responses_peek`]. @@ -488,8 +488,8 @@ pub extern "C" fn chain_error_ptr(chain_id: u32) -> u32 { /// one. /// #[no_mangle] -pub extern "C" fn json_rpc_send(text_ptr: u32, text_len: u32, chain_id: u32) -> u32 { - let success_code = super::json_rpc_send(text_ptr, text_len, chain_id); +pub extern "C" fn json_rpc_send(text_buffer_index: u32, chain_id: u32) -> u32 { + let success_code = super::json_rpc_send(get_buffer(text_buffer_index), chain_id); super::advance_execution(); success_code } @@ -550,9 +550,8 @@ pub extern "C" fn timer_finished(timer_id: u32) { /// /// See also [`connection_new`]. /// -/// When in the `Open` state, the connection can receive messages. When a message is received, -/// [`alloc`] must be called in order to allocate memory for this message, then -/// [`stream_message`] must be called with the pointer returned by [`alloc`]. +/// When in the `Open` state, the connection can receive messages. Use [`stream_message`] in order +/// to provide to the Rust code the messages received by the connection. /// /// The `handshake_ty` parameter indicates the type of handshake. It must always be 0 at the /// moment, indicating a multistream-select+Noise+Yamux handshake. @@ -582,25 +581,19 @@ pub extern "C" fn connection_open_single_stream( /// /// See also [`connection_new`]. /// -/// When in the `Open` state, the connection can receive messages. When a message is received, -/// [`alloc`] must be called in order to allocate memory for this message, then -/// [`stream_message`] must be called with the pointer returned by [`alloc`]. +/// Assign a so-called "buffer index" (a `u32`) representing the buffer containing the handshake +/// type, then provide this buffer index to the function. The Rust code will call [`buffer_size`] +/// and [`buffer_copy`] in order to obtain the content of this buffer. The buffer index can be +/// de-assigned and buffer destroyed once this function returns. /// -/// A "handshake type" must be provided. To do so, allocate a buffer with [`alloc`] and pass a -/// pointer to it. This buffer is freed when this function is called. /// The buffer must contain a single 0 byte (indicating WebRTC), followed with the multihash /// representation of the hash of the local node's TLS certificate, followed with the multihash /// representation of the hash of the remote node's TLS certificate. #[no_mangle] -pub extern "C" fn connection_open_multi_stream( - connection_id: u32, - handshake_ty_ptr: u32, - handshake_ty_len: u32, -) { +pub extern "C" fn connection_open_multi_stream(connection_id: u32, handshake_ty_buffer_index: u32) { crate::platform::connection_open_multi_stream( connection_id, - handshake_ty_ptr, - handshake_ty_len, + get_buffer(handshake_ty_buffer_index), ); super::advance_execution(); } @@ -608,17 +601,19 @@ pub extern "C" fn connection_open_multi_stream( /// Notify of a message being received on the stream. The connection associated with that stream /// (and, in the case of a multi-stream connection, the stream itself) must be in the `Open` state. /// +/// Assign a so-called "buffer index" (a `u32`) representing the buffer containing the message, +/// then provide this buffer index to the function. The Rust code will call [`buffer_size`] and +/// [`buffer_copy`] in order to obtain the content of this buffer. The buffer index can be +/// de-assigned and buffer destroyed once this function returns. +/// /// If `connection_id` is a single-stream connection, then the value of `stream_id` is ignored. /// If `connection_id` is a multi-stream connection, then `stream_id` corresponds to the stream /// on which the data was received, as was provided to [`connection_stream_opened`]. /// /// See also [`connection_open_single_stream`] and [`connection_open_multi_stream`]. -/// -/// The buffer **must** have been allocated with [`alloc`]. It is freed when this function is -/// called. #[no_mangle] -pub extern "C" fn stream_message(connection_id: u32, stream_id: u32, ptr: u32, len: u32) { - crate::platform::stream_message(connection_id, stream_id, ptr, len); +pub extern "C" fn stream_message(connection_id: u32, stream_id: u32, buffer_index: u32) { + crate::platform::stream_message(connection_id, stream_id, get_buffer(buffer_index)); super::advance_execution(); } @@ -667,13 +662,15 @@ pub extern "C" fn connection_stream_opened( /// Must only be called once per connection object. /// Must never be called if [`reset_connection`] has been called on that object in the past. /// -/// Must be passed a UTF-8 string indicating the reason for closing. The buffer **must** have -/// been allocated with [`alloc`]. It is freed when this function is called. +/// Assign a so-called "buffer index" (a `u32`) representing the buffer containing the UTF-8 +/// reason for closing, then provide this buffer index to the function. The Rust code will call +/// [`buffer_size`] and [`buffer_copy`] in order to obtain the content of this buffer. The buffer +/// index can be de-assigned and buffer destroyed once this function returns. /// /// See also [`connection_new`]. #[no_mangle] -pub extern "C" fn connection_reset(connection_id: u32, ptr: u32, len: u32) { - crate::platform::connection_reset(connection_id, ptr, len); +pub extern "C" fn connection_reset(connection_id: u32, buffer_index: u32) { + crate::platform::connection_reset(connection_id, get_buffer(buffer_index)); super::advance_execution(); } @@ -693,3 +690,19 @@ pub extern "C" fn stream_reset(connection_id: u32, stream_id: u32) { crate::platform::stream_reset(connection_id, stream_id); super::advance_execution(); } + +pub(crate) fn get_buffer(buffer_index: u32) -> Vec { + unsafe { + let len = usize::try_from(buffer_size(buffer_index)).unwrap(); + + // TODO: consider rewriting this in a better way after all the currently unstable functions are stable: https://github.com/rust-lang/rust/issues/63291 + let mut buffer = Vec::>::with_capacity(len); + buffer_copy( + buffer_index, + buffer.spare_capacity_mut().as_mut_ptr() as usize as u32, + ); + buffer.set_len(len); + + mem::transmute::>, Vec>(buffer) + } +} diff --git a/wasm-node/rust/src/lib.rs b/wasm-node/rust/src/lib.rs index 668cb00179..b1206208f0 100644 --- a/wasm-node/rust/src/lib.rs +++ b/wasm-node/rust/src/lib.rs @@ -24,7 +24,7 @@ use core::{ cmp::Ordering, ops::{Add, Sub}, pin::Pin, - slice, str, + str, sync::atomic, time::Duration, }; @@ -154,13 +154,10 @@ fn start_shutdown() { } fn add_chain( - chain_spec_pointer: u32, - chain_spec_len: u32, - database_content_pointer: u32, - database_content_len: u32, + chain_spec: Vec, + database_content: Vec, json_rpc_running: u32, - potential_relay_chains_ptr: u32, - potential_relay_chains_len: u32, + potential_relay_chains: Vec, ) -> u32 { let mut client_lock = CLIENT.lock().unwrap(); @@ -182,43 +179,10 @@ fn add_chain( return u32::try_from(chain_id).unwrap(); } - // Retrieve the chain spec parameter passed through the FFI layer. - let chain_spec: Box<[u8]> = { - let chain_spec_pointer = usize::try_from(chain_spec_pointer).unwrap(); - let chain_spec_len = usize::try_from(chain_spec_len).unwrap(); - unsafe { - Box::from_raw(slice::from_raw_parts_mut( - chain_spec_pointer as *mut u8, - chain_spec_len, - )) - } - }; - - // Retrieve the database content parameter passed through the FFI layer. - let database_content: Box<[u8]> = { - let database_content_pointer = usize::try_from(database_content_pointer).unwrap(); - let database_content_len = usize::try_from(database_content_len).unwrap(); - unsafe { - Box::from_raw(slice::from_raw_parts_mut( - database_content_pointer as *mut u8, - database_content_len, - )) - } - }; - // Retrieve the potential relay chains parameter passed through the FFI layer. let potential_relay_chains: Vec<_> = { - let allowed_relay_chains_ptr = usize::try_from(potential_relay_chains_ptr).unwrap(); - let allowed_relay_chains_len = usize::try_from(potential_relay_chains_len).unwrap(); - - let raw_data = unsafe { - Box::from_raw(slice::from_raw_parts_mut( - allowed_relay_chains_ptr as *mut u8, - allowed_relay_chains_len * 4, - )) - }; - - raw_data + assert_eq!(potential_relay_chains.len() % 4, 0); + potential_relay_chains .chunks(4) .map(|c| u32::from_le_bytes(<[u8; 4]>::try_from(c).unwrap())) .filter_map(|c| { @@ -248,8 +212,10 @@ fn add_chain( .smoldot .add_chain(smoldot_light::AddChainConfig { user_data: (), - specification: str::from_utf8(&chain_spec).unwrap(), - database_content: str::from_utf8(&database_content).unwrap(), + specification: str::from_utf8(&chain_spec) + .unwrap_or_else(|_| panic!("non-utf8 chain spec")), + database_content: str::from_utf8(&database_content) + .unwrap_or_else(|_| panic!("non-utf8 database content")), disable_json_rpc: json_rpc_running == 0, potential_relay_chains: potential_relay_chains.into_iter(), }) { @@ -402,15 +368,10 @@ fn chain_error_ptr(chain_id: u32) -> u32 { } } -fn json_rpc_send(ptr: u32, len: u32, chain_id: u32) -> u32 { - let json_rpc_request: Box<[u8]> = { - let ptr = usize::try_from(ptr).unwrap(); - let len = usize::try_from(len).unwrap(); - unsafe { Box::from_raw(slice::from_raw_parts_mut(ptr as *mut u8, len)) } - }; - +fn json_rpc_send(json_rpc_request: Vec, chain_id: u32) -> u32 { // As mentioned in the documentation, the bytes *must* be valid UTF-8. - let json_rpc_request: String = String::from_utf8(json_rpc_request.into()).unwrap(); + let json_rpc_request: String = String::from_utf8(json_rpc_request.into()) + .unwrap_or_else(|_| panic!("non-UTF-8 JSON-RPC request")); let mut client_lock = CLIENT.lock().unwrap(); let client_chain_id = match client_lock diff --git a/wasm-node/rust/src/platform.rs b/wasm-node/rust/src/platform.rs index 26aa15a728..a7660e9d61 100644 --- a/wasm-node/rust/src/platform.rs +++ b/wasm-node/rust/src/platform.rs @@ -20,7 +20,7 @@ use crate::{bindings, timers::Delay}; use smoldot::libp2p::multihash; use smoldot_light::platform::{ConnectError, PlatformSubstreamDirection}; -use core::{mem, pin, slice, str, task, time::Duration}; +use core::{mem, pin, str, task, time::Duration}; use futures::prelude::*; use std::{ collections::{BTreeMap, VecDeque}, @@ -106,30 +106,24 @@ impl smoldot_light::platform::Platform for Platform { let connection_id = lock.next_connection_id; lock.next_connection_id += 1; - let mut error_ptr = [0u8; 9]; + let mut error_buffer_index = [0u8; 5]; let ret_code = unsafe { bindings::connection_new( connection_id, u32::try_from(url.as_bytes().as_ptr() as usize).unwrap(), u32::try_from(url.as_bytes().len()).unwrap(), - u32::try_from(&mut error_ptr as *mut [u8; 9] as usize).unwrap(), + u32::try_from(&mut error_buffer_index as *mut [u8; 5] as usize).unwrap(), ) }; let result = if ret_code != 0 { - let ptr = u32::from_le_bytes(<[u8; 4]>::try_from(&error_ptr[0..4]).unwrap()); - let len = u32::from_le_bytes(<[u8; 4]>::try_from(&error_ptr[4..8]).unwrap()); - let error_message: Box<[u8]> = unsafe { - Box::from_raw(slice::from_raw_parts_mut( - usize::try_from(ptr).unwrap() as *mut u8, - usize::try_from(len).unwrap(), - )) - }; - + let error_message = bindings::get_buffer(u32::from_le_bytes( + <[u8; 4]>::try_from(&error_buffer_index[0..4]).unwrap(), + )); Err(ConnectError { message: str::from_utf8(&error_message).unwrap().to_owned(), - is_bad_addr: error_ptr[8] != 0, + is_bad_addr: error_buffer_index[4] != 0, }) } else { let _prev_value = lock.connections.insert( @@ -693,22 +687,7 @@ pub(crate) fn connection_open_single_stream( connection.something_happened.notify(usize::max_value()); } -pub(crate) fn connection_open_multi_stream( - connection_id: u32, - handshake_ty_ptr: u32, - handshake_ty_len: u32, -) { - let handshake_ty: Box<[u8]> = { - let handshake_ty_ptr = usize::try_from(handshake_ty_ptr).unwrap(); - let handshake_ty_len = usize::try_from(handshake_ty_len).unwrap(); - unsafe { - Box::from_raw(slice::from_raw_parts_mut( - handshake_ty_ptr as *mut u8, - handshake_ty_len, - )) - } - }; - +pub(crate) fn connection_open_multi_stream(connection_id: u32, handshake_ty: Vec) { let (_, (local_tls_certificate_multihash, remote_tls_certificate_multihash)) = nom::sequence::preceded( nom::bytes::complete::tag::<_, _, nom::error::Error<&[u8]>>(&[0]), @@ -777,7 +756,7 @@ pub(crate) fn stream_writable_bytes(connection_id: u32, stream_id: u32, bytes: u stream.something_happened.notify(usize::max_value()); } -pub(crate) fn stream_message(connection_id: u32, stream_id: u32, ptr: u32, len: u32) { +pub(crate) fn stream_message(connection_id: u32, stream_id: u32, message: Vec) { let mut lock = STATE.try_lock().unwrap(); let connection = lock.connections.get_mut(&connection_id).unwrap(); @@ -796,13 +775,7 @@ pub(crate) fn stream_message(connection_id: u32, stream_id: u32, ptr: u32, len: .unwrap(); debug_assert!(!stream.reset); - let ptr = usize::try_from(ptr).unwrap(); - let len_usize = usize::try_from(len).unwrap(); - - TOTAL_BYTES_RECEIVED.fetch_add(u64::from(len), Ordering::Relaxed); - - let message: Box<[u8]> = - unsafe { Box::from_raw(slice::from_raw_parts_mut(ptr as *mut u8, len_usize)) }; + TOTAL_BYTES_RECEIVED.fetch_add(u64::try_from(message.len()).unwrap(), Ordering::Relaxed); // Ignore empty message to avoid all sorts of problems. if message.is_empty() { @@ -830,7 +803,7 @@ pub(crate) fn stream_message(connection_id: u32, stream_id: u32, ptr: u32, len: } stream.messages_queue_total_size += message.len(); - stream.messages_queue.push_back(message); + stream.messages_queue.push_back(message.into_boxed_slice()); stream.something_happened.notify(usize::max_value()); } @@ -880,7 +853,7 @@ pub(crate) fn connection_stream_opened( } } -pub(crate) fn connection_reset(connection_id: u32, ptr: u32, len: u32) { +pub(crate) fn connection_reset(connection_id: u32, message: Vec) { let mut lock = STATE.try_lock().unwrap(); let connection = lock.connections.get_mut(&connection_id).unwrap(); @@ -896,13 +869,9 @@ pub(crate) fn connection_reset(connection_id: u32, ptr: u32, len: u32) { connection.inner = ConnectionInner::Reset { connection_handles_alive, - message: { - let ptr = usize::try_from(ptr).unwrap(); - let len = usize::try_from(len).unwrap(); - let message: Box<[u8]> = - unsafe { Box::from_raw(slice::from_raw_parts_mut(ptr as *mut u8, len)) }; - str::from_utf8(&message).unwrap().to_owned() - }, + message: str::from_utf8(&message) + .unwrap_or_else(|_| panic!("non-UTF-8 message")) + .to_owned(), }; connection.something_happened.notify(usize::max_value());