Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle c++ exception in TSFN callback #1345

Merged
merged 3 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions napi-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ struct ThreadSafeFinalize {
template <typename ContextType, typename DataType, typename CallJs, CallJs call>
inline typename std::enable_if<call != static_cast<CallJs>(nullptr)>::type
CallJsWrapper(napi_env env, napi_value jsCallback, void* context, void* data) {
call(env,
Function(env, jsCallback),
static_cast<ContextType*>(context),
static_cast<DataType*>(data));
details::WrapVoidCallback([&]() {
call(env,
Function(env, jsCallback),
static_cast<ContextType*>(context),
static_cast<DataType*>(data));
});
}

template <typename ContextType, typename DataType, typename CallJs, CallJs call>
Expand All @@ -275,9 +277,11 @@ CallJsWrapper(napi_env env,
napi_value jsCallback,
void* /*context*/,
void* /*data*/) {
if (jsCallback != nullptr) {
Function(env, jsCallback).Call(0, nullptr);
}
details::WrapVoidCallback([&]() {
if (jsCallback != nullptr) {
Function(env, jsCallback).Call(0, nullptr);
}
});
}

#if NAPI_VERSION > 4
Expand Down Expand Up @@ -6135,13 +6139,15 @@ inline void ThreadSafeFunction::CallJS(napi_env env,
return;
}

if (data != nullptr) {
auto* callbackWrapper = static_cast<CallbackWrapper*>(data);
(*callbackWrapper)(env, Function(env, jsCallback));
delete callbackWrapper;
} else if (jsCallback != nullptr) {
Function(env, jsCallback).Call({});
}
details::WrapVoidCallback([&]() {
if (data != nullptr) {
auto* callbackWrapper = static_cast<CallbackWrapper*>(data);
(*callbackWrapper)(env, Function(env, jsCallback));
delete callbackWrapper;
} else if (jsCallback != nullptr) {
Function(env, jsCallback).Call({});
}
});
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 6 additions & 0 deletions test/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ Object InitPromise(Env env);
Object InitRunScript(Env env);
#if (NAPI_VERSION > 3)
Object InitThreadSafeFunctionCtx(Env env);
Object InitThreadSafeFunctionException(Env env);
Object InitThreadSafeFunctionExistingTsfn(Env env);
Object InitThreadSafeFunctionPtr(Env env);
Object InitThreadSafeFunctionSum(Env env);
Object InitThreadSafeFunctionUnref(Env env);
Object InitThreadSafeFunction(Env env);
Object InitTypedThreadSafeFunctionCtx(Env env);
Object InitTypedThreadSafeFunctionException(Env env);
Object InitTypedThreadSafeFunctionExistingTsfn(Env env);
Object InitTypedThreadSafeFunctionPtr(Env env);
Object InitTypedThreadSafeFunctionSum(Env env);
Expand Down Expand Up @@ -139,6 +141,8 @@ Object Init(Env env, Object exports) {
exports.Set("symbol", InitSymbol(env));
#if (NAPI_VERSION > 3)
exports.Set("threadsafe_function_ctx", InitThreadSafeFunctionCtx(env));
exports.Set("threadsafe_function_exception",
InitThreadSafeFunctionException(env));
exports.Set("threadsafe_function_existing_tsfn",
InitThreadSafeFunctionExistingTsfn(env));
exports.Set("threadsafe_function_ptr", InitThreadSafeFunctionPtr(env));
Expand All @@ -147,6 +151,8 @@ Object Init(Env env, Object exports) {
exports.Set("threadsafe_function", InitThreadSafeFunction(env));
exports.Set("typed_threadsafe_function_ctx",
InitTypedThreadSafeFunctionCtx(env));
exports.Set("typed_threadsafe_function_exception",
InitTypedThreadSafeFunctionException(env));
exports.Set("typed_threadsafe_function_existing_tsfn",
InitTypedThreadSafeFunctionExistingTsfn(env));
exports.Set("typed_threadsafe_function_ptr",
Expand Down
2 changes: 2 additions & 0 deletions test/binding.gyp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
'run_script.cc',
'symbol.cc',
'threadsafe_function/threadsafe_function_ctx.cc',
'threadsafe_function/threadsafe_function_exception.cc',
'threadsafe_function/threadsafe_function_existing_tsfn.cc',
'threadsafe_function/threadsafe_function_ptr.cc',
'threadsafe_function/threadsafe_function_sum.cc',
'threadsafe_function/threadsafe_function_unref.cc',
'threadsafe_function/threadsafe_function.cc',
'type_taggable.cc',
'typed_threadsafe_function/typed_threadsafe_function_ctx.cc',
'typed_threadsafe_function/typed_threadsafe_function_exception.cc',
'typed_threadsafe_function/typed_threadsafe_function_existing_tsfn.cc',
'typed_threadsafe_function/typed_threadsafe_function_ptr.cc',
'typed_threadsafe_function/typed_threadsafe_function_sum.cc',
Expand Down
33 changes: 33 additions & 0 deletions test/child_processes/threadsafe_function_exception.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'use strict';

const assert = require('assert');
const common = require('../common');

module.exports = {
testCall: async binding => {
const { testCall } = binding.threadsafe_function_exception;

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test');
resolve();
}, 1));

testCall(common.mustCall(() => {
throw new Error('test');
}, 1));
});
},
testCallWithNativeCallback: async binding => {
const { testCallWithNativeCallback } = binding.threadsafe_function_exception;

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test-from-native');
resolve();
}, 1));

testCallWithNativeCallback();
});
}
};
19 changes: 19 additions & 0 deletions test/child_processes/typed_threadsafe_function_exception.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
'use strict';

const assert = require('assert');
const common = require('../common');

module.exports = {
testCall: async binding => {
const { testCall } = binding.typed_threadsafe_function_exception;

await new Promise(resolve => {
process.once('uncaughtException', common.mustCall(err => {
assert.strictEqual(err.message, 'test-from-native');
resolve();
}, 1));

testCall();
});
}
};
3 changes: 2 additions & 1 deletion test/common/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,15 @@ exports.runTestWithBuildType = async function (test, buildType) {
// in the main process. Two examples are addon and addon_data, both of which
// use Napi::Env::SetInstanceData(). This helper function provides a common
// approach for running such tests.
exports.runTestInChildProcess = function ({ suite, testName, expectedStderr }) {
exports.runTestInChildProcess = function ({ suite, testName, expectedStderr, execArgv }) {
return exports.runTestWithBindingPath((bindingName) => {
return new Promise((resolve) => {
bindingName = escapeBackslashes(bindingName);
// Test suites are assumed to be located here.
const suitePath = escapeBackslashes(path.join(__dirname, '..', 'child_processes', suite));
const child = spawn(process.execPath, [
'--expose-gc',
...(execArgv ?? []),
'-e',
`require('${suitePath}').${testName}(require('${bindingName}'))`
]);
Expand Down
50 changes: 50 additions & 0 deletions test/threadsafe_function/threadsafe_function_exception.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <cstdlib>
#include "napi.h"
#include "test_helper.h"

#if (NAPI_VERSION > 3)

using namespace Napi;

namespace {

void CallJS(napi_env env, napi_value /* callback */, void* /*data*/) {
Napi::Error error = Napi::Error::New(env, "test-from-native");
NAPI_THROW_VOID(error);
}

void TestCall(const CallbackInfo& info) {
Napi::Env env = info.Env();

ThreadSafeFunction wrapped =
ThreadSafeFunction::New(env,
info[0].As<Napi::Function>(),
Object::New(env),
String::New(env, "Test"),
0,
1);
wrapped.BlockingCall(static_cast<void*>(nullptr));
wrapped.Release();
}

void TestCallWithNativeCallback(const CallbackInfo& info) {
Napi::Env env = info.Env();

ThreadSafeFunction wrapped = ThreadSafeFunction::New(
env, Napi::Function(), Object::New(env), String::New(env, "Test"), 0, 1);
wrapped.BlockingCall(static_cast<void*>(nullptr), CallJS);
wrapped.Release();
}

} // namespace

Object InitThreadSafeFunctionException(Env env) {
Object exports = Object::New(env);
exports["testCall"] = Function::New(env, TestCall);
exports["testCallWithNativeCallback"] =
Function::New(env, TestCallWithNativeCallback);

return exports;
}

#endif
20 changes: 20 additions & 0 deletions test/threadsafe_function/threadsafe_function_exception.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
'use strict';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For the tests, you could use common.runTestInChildProcess(). You need to extend it to also accept Node.js command line arguments, but I think it might make the code a little cleaner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks for the suggestion! It's definitely an improvement after #1325 landed.


const common = require('../common');

module.exports = common.runTest(test);

const execArgv = ['--force-node-api-uncaught-exceptions-policy=true'];
async function test () {
await common.runTestInChildProcess({
suite: 'threadsafe_function_exception',
testName: 'testCall',
execArgv
});

await common.runTestInChildProcess({
suite: 'threadsafe_function_exception',
testName: 'testCallWithNativeCallback',
execArgv
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <cstdlib>
#include "napi.h"
#include "test_helper.h"

#if (NAPI_VERSION > 3)

using namespace Napi;

namespace {

void CallJS(Napi::Env env,
Napi::Function /* callback */,
std::nullptr_t* /* context */,
void* /*data*/) {
Napi::Error error = Napi::Error::New(env, "test-from-native");
NAPI_THROW_VOID(error);
}

using TSFN = TypedThreadSafeFunction<std::nullptr_t, void, CallJS>;

void TestCall(const CallbackInfo& info) {
Napi::Env env = info.Env();

TSFN wrapped = TSFN::New(
env, Napi::Function(), Object::New(env), String::New(env, "Test"), 0, 1);
wrapped.BlockingCall(static_cast<void*>(nullptr));
wrapped.Release();
}

} // namespace

Object InitTypedThreadSafeFunctionException(Env env) {
Object exports = Object::New(env);
exports["testCall"] = Function::New(env, TestCall);

return exports;
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
'use strict';

const common = require('../common');

module.exports = common.runTest(test);

async function test () {
await common.runTestInChildProcess({
suite: 'typed_threadsafe_function_exception',
testName: 'testCall',
execArgv: ['--force-node-api-uncaught-exceptions-policy=true']
});
}