Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the way buffers are passed to the wasm-node #396

Merged
merged 5 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions wasm-node/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Fixed

- Fix a potential undefined behavior in the way the Rust and JavaScript communicate. ([#396](https://github.com/smol-dot/smoldot/pull/396))
- Properly check whether Yamux substream IDs allocated by the remote are valid. ([#383](https://github.com/smol-dot/smoldot/pull/383))
- Fix the size of the data of Yamux frames with the `SYN` flag not being verified against the allowed credits. ([#383](https://github.com/smol-dot/smoldot/pull/383))
- Fix Yamux repeatedly sending empty data frames when the allowed window size is 0. ([#383](https://github.com/smol-dot/smoldot/pull/383))
Expand Down
61 changes: 39 additions & 22 deletions wasm-node/javascript/src/instance/bindings-smoldot-light.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ import type { SmoldotWasmInstance } from './bindings.js';
export interface Config {
instance?: SmoldotWasmInstance,

/**
* Array used to store the buffers provided to the Rust code.
*
* When `buffer_size` or `buffer_index` are called, the buffer is found here.
*/
bufferIndices: Array<Uint8Array>,

/**
* Returns the number of milliseconds since an arbitrary epoch.
*/
Expand Down Expand Up @@ -254,6 +261,19 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
config.onPanic(message);
},

buffer_size: (bufferIndex: number) => {
const buf = config.bufferIndices[bufferIndex]!;
return buf.byteLength;
},

buffer_copy: (bufferIndex: number, targetPtr: number) => {
const instance = config.instance!;
targetPtr = targetPtr >>> 0;

const buf = config.bufferIndices[bufferIndex]!;
new Uint8Array(instance.exports.memory.buffer).set(buf, targetPtr);
},

// Used by the Rust side to notify that a JSON-RPC response or subscription notification
// is available in the queue of JSON-RPC responses.
json_rpc_responses_non_empty: (chainId: number) => {
Expand Down Expand Up @@ -322,12 +342,12 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,

// Must create a new connection object. This implementation stores the created object in
// `connections`.
connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorPtrPtr: number) => {
connection_new: (connectionId: number, addrPtr: number, addrLen: number, errorBufferIndexPtr: number) => {
const instance = config.instance!;

addrPtr >>>= 0;
addrLen >>>= 0;
errorPtrPtr >>>= 0;
errorBufferIndexPtr >>>= 0;

if (!!connections[connectionId]) {
throw new Error("internal error: connection already allocated");
Expand All @@ -350,13 +370,13 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
break
}
case 'multi-stream': {
const bufferLen = 1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length;
const ptr = instance.exports.alloc(bufferLen) >>> 0;
const mem = new Uint8Array(instance.exports.memory.buffer);
buffer.writeUInt8(mem, ptr, 0);
mem.set(info.localTlsCertificateMultihash, ptr + 1)
mem.set(info.remoteTlsCertificateMultihash, ptr + 1 + info.localTlsCertificateMultihash.length)
instance.exports.connection_open_multi_stream(connectionId, ptr, bufferLen);
const handshakeTy = new Uint8Array(1 + info.localTlsCertificateMultihash.length + info.remoteTlsCertificateMultihash.length);
buffer.writeUInt8(handshakeTy, 0, 0);
handshakeTy.set(info.localTlsCertificateMultihash, 1)
handshakeTy.set(info.remoteTlsCertificateMultihash, 1 + info.localTlsCertificateMultihash.length)
config.bufferIndices[0] = handshakeTy;
instance.exports.connection_open_multi_stream(connectionId, 0);
delete config.bufferIndices[0]
break
}
}
Expand All @@ -365,10 +385,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
onConnectionReset: (message: string) => {
if (killedTracked.killed) return;
try {
const encoded = new TextEncoder().encode(message)
const ptr = instance.exports.alloc(encoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(encoded, ptr);
instance.exports.connection_reset(connectionId, ptr, encoded.length);
config.bufferIndices[0] = new TextEncoder().encode(message);
instance.exports.connection_reset(connectionId, 0);
delete config.bufferIndices[0]
} catch(_error) {}
},
onWritableBytes: (numExtra, streamId) => {
Expand All @@ -384,9 +403,9 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
onMessage: (message: Uint8Array, streamId?: number) => {
if (killedTracked.killed) return;
try {
const ptr = instance.exports.alloc(message.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(message, ptr)
instance.exports.stream_message(connectionId, streamId || 0, ptr, message.length);
config.bufferIndices[0] = message;
instance.exports.stream_message(connectionId, streamId || 0, 0);
delete config.bufferIndices[0]
} catch(_error) {}
},
onStreamOpened: (streamId: number, direction: 'inbound' | 'outbound', initialWritableBytes) => {
Expand Down Expand Up @@ -418,13 +437,11 @@ export default function (config: Config): { imports: WebAssembly.ModuleImports,
if (error instanceof Error) {
errorStr = error.toString();
}

const mem = new Uint8Array(instance.exports.memory.buffer);
const encoded = new TextEncoder().encode(errorStr)
const ptr = instance.exports.alloc(encoded.length) >>> 0;
mem.set(encoded, ptr);
buffer.writeUInt32LE(mem, errorPtrPtr, ptr);
buffer.writeUInt32LE(mem, errorPtrPtr + 4, encoded.length);
buffer.writeUInt8(mem, errorPtrPtr + 8, isBadAddress ? 1 : 0);
config.bufferIndices[0] = new TextEncoder().encode(errorStr)
buffer.writeUInt32LE(mem, errorBufferIndexPtr, 0);
buffer.writeUInt8(mem, errorBufferIndexPtr + 4, isBadAddress ? 1 : 0);
return 1;
}
},
Expand Down
11 changes: 5 additions & 6 deletions wasm-node/javascript/src/instance/bindings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,21 @@ export interface SmoldotWasmExports extends WebAssembly.Exports {
init: (maxLogLevel: number, enableCurrentTask: number, cpuRateLimit: number, periodicallyYield: number) => void,
set_periodically_yield: (periodicallyYield: number) => void,
start_shutdown: () => void,
alloc: (len: number) => number,
add_chain: (chainSpecPointer: number, chainSpecLen: number, databaseContentPointer: number, databaseContentLen: number, jsonRpcRunning: number, potentialRelayChainsPtr: number, potentialRelayChainsLen: number) => number;
add_chain: (chainSpecBufferIndex: number, databaseContentBufferIndex: number, jsonRpcRunning: number, potentialRelayChainsBufferIndex: number) => number;
remove_chain: (chainId: number) => void,
chain_is_ok: (chainId: number) => number,
chain_error_len: (chainId: number) => number,
chain_error_ptr: (chainId: number) => number,
json_rpc_send: (textPtr: number, textLen: number, chainId: number) => number,
json_rpc_send: (textBufferIndex: number, chainId: number) => number,
json_rpc_responses_peek: (chainId: number) => number,
json_rpc_responses_pop: (chainId: number) => void,
timer_finished: (timerId: number) => void,
connection_open_single_stream: (connectionId: number, handshakeTy: number, initialWritableBytes: number, writeClosable: number) => void,
connection_open_multi_stream: (connectionId: number, handshakeTyPtr: number, handshakeTyLen: number) => void,
connection_open_multi_stream: (connectionId: number, handshakeTyBufferIndex: number) => void,
stream_writable_bytes: (connectionId: number, streamId: number, numBytes: number) => void,
stream_message: (connectionId: number, streamId: number, ptr: number, len: number) => void,
stream_message: (connectionId: number, streamId: number, bufferIndex: number) => void,
connection_stream_opened: (connectionId: number, streamId: number, outbound: number, initialWritableBytes: number) => void,
connection_reset: (connectionId: number, ptr: number, len: number) => void,
connection_reset: (connectionId: number, bufferIndex: number) => void,
stream_reset: (connectionId: number, streamId: number) => void,
}

Expand Down
57 changes: 22 additions & 35 deletions wasm-node/javascript/src/instance/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform
// - At initialization, it is a Promise containing the Wasm VM is still initializing.
// - After the Wasm VM has finished initialization, contains the `WebAssembly.Instance` object.
//
let state: { initialized: false, promise: Promise<SmoldotWasmInstance> } | { initialized: true, instance: SmoldotWasmInstance, unregisterCallback: () => void };
let state: { initialized: false, promise: Promise<[SmoldotWasmInstance, Array<Uint8Array>]> } | { initialized: true, instance: SmoldotWasmInstance, bufferIndices: Array<Uint8Array>, unregisterCallback: () => void };

const crashError: { error?: CrashError } = {};

Expand Down Expand Up @@ -127,7 +127,7 @@ export function start(configMessage: Config, platformBindings: instance.Platform
};

state = {
initialized: false, promise: instance.startInstance(config, platformBindings).then((instance) => {
initialized: false, promise: instance.startInstance(config, platformBindings).then(([instance, bufferIndices]) => {
// `config.cpuRateLimit` is a floating point that should be between 0 and 1, while the value
// to pass as parameter must be between `0` and `2^32-1`.
// The few lines of code below should handle all possible values of `number`, including
Expand All @@ -148,22 +148,22 @@ export function start(configMessage: Config, platformBindings: instance.Platform
});
instance.exports.init(configMessage.maxLogLevel, configMessage.enableCurrentTask ? 1 : 0, cpuRateLimit, periodicallyYield ? 1 : 0);

state = { initialized: true, instance, unregisterCallback };
return instance;
state = { initialized: true, instance, bufferIndices, unregisterCallback };
return [instance, bufferIndices];
})
};

async function queueOperation<T>(operation: (instance: SmoldotWasmInstance) => T): Promise<T> {
async function queueOperation<T>(operation: (instance: SmoldotWasmInstance, bufferIndices: Array<Uint8Array>) => T): Promise<T> {
// What to do depends on the type of `state`.
// See the documentation of the `state` variable for information.
if (!state.initialized) {
// A message has been received while the Wasm VM is still initializing. Queue it for when
// initialization is over.
return state.promise.then((instance) => operation(instance))
return state.promise.then(([instance, bufferIndices]) => operation(instance, bufferIndices))

} else {
// Everything is already initialized. Process the message synchronously.
return operation(state.instance)
return operation(state.instance, state.bufferIndices)
}
}

Expand All @@ -179,10 +179,8 @@ export function start(configMessage: Config, platformBindings: instance.Platform

let retVal;
try {
const encoded = new TextEncoder().encode(request)
const ptr = state.instance.exports.alloc(encoded.length) >>> 0;
new Uint8Array(state.instance.exports.memory.buffer).set(encoded, ptr);
retVal = state.instance.exports.json_rpc_send(ptr, encoded.length, chainId) >>> 0;
state.bufferIndices[0] = new TextEncoder().encode(request)
retVal = state.instance.exports.json_rpc_send(0, chainId) >>> 0;
} catch (_error) {
console.assert(crashError.error);
throw crashError.error
Expand Down Expand Up @@ -234,38 +232,27 @@ export function start(configMessage: Config, platformBindings: instance.Platform
},

addChain: (chainSpec: string, databaseContent: string, potentialRelayChains: number[], disableJsonRpc: boolean): Promise<{ success: true, chainId: number } | { success: false, error: string }> => {
return queueOperation((instance) => {
return queueOperation((instance, bufferIndices) => {
if (crashError.error)
throw crashError.error;

try {
// Write the chain specification into memory.
const chainSpecEncoded = new TextEncoder().encode(chainSpec)
const chainSpecPtr = instance.exports.alloc(chainSpecEncoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(chainSpecEncoded, chainSpecPtr);

// Write the database content into memory.
const databaseContentEncoded = new TextEncoder().encode(databaseContent)
const databaseContentPtr = instance.exports.alloc(databaseContentEncoded.length) >>> 0;
new Uint8Array(instance.exports.memory.buffer).set(databaseContentEncoded, databaseContentPtr);

// Write the potential relay chains into memory.
const potentialRelayChainsLen = potentialRelayChains.length;
const potentialRelayChainsPtr = instance.exports.alloc(potentialRelayChainsLen * 4) >>> 0;
for (let idx = 0; idx < potentialRelayChains.length; ++idx) {
buffer.writeUInt32LE(new Uint8Array(instance.exports.memory.buffer), potentialRelayChainsPtr + idx * 4, potentialRelayChains[idx]!);
}

// `add_chain` unconditionally allocates a chain id. If an error occurs, however, this chain
// id will refer to an *erroneous* chain. `chain_is_ok` is used below to determine whether it
// has succeeeded or not.
// Note that `add_chain` properly de-allocates buffers even if it failed.
const chainId = instance.exports.add_chain(
chainSpecPtr, chainSpecEncoded.length,
databaseContentPtr, databaseContentEncoded.length,
disableJsonRpc ? 0 : 1,
potentialRelayChainsPtr, potentialRelayChainsLen
);
bufferIndices[0] = new TextEncoder().encode(chainSpec)
bufferIndices[1] = new TextEncoder().encode(databaseContent)
const potentialRelayChainsEncoded = new Uint8Array(potentialRelayChains.length * 4)
for (let idx = 0; idx < potentialRelayChains.length; ++idx) {
buffer.writeUInt32LE(potentialRelayChainsEncoded, idx * 4, potentialRelayChains[idx]!);
}
bufferIndices[2] = potentialRelayChainsEncoded
const chainId = instance.exports.add_chain(0, 1, disableJsonRpc ? 0 : 1, 2);

delete bufferIndices[0]
delete bufferIndices[1]
delete bufferIndices[2]

if (instance.exports.chain_is_ok(chainId) != 0) {
console.assert(!chains.has(chainId));
Expand Down
7 changes: 5 additions & 2 deletions wasm-node/javascript/src/instance/raw-instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export interface PlatformBindings {
connect(config: ConnectionConfig): Connection;
}

export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise<SmoldotWasmInstance> {
export async function startInstance(config: Config, platformBindings: PlatformBindings): Promise<[SmoldotWasmInstance, Array<Uint8Array>]> {
// The actual Wasm bytecode is base64-decoded then deflate-decoded from a constant found in a
// different file.
// This is suboptimal compared to using `instantiateStreaming`, but it is the most
Expand All @@ -100,8 +100,11 @@ export async function startInstance(config: Config, platformBindings: PlatformBi

let killAll: () => void;

const bufferIndices = new Array;

// Used to bind with the smoldot-light bindings. See the `bindings-smoldot-light.js` file.
const smoldotJsConfig: SmoldotBindingsConfig = {
bufferIndices,
performanceNow: platformBindings.performanceNow,
connect: platformBindings.connect,
onPanic: (message) => {
Expand Down Expand Up @@ -141,5 +144,5 @@ export async function startInstance(config: Config, platformBindings: PlatformBi
const instance = result.instance as SmoldotWasmInstance;
smoldotJsConfig.instance = instance;
wasiConfig.instance = instance;
return instance;
return [instance, bufferIndices];
}
Loading