Skip to content

Commit

Permalink
Change the way buffers are passed to the wasm-node (#396)
Browse files Browse the repository at this point in the history
* Change the way buffers are passed to the wasm-node

* Directly use Uint8Array instead of ArrayBuffer

* Free the buffers ASAP

* CHANGELOG

* Docs and spellcheck
  • Loading branch information
tomaka authored Apr 6, 2023
1 parent 42c6b75 commit e10dda4
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 243 deletions.
1 change: 1 addition & 0 deletions wasm-node/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Fixed

- Fix a potential undefined behavior in the way the Rust and JavaScript communicate. ([#396](https://github.com/smol-dot/smoldot/pull/396))
- Properly check whether Yamux substream IDs allocated by the remote are valid. ([#383](https://github.com/smol-dot/smoldot/pull/383))
- Fix the size of the data of Yamux frames with the `SYN` flag not being verified against the allowed credits. ([#383](https://github.com/smol-dot/smoldot/pull/383))
- Fix Yamux repeatedly sending empty data frames when the allowed window size is 0. ([#383](https://github.com/smol-dot/smoldot/pull/383))
Expand Down
61 changes: 39 additions & 22 deletions wasm-node/javascript/src/instance/bindings-smoldot-light.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ import type { SmoldotWasmInstance } from './bindings.js';
export interface Config {
instance?: SmoldotWasmInstance,

/**
* Array used to store the buffers provided to the Rust code.
*
* When `buffer_size` or `buffer_index` are called, the buffer is found here.
*/
bufferIndices: Array<Uint8Array>,

/**
* Returns the number of milliseconds since an arbitrary epoch.
*/
Expand Down Expand Up @@ -254,6 +261,19 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
config.onPanic(message);
},

buffer_size: (bufferIndex: number) => {
const buf = config.bufferIndices[bufferIndex]!;
return buf.byteLength;
},

buffer_copy: (bufferIndex: number, targetPtr: number) => {
const instance = config.instance!;
targetPtr = targetPtr >>> 0;

const buf = config.bufferIndices[bufferIndex]!;
new Uint8Array(instance.exports.memory.buffer).set(buf, targetPtr);
},

// Used by the Rust side to notify that a JSON-RPC response or subscription notification
// is available in the queue of JSON-RPC responses.
json_rpc_responses_non_empty: (chainId: number) => {
Expand Down Expand Up @@ -322,12 +342,12 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,

// Must create a new connection object. This implementation stores the created object in
// `connections`.
connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorPtrPtr: number) => {
connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorBufferIndexPtr: number) => {
const instance = config.instance!;

addrPtr >>>= 0;
addrLen >>>= 0;
errorPtrPtr >>>= 0;
errorBufferIndexPtr >>>= 0;

if (!!connections[connectionId]) {
throw new Error("internal error: connection already allocated");
Expand All @@ -350,13 +370,13 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
break
}
case 'multi-stream': {
const bufferLen = 1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length;
const ptr = instance.exports.alloc(bufferLen) >>> 0;
const mem = new Uint8Array(instance.exports.memory.buffer);
buffer.writeUInt8(mem, ptr, 0);
mem.set(info.localTlsCertificateMultihash, ptr + 1)
mem.set(info.remoteTlsCertificateMultihash, ptr + 1 + info.localTlsCertificateMultihash.length)
instance.exports.connection_open_multi_stream(connectionId, ptr, bufferLen);
const handshakeTy = new Uint8Array(1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length);
buffer.writeUInt8(handshakeTy, 0, 0);
handshakeTy.set(info.localTlsCertificateMultihash, 1)
handshakeTy.set(info.remoteTlsCertificateMultihash, 1 + info.localTlsCertificateMultihash.length)
config.bufferIndices[0] = handshakeTy;
instance.exports.connection_open_multi_stream(connectionId, 0);
delete config.bufferIndices[0]
break
}
}
Expand All @@ -365,10 +385,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
onConnectionReset: (message: string) => {
if (killedTracked.killed) return;
try {
const encoded = new TextEncoder().encode(message)
const ptr = instance.exports.alloc(encoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(encoded, ptr);
instance.exports.connection_reset(connectionId, ptr, encoded.length);
config.bufferIndices[0] = new TextEncoder().encode(message);
instance.exports.connection_reset(connectionId, 0);
delete config.bufferIndices[0]
} catch(_error) {}
},
onWritableBytes: (numExtra, streamId) => {
Expand All @@ -384,9 +403,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
onMessage: (message: Uint8Array, streamId?: number) => {
if (killedTracked.killed) return;
try {
const ptr = instance.exports.alloc(message.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(message, ptr)
instance.exports.stream_message(connectionId, streamId || 0, ptr, message.length);
config.bufferIndices[0] = message;
instance.exports.stream_message(connectionId, streamId || 0, 0);
delete config.bufferIndices[0]
} catch(_error) {}
},
onStreamOpened: (streamId: number, direction: 'inbound' | 'outbound', initialWritableBytes) => {
Expand Down Expand Up @@ -418,13 +437,11 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
if (error instanceof Error) {
errorStr = error.toString();
}

const mem = new Uint8Array(instance.exports.memory.buffer);
const encoded = new TextEncoder().encode(errorStr)
const ptr = instance.exports.alloc(encoded.length) >>> 0;
mem.set(encoded, ptr);
buffer.writeUInt32LE(mem, errorPtrPtr, ptr);
buffer.writeUInt32LE(mem, errorPtrPtr + 4, encoded.length);
buffer.writeUInt8(mem, errorPtrPtr + 8, isBadAddress ? 1 : 0);
config.bufferIndices[0] = new TextEncoder().encode(errorStr)
buffer.writeUInt32LE(mem, errorBufferIndexPtr, 0);
buffer.writeUInt8(mem, errorBufferIndexPtr + 4, isBadAddress ? 1 : 0);
return 1;
}
},
Expand Down
11 changes: 5 additions & 6 deletions wasm-node/javascript/src/instance/bindings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,21 @@ export interface SmoldotWasmExports extends WebAssembly.Exports {
init: (maxLogLevel: number, enableCurrentTask: number, cpuRateLimit: number, periodicallyYield: number) => void,
set_periodically_yield: (periodicallyYield: number) => void,
start_shutdown: () => void,
alloc: (len: number) => number,
add_chain: (chainSpecPointer: number, chainSpecLen: number, databaseContentPointer: number, databaseContentLen: number, jsonRpcRunning: number, potentialRelayChainsPtr: number, potentialRelayChainsLen: number) => number;
add_chain: (chainSpecBufferIndex: number, databaseContentBufferIndex: number, jsonRpcRunning: number, potentialRelayChainsBufferIndex: number) => number;
remove_chain: (chainId: number) => void,
chain_is_ok: (chainId: number) => number,
chain_error_len: (chainId: number) => number,
chain_error_ptr: (chainId: number) => number,
json_rpc_send: (textPtr: number, textLen: number, chainId: number) => number,
json_rpc_send: (textBufferIndex: number, chainId: number) => number,
json_rpc_responses_peek: (chainId: number) => number,
json_rpc_responses_pop: (chainId: number) => void,
timer_finished: (timerId: number) => void,
connection_open_single_stream: (connectionId: number, handshakeTy: number, initialWritableBytes: number, writeClosable: number) => void,
connection_open_multi_stream: (connectionId: number, handshakeTyPtr: number, handshakeTyLen: number) => void,
connection_open_multi_stream: (connectionId: number, handshakeTyBufferIndex: number) => void,
stream_writable_bytes: (connectionId: number, streamId: number, numBytes: number) => void,
stream_message: (connectionId: number, streamId: number, ptr: number, len: number) => void,
stream_message: (connectionId: number, streamId: number, bufferIndex: number) => void,
connection_stream_opened: (connectionId: number, streamId: number, outbound: number, initialWritableBytes: number) => void,
connection_reset: (connectionId: number, ptr: number, len: number) => void,
connection_reset: (connectionId: number, bufferIndex: number) => void,
stream_reset: (connectionId: number, streamId: number) => void,
}

Expand Down
57 changes: 22 additions & 35 deletions wasm-node/javascript/src/instance/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform
// - At initialization, it is a Promise containing the Wasm VM is still initializing.
// - After the Wasm VM has finished initialization, contains the `WebAssembly.Instance` object.
//
let state: { initialized: false, promise: Promise<SmoldotWasmInstance> } | { initialized: true, instance: SmoldotWasmInstance, unregisterCallback: () => void };
let state: { initialized: false, promise: Promise<[SmoldotWasmInstance, Array<Uint8Array>]> } | { initialized: true, instance: SmoldotWasmInstance, bufferIndices: Array<Uint8Array>, unregisterCallback: () => void };

const crashError: { error?: CrashError } = {};

Expand Down Expand Up @@ -127,7 +127,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform
};

state = {
initialized: false, promise: instance.startInstance(config, platformBindings).then((instance) => {
initialized: false, promise: instance.startInstance(config, platformBindings).then(([instance, bufferIndices]) => {
// `config.cpuRateLimit` is a floating point that should be between 0 and 1, while the value
// to pass as parameter must be between `0` and `2^32-1`.
// The few lines of code below should handle all possible values of `number`, including
Expand All @@ -148,22 +148,22 @@ export function start(configMessage: Config, platformBindings: instance.Platform
});
instance.exports.init(configMessage.maxLogLevel, configMessage.enableCurrentTask ? 1 : 0, cpuRateLimit, periodicallyYield ? 1 : 0);

state = { initialized: true, instance, unregisterCallback };
return instance;
state = { initialized: true, instance, bufferIndices, unregisterCallback };
return [instance, bufferIndices];
})
};

async function queueOperation<T>(operation: (instance: SmoldotWasmInstance) => T): Promise<T> {
async function queueOperation<T>(operation: (instance: SmoldotWasmInstance, bufferIndices: Array<Uint8Array>) => T): Promise<T> {
// What to do depends on the type of `state`.
// See the documentation of the `state` variable for information.
if (!state.initialized) {
// A message has been received while the Wasm VM is still initializing. Queue it for when
// initialization is over.
return state.promise.then((instance) => operation(instance))
return state.promise.then(([instance, bufferIndices]) => operation(instance, bufferIndices))

} else {
// Everything is already initialized. Process the message synchronously.
return operation(state.instance)
return operation(state.instance, state.bufferIndices)
}
}

Expand All @@ -179,10 +179,8 @@ export function start(configMessage: Config, platformBindings: instance.Platform

let retVal;
try {
const encoded = new TextEncoder().encode(request)
const ptr = state.instance.exports.alloc(encoded.length) >>> 0;
new Uint8Array(state.instance.exports.memory.buffer).set(encoded, ptr);
retVal = state.instance.exports.json_rpc_send(ptr, encoded.length, chainId) >>> 0;
state.bufferIndices[0] = new TextEncoder().encode(request)
retVal = state.instance.exports.json_rpc_send(0, chainId) >>> 0;
} catch (_error) {
console.assert(crashError.error);
throw crashError.error
Expand Down Expand Up @@ -234,38 +232,27 @@ export function start(configMessage: Config, platformBindings: instance.Platform
},

addChain: (chainSpec: string, databaseContent: string, potentialRelayChains: number[], disableJsonRpc: boolean): Promise<{ success: true, chainId: number } | { success: false, error: string }> => {
return queueOperation((instance) => {
return queueOperation((instance, bufferIndices) => {
if (crashError.error)
throw crashError.error;

try {
// Write the chain specification into memory.
const chainSpecEncoded = new TextEncoder().encode(chainSpec)
const chainSpecPtr = instance.exports.alloc(chainSpecEncoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(chainSpecEncoded, chainSpecPtr);

// Write the database content into memory.
const databaseContentEncoded = new TextEncoder().encode(databaseContent)
const databaseContentPtr = instance.exports.alloc(databaseContentEncoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(databaseContentEncoded, databaseContentPtr);

// Write the potential relay chains into memory.
const potentialRelayChainsLen = potentialRelayChains.length;
const potentialRelayChainsPtr = instance.exports.alloc(potentialRelayChainsLen * 4) >>> 0;
for (let idx = 0; idx < potentialRelayChains.length; ++idx) {
buffer.writeUInt32LE(new Uint8Array(instance.exports.memory.buffer), potentialRelayChainsPtr + idx * 4, potentialRelayChains[idx]!);
}

// `add_chain` unconditionally allocates a chain id. If an error occurs, however, this chain
// id will refer to an *erroneous* chain. `chain_is_ok` is used below to determine whether it
// has succeeeded or not.
// Note that `add_chain` properly de-allocates buffers even if it failed.
const chainId = instance.exports.add_chain(
chainSpecPtr, chainSpecEncoded.length,
databaseContentPtr, databaseContentEncoded.length,
disableJsonRpc ? 0 : 1,
potentialRelayChainsPtr, potentialRelayChainsLen
);
bufferIndices[0] = new TextEncoder().encode(chainSpec)
bufferIndices[1] = new TextEncoder().encode(databaseContent)
const potentialRelayChainsEncoded = new Uint8Array(potentialRelayChains.length * 4)
for (let idx = 0; idx < potentialRelayChains.length; ++idx) {
buffer.writeUInt32LE(potentialRelayChainsEncoded, idx * 4, potentialRelayChains[idx]!);
}
bufferIndices[2] = potentialRelayChainsEncoded
const chainId = instance.exports.add_chain(0, 1, disableJsonRpc ? 0 : 1, 2);

delete bufferIndices[0]
delete bufferIndices[1]
delete bufferIndices[2]

if (instance.exports.chain_is_ok(chainId) != 0) {
console.assert(!chains.has(chainId));
Expand Down
7 changes: 5 additions & 2 deletions wasm-node/javascript/src/instance/raw-instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export interface PlatformBindings {
connect(config: ConnectionConfig): Connection;
}

export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise<SmoldotWasmInstance> {
export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise<[SmoldotWasmInstance, Array<Uint8Array>]> {
// The actual Wasm bytecode is base64-decoded then deflate-decoded from a constant found in a
// different file.
// This is suboptimal compared to using `instantiateStreaming`, but it is the most
Expand All @@ -100,8 +100,11 @@ export async function startInstance(config: Config, platformBindings: PlatformBi

let killAll: () => void;

const bufferIndices = new Array;

// Used to bind with the smoldot-light bindings. See the `bindings-smoldot-light.js` file.
const smoldotJsConfig: SmoldotBindingsConfig = {
bufferIndices,
performanceNow: platformBindings.performanceNow,
connect: platformBindings.connect,
onPanic: (message) => {
Expand Down Expand Up @@ -141,5 +144,5 @@ export async function startInstance(config: Config, platformBindings: PlatformBi
const instance = result.instance as SmoldotWasmInstance;
smoldotJsConfig.instance = instance;
wasiConfig.instance = instance;
return instance;
return [instance, bufferIndices];
}
Loading

0 comments on commit e10dda4

Please sign in to comment.