Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle c++ exception in TSFN callback #1345

Merged
merged 3 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions napi-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ struct ThreadSafeFinalize {
template <typename ContextType, typename DataType, typename CallJs, CallJs call>
inline typename std::enable_if<call != static_cast<CallJs>(nullptr)>::type
CallJsWrapper(napi_env env, napi_value jsCallback, void* context, void* data) {
call(env,
Function(env, jsCallback),
static_cast<ContextType*>(context),
static_cast<DataType*>(data));
details::WrapVoidCallback([&]() {
call(env,
Function(env, jsCallback),
static_cast<ContextType*>(context),
static_cast<DataType*>(data));
});
}

template <typename ContextType, typename DataType, typename CallJs, CallJs call>
Expand All @@ -275,9 +277,11 @@ CallJsWrapper(napi_env env,
napi_value jsCallback,
void* /*context*/,
void* /*data*/) {
if (jsCallback != nullptr) {
Function(env, jsCallback).Call(0, nullptr);
}
details::WrapVoidCallback([&]() {
if (jsCallback != nullptr) {
Function(env, jsCallback).Call(0, nullptr);
}
});
}

#if NAPI_VERSION > 4
Expand Down Expand Up @@ -6135,13 +6139,15 @@ inline void ThreadSafeFunction::CallJS(napi_env env,
return;
}

if (data != nullptr) {
auto* callbackWrapper = static_cast<CallbackWrapper*>(data);
(*callbackWrapper)(env, Function(env, jsCallback));
delete callbackWrapper;
} else if (jsCallback != nullptr) {
Function(env, jsCallback).Call({});
}
details::WrapVoidCallback([&]() {
if (data != nullptr) {
auto* callbackWrapper = static_cast<CallbackWrapper*>(data);
(*callbackWrapper)(env, Function(env, jsCallback));
delete callbackWrapper;
} else if (jsCallback != nullptr) {
Function(env, jsCallback).Call({});
}
});
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 6 additions & 0 deletions test/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ Object InitPromise(Env env);
Object InitRunScript(Env env);
#if (NAPI_VERSION > 3)
Object InitThreadSafeFunctionCtx(Env env);
Object InitThreadSafeFunctionException(Env env);
Object InitThreadSafeFunctionExistingTsfn(Env env);
Object InitThreadSafeFunctionPtr(Env env);
Object InitThreadSafeFunctionSum(Env env);
Object InitThreadSafeFunctionUnref(Env env);
Object InitThreadSafeFunction(Env env);
Object InitTypedThreadSafeFunctionCtx(Env env);
Object InitTypedThreadSafeFunctionException(Env env);
Object InitTypedThreadSafeFunctionExistingTsfn(Env env);
Object InitTypedThreadSafeFunctionPtr(Env env);
Object InitTypedThreadSafeFunctionSum(Env env);
Expand Down Expand Up @@ -139,6 +141,8 @@ Object Init(Env env, Object exports) {
exports.Set("symbol", InitSymbol(env));
#if (NAPI_VERSION > 3)
exports.Set("threadsafe_function_ctx", InitThreadSafeFunctionCtx(env));
exports.Set("threadsafe_function_exception",
InitThreadSafeFunctionException(env));
exports.Set("threadsafe_function_existing_tsfn",
InitThreadSafeFunctionExistingTsfn(env));
exports.Set("threadsafe_function_ptr", InitThreadSafeFunctionPtr(env));
Expand All @@ -147,6 +151,8 @@ Object Init(Env env, Object exports) {
exports.Set("threadsafe_function", InitThreadSafeFunction(env));
exports.Set("typed_threadsafe_function_ctx",
InitTypedThreadSafeFunctionCtx(env));
exports.Set("typed_threadsafe_function_exception",
InitTypedThreadSafeFunctionException(env));
exports.Set("typed_threadsafe_function_existing_tsfn",
InitTypedThreadSafeFunctionExistingTsfn(env));
exports.Set("typed_threadsafe_function_ptr",
Expand Down
2 changes: 2 additions & 0 deletions test/binding.gyp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
'run_script.cc',
'symbol.cc',
'threadsafe_function/threadsafe_function_ctx.cc',
'threadsafe_function/threadsafe_function_exception.cc',
'threadsafe_function/threadsafe_function_existing_tsfn.cc',
'threadsafe_function/threadsafe_function_ptr.cc',
'threadsafe_function/threadsafe_function_sum.cc',
'threadsafe_function/threadsafe_function_unref.cc',
'threadsafe_function/threadsafe_function.cc',
'type_taggable.cc',
'typed_threadsafe_function/typed_threadsafe_function_ctx.cc',
'typed_threadsafe_function/typed_threadsafe_function_exception.cc',
'typed_threadsafe_function/typed_threadsafe_function_existing_tsfn.cc',
'typed_threadsafe_function/typed_threadsafe_function_ptr.cc',
'typed_threadsafe_function/typed_threadsafe_function_sum.cc',
Expand Down
50 changes: 50 additions & 0 deletions test/threadsafe_function/threadsafe_function_exception.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <cstdlib>
#include "napi.h"
#include "test_helper.h"

#if (NAPI_VERSION > 3)

using namespace Napi;

namespace {

void CallJS(napi_env env, napi_value /* callback */, void* /*data*/) {
Napi::Error error = Napi::Error::New(env, "test-from-native");
NAPI_THROW_VOID(error);
}

void TestCall(const CallbackInfo& info) {
Napi::Env env = info.Env();

ThreadSafeFunction wrapped =
ThreadSafeFunction::New(env,
info[0].As<Napi::Function>(),
Object::New(env),
String::New(env, "Test"),
0,
1);
wrapped.BlockingCall(static_cast<void*>(nullptr));
wrapped.Release();
}

void TestCallWithNativeCallback(const CallbackInfo& info) {
Napi::Env env = info.Env();

ThreadSafeFunction wrapped = ThreadSafeFunction::New(
env, Napi::Function(), Object::New(env), String::New(env, "Test"), 0, 1);
wrapped.BlockingCall(static_cast<void*>(nullptr), CallJS);
wrapped.Release();
}

} // namespace

Object InitThreadSafeFunctionException(Env env) {
Object exports = Object::New(env);
exports["testCall"] = Function::New(env, TestCall);
exports["testCallWithNativeCallback"] =
Function::New(env, TestCallWithNativeCallback);

return exports;
}

#endif
55 changes: 55 additions & 0 deletions test/threadsafe_function/threadsafe_function_exception.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
'use strict';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For the tests, you could use common.runTestInChildProcess(). You need to extend it to also accept Node.js command line arguments, but I think it might make the code a little cleaner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks for the suggestion! It's definitely an improvement after #1325 landed.


const assert = require('assert');
const common = require('../common');
const { spawnSync } = require('../napi_child');

function test (bindingPath) {
const { status } = spawnSync(
process.execPath,
[
'--force-node-api-uncaught-exceptions-policy=true',
__filename,
'child',
bindingPath
],
{ stdio: 'inherit' }
);

assert.strictEqual(status, 0);
}

if (process.argv[2] === 'child') {
child(process.argv[3])
.catch(err => {
process.exitCode = 1;
console.error(err);
});
} else {
module.exports = common.runTestWithBindingPath(test);
}

async function child (bindingPath) {
const binding = require(bindingPath);
const { testCall, testCallWithNativeCallback } = binding.threadsafe_function_exception;

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test');
resolve();
}, 1));

testCall(common.mustCall(() => {
throw new Error('test');
}, 1));
});

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test-from-native');
resolve();
}, 1));

testCallWithNativeCallback();
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <cstdlib>
#include "napi.h"
#include "test_helper.h"

#if (NAPI_VERSION > 3)

using namespace Napi;

namespace {

void CallJS(Napi::Env env,
Napi::Function /* callback */,
std::nullptr_t* /* context */,
void* /*data*/) {
Napi::Error error = Napi::Error::New(env, "test-from-native");
NAPI_THROW_VOID(error);
}

using TSFN = TypedThreadSafeFunction<std::nullptr_t, void, CallJS>;

void TestCall(const CallbackInfo& info) {
Napi::Env env = info.Env();

TSFN wrapped = TSFN::New(
env, Napi::Function(), Object::New(env), String::New(env, "Test"), 0, 1);
wrapped.BlockingCall(static_cast<void*>(nullptr));
wrapped.Release();
}

} // namespace

Object InitTypedThreadSafeFunctionException(Env env) {
Object exports = Object::New(env);
exports["testCall"] = Function::New(env, TestCall);

return exports;
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
'use strict';

const assert = require('assert');
const common = require('../common');

const { spawnSync } = require('../napi_child');

function test (bindingPath) {
const { status } = spawnSync(
process.execPath,
[
'--force-node-api-uncaught-exceptions-policy=true',
__filename,
'child',
bindingPath
],
{ stdio: 'inherit' }
);

assert.strictEqual(status, 0);
}

if (process.argv[2] === 'child') {
child(process.argv[3])
.catch(err => {
process.exitCode = 1;
console.error(err);
});
} else {
module.exports = common.runTestWithBindingPath(test);
}

async function child (bindingPath) {
const binding = require(bindingPath);
const { testCall } = binding.typed_threadsafe_function_exception;

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test-from-native');
resolve();
}, 1));

testCall();
});
}