Skip to content

Commit

Permalink
tls: fix detection of the upstream connection close event. (#13858)
Browse files Browse the repository at this point in the history
Fixes #13856.

Signed-off-by: Piotr Sikora <[email protected]>
  • Loading branch information
PiotrSikora authored Nov 4, 2020
1 parent 01c4532 commit 359def3
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/root/version_history/current.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Bug Fixes
* http: fixed URL parsing for HTTP/1.1 fully qualified URLs and connect requests containing IPv6 addresses.
* http: sending CONNECT_ERROR for HTTP/2 where appropriate during CONNECT requests.
* proxy_proto: fixed a bug where the wrong downstream address got sent to upstream connections.
* tls: fix detection of the upstream connection close event.
* tls: fix read resumption after triggering buffer high-watermark and all remaining request/response bytes are stored in the SSL connection's internal buffers.
* watchdog: touch the watchdog before most event loop operations to avoid misses when handling bursts of callbacks.

Expand Down
2 changes: 1 addition & 1 deletion source/extensions/transport_sockets/tls/ssl_handshaker.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker {
// Ssl::Handshaker
Network::PostIoAction doHandshake() override;

Ssl::SocketState state() { return state_; }
Ssl::SocketState state() const { return state_; }
void setState(Ssl::SocketState state) { state_ = state; }
SSL* ssl() const { return ssl_.get(); }
Ssl::HandshakeCallbacks* handshakeCallbacks() { return handshake_callbacks_; }
Expand Down
10 changes: 9 additions & 1 deletion source/extensions/transport_sockets/tls/ssl_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,18 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) {
case SSL_ERROR_WANT_READ:
break;
case SSL_ERROR_ZERO_RETURN:
// Graceful shutdown using close_notify TLS alert.
end_stream = true;
break;
case SSL_ERROR_SYSCALL:
if (result.error_.value() == 0) {
// Non-graceful shutdown by closing the underlying socket.
end_stream = true;
break;
}
FALLTHRU;
case SSL_ERROR_WANT_WRITE:
// Renegotiation has started. We don't handle renegotiation so just fall through.
// Renegotiation has started. We don't handle renegotiation so just fall through.
default:
drainErrorQueue();
action = PostIoAction::Close;
Expand Down
179 changes: 179 additions & 0 deletions test/extensions/transport_sockets/tls/ssl_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,185 @@ TEST_P(SslSocketTest, HalfClose) {
dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ShutdownWithCloseNotify) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
tls_certificates:
certificate_chain:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_cert.pem"
private_key:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_key.pem"
validation_context:
trusted_ca:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem"
)EOF";

envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context);
auto server_cfg = std::make_unique<ServerContextConfigImpl>(server_tls_context, factory_context_);
ContextManagerImpl manager(time_system_);
Stats::TestUtil::TestStore server_stats_store;
ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager,
server_stats_store, std::vector<std::string>{});

auto socket = std::make_shared<Network::TcpListenSocket>(
Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true);
Network::MockTcpListenerCallbacks listener_callbacks;
Network::MockConnectionHandler connection_handler;
Network::ListenerPtr listener =
dispatcher_->createListener(socket, listener_callbacks, true, ENVOY_TCP_BACKLOG_SIZE);
std::shared_ptr<Network::MockReadFilter> server_read_filter(new Network::MockReadFilter());
std::shared_ptr<Network::MockReadFilter> client_read_filter(new Network::MockReadFilter());

const std::string client_ctx_yaml = R"EOF(
common_tls_context:
)EOF";

envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context);
auto client_cfg = std::make_unique<ClientContextConfigImpl>(tls_context, factory_context_);
Stats::TestUtil::TestStore client_stats_store;
ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager,
client_stats_store);
Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection(
socket->localAddress(), Network::Address::InstanceConstSharedPtr(),
client_ssl_socket_factory.createTransportSocket(nullptr), nullptr);
Network::MockConnectionCallbacks client_connection_callbacks;
client_connection->enableHalfClose(true);
client_connection->addReadFilter(client_read_filter);
client_connection->addConnectionCallbacks(client_connection_callbacks);
client_connection->connect();

Network::ConnectionPtr server_connection;
Network::MockConnectionCallbacks server_connection_callbacks;
EXPECT_CALL(listener_callbacks, onAccept_(_))
.WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void {
server_connection = dispatcher_->createServerConnection(
std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr),
stream_info_);
server_connection->enableHalfClose(true);
server_connection->addReadFilter(server_read_filter);
server_connection->addConnectionCallbacks(server_connection_callbacks);
}));
EXPECT_CALL(*server_read_filter, onNewConnection());
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Buffer::OwnedImpl data("hello");
server_connection->write(data, true);
EXPECT_EQ(data.length(), 0);
}));

EXPECT_CALL(*client_read_filter, onNewConnection())
.WillOnce(Return(Network::FilterStatus::Continue));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true))
.WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus {
read_buffer.drain(read_buffer.length());
client_connection->close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
}));
EXPECT_CALL(*server_read_filter, onData(_, true));

EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
server_connection->close(Network::ConnectionCloseType::NoFlush);
dispatcher_->exit();
}));

dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ShutdownWithoutCloseNotify) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
tls_certificates:
certificate_chain:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_cert.pem"
private_key:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/unittest_key.pem"
validation_context:
trusted_ca:
filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem"
)EOF";

envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context);
auto server_cfg = std::make_unique<ServerContextConfigImpl>(server_tls_context, factory_context_);
ContextManagerImpl manager(time_system_);
Stats::TestUtil::TestStore server_stats_store;
ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager,
server_stats_store, std::vector<std::string>{});

auto socket = std::make_shared<Network::TcpListenSocket>(
Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true);
Network::MockTcpListenerCallbacks listener_callbacks;
Network::MockConnectionHandler connection_handler;
Network::ListenerPtr listener =
dispatcher_->createListener(socket, listener_callbacks, true, ENVOY_TCP_BACKLOG_SIZE);
std::shared_ptr<Network::MockReadFilter> server_read_filter(new Network::MockReadFilter());
std::shared_ptr<Network::MockReadFilter> client_read_filter(new Network::MockReadFilter());

const std::string client_ctx_yaml = R"EOF(
common_tls_context:
)EOF";

envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context;
TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context);
auto client_cfg = std::make_unique<ClientContextConfigImpl>(tls_context, factory_context_);
Stats::TestUtil::TestStore client_stats_store;
ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager,
client_stats_store);
Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection(
socket->localAddress(), Network::Address::InstanceConstSharedPtr(),
client_ssl_socket_factory.createTransportSocket(nullptr), nullptr);
Network::MockConnectionCallbacks client_connection_callbacks;
client_connection->enableHalfClose(true);
client_connection->addReadFilter(client_read_filter);
client_connection->addConnectionCallbacks(client_connection_callbacks);
client_connection->connect();

Network::ConnectionPtr server_connection;
Network::MockConnectionCallbacks server_connection_callbacks;
EXPECT_CALL(listener_callbacks, onAccept_(_))
.WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void {
server_connection = dispatcher_->createServerConnection(
std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr),
stream_info_);
server_connection->enableHalfClose(true);
server_connection->addReadFilter(server_read_filter);
server_connection->addConnectionCallbacks(server_connection_callbacks);
}));
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Buffer::OwnedImpl data("hello");
server_connection->write(data, false);
EXPECT_EQ(data.length(), 0);
// Close without sending close_notify alert.
const SslHandshakerImpl* ssl_socket =
dynamic_cast<const SslHandshakerImpl*>(server_connection->ssl().get());
EXPECT_EQ(ssl_socket->state(), Ssl::SocketState::HandshakeComplete);
SSL_set_quiet_shutdown(ssl_socket->ssl(), 1);
server_connection->close(Network::ConnectionCloseType::NoFlush);
}));

EXPECT_CALL(*client_read_filter, onNewConnection())
.WillOnce(Return(Network::FilterStatus::Continue));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true))
.WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus {
read_buffer.drain(read_buffer.length());
client_connection->close(Network::ConnectionCloseType::NoFlush);
return Network::FilterStatus::StopIteration;
}));

EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); }));

dispatcher_->run(Event::Dispatcher::RunType::Block);
}

TEST_P(SslSocketTest, ClientAuthMultipleCAs) {
const std::string server_ctx_yaml = R"EOF(
common_tls_context:
Expand Down

0 comments on commit 359def3

Please sign in to comment.