From 1964a911ace0dc040c2cce3ff8a9228140a05a68 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 14:59:00 -0700 Subject: [PATCH 01/11] [ENH] handle panics in query service --- Cargo.lock | 36 +++++++++++++--- rust/worker/Cargo.toml | 8 +++- rust/worker/src/catch_panic.rs | 75 ++++++++++++++++++++++++++++++++++ rust/worker/src/lib.rs | 1 + rust/worker/src/server.rs | 75 ++++++++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 rust/worker/src/catch_panic.rs diff --git a/Cargo.lock b/Cargo.lock index f5dc7ea915a..965e3c20e21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -315,9 +315,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -1717,12 +1717,12 @@ dependencies = [ [[package]] name = "http-body-util" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", - "futures-core", + "futures-util", "http 1.1.0", "http-body 1.0.0", "pin-project-lite", @@ -2301,6 +2301,18 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" +[[package]] +name = "network-interface" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a43439bf756eed340bdf8feba761e2d50c7d47175d87545cd5cbe4a137c4d1" +dependencies = [ + "cc", + "libc", + "thiserror", + "winapi", +] + [[package]] name = "nom" version = "7.1.3" @@ -2985,6 +2997,17 @@ dependencies = [ "rand_core", ] +[[package]] +name = "random-port" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52b7d0e298a1b2f2f46c8d5da944c80ed1e5e6b032521cc44ee2b1dcbe2b94a" +dependencies = [ + "network-interface", + "rand", + "thiserror", +] + [[package]] name = "rayon" version = "1.9.0" @@ -4637,6 +4660,7 @@ dependencies = [ "criterion", "figment", "futures", + "hyper", "k8s-openapi", "kube", "murmur3", @@ -4649,6 +4673,7 @@ dependencies = [ "prost 0.12.3", "prost-types", "rand", + "random-port", "rayon", "roaring", "schemars", @@ -4662,6 +4687,7 @@ dependencies = [ "tokio-util", "tonic 0.10.2", "tonic-build", + "tower", "tracing", "tracing-bunyan-formatter", "tracing-opentelemetry", diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index 18c84123e50..5c155468091 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -48,9 +48,15 @@ tracing = "0.1" tracing-bunyan-formatter = "0.3.3" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -opentelemetry = { version = "0.19.0", default-features = false, features = ["trace", "rt-tokio"] } +opentelemetry = { version = "0.19.0", default-features = false, features = [ + "trace", + "rt-tokio", +] } opentelemetry-otlp = "0.12.0" shuttle = "0.7.1" +tower = "0.4.13" +random-port = "0.1.1" +hyper = "0.14" [dev-dependencies] diff --git a/rust/worker/src/catch_panic.rs b/rust/worker/src/catch_panic.rs new file mode 100644 index 00000000000..f1dfb65b6f1 --- /dev/null +++ b/rust/worker/src/catch_panic.rs @@ -0,0 +1,75 @@ +use std::{ + panic::AssertUnwindSafe, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::FutureExt; +use hyper::body::Body; +use tonic::body::BoxBody; +use tower::{Layer, Service}; + +#[derive(Debug, Clone, Default)] +pub struct CatchPanicLayer; + +impl Layer for CatchPanicLayer { + type Service = CatchPanicMiddleware; + + fn layer(&self, service: S) -> Self::Service { + CatchPanicMiddleware { inner: service } + } +} + +#[derive(Debug, Clone)] +pub struct CatchPanicMiddleware { + inner: S, +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +impl Service> for CatchPanicMiddleware +where + S: Service, Response = hyper::Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: hyper::Request) -> Self::Future { + // This is necessary because tonic internally uses `tower::buffer::Buffer`. + // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 + // for details on why this is necessary. + // (We could also probably avoid a clone using an approach similar to https://docs.rs/tower-http/latest/tower_http/catch_panic, but wrangling the types between hyper/tower/tonic is very tricky.) + let clone = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, clone); + + Box::pin(async move { + // See https://doc.rust-lang.org/core/panic/trait.UnwindSafe.html for details on unwind safety. + // tl;dr: it's not guaranteed to be safe to continue execution after a panic as the world may be in an inconsistent state. + // Many types *are* unwind safe and marked as such with the UnwindSafe trait. In our case, since we want a generic wrapper around any service, we need to manually assert that the service is unwind safe. + // Note that this can lead to unexpected behavior if the service is not actually unwind safe and it panics. + match AssertUnwindSafe(inner.call(req)).catch_unwind().await { + Ok(response) => response, + Err(err) => { + let message = if let Some(s) = err.downcast_ref::() { + format!("Service panicked: {}", s) + } else if let Some(s) = err.downcast_ref::<&str>() { + format!("Service panicked: {}", s) + } else { + "Service panicked but `CatchPanic` was unable to downcast the panic info" + .to_string() + }; + tracing::error!("{}", message); + + let response = tonic::Status::internal(message).to_http(); + Ok(response) + } + } + }) + } +} diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index f043665ef9a..5521b513847 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,5 +1,6 @@ mod assignment; mod blockstore; +mod catch_panic; mod compactor; mod config; pub mod distance; diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 658a472fc4f..4e555078f8b 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::path::PathBuf; use crate::blockstore::provider::BlockfileProvider; +use crate::catch_panic::CatchPanicLayer; use crate::chroma_proto::{ self, CountRecordsRequest, CountRecordsResponse, QueryMetadataRequest, QueryMetadataResponse, }; @@ -89,6 +90,7 @@ impl WorkerServer { let addr = format!("[::]:{}", worker.port).parse().unwrap(); println!("Worker listening on {}", addr); let server = Server::builder() + .layer(CatchPanicLayer::default()) .add_service(chroma_proto::vector_reader_server::VectorReaderServer::new( worker.clone(), )) @@ -131,6 +133,13 @@ impl WorkerServer { } }; + #[cfg(debug_assertions)] + { + if segment_uuid == Uuid::nil() { + panic!("Invalid Segment UUID (throwing panic for testing)"); + } + } + let mut proto_results_for_all = Vec::new(); let parse_vectors_span = trace_span!("Input vectors parsing"); @@ -522,3 +531,69 @@ impl chroma_proto::metadata_reader_server::MetadataReader for WorkerServer { .await } } + +#[cfg(test)] +mod tests { + use crate::execution::dispatcher; + use crate::log::log::InMemoryLog; + use crate::storage::local::LocalStorage; + use crate::storage::Storage; + use crate::sysdb::test_sysdb::TestSysDb; + use crate::system; + + use super::*; + use chroma_proto::vector_reader_client::VectorReaderClient; + use tempfile::tempdir; + + #[tokio::test] + async fn foo() { + let sysdb = TestSysDb::new(); + let log = InMemoryLog::new(); + let tmp_dir = tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + + let port = random_port::PortPicker::new().pick().unwrap(); + let mut server = WorkerServer { + dispatcher: None, + system: None, + sysdb: Box::new(SysDb::Test(sysdb)), + log: Box::new(Log::InMemory(log)), + hnsw_index_provider: HnswIndexProvider::new( + storage.clone(), + tmp_dir.path().to_path_buf(), + ), + blockfile_provider: BlockfileProvider::new_arrow(storage), + port, + }; + + let system: system::System = system::System::new(); + let dispatcher = dispatcher::Dispatcher::new(4, 10, 10); + let dispatcher_handle = system.start_component(dispatcher); + + server.set_system(system.clone()); + server.set_dispatcher(dispatcher_handle.receiver()); + + tokio::spawn(async move { + let _ = crate::server::WorkerServer::run(server).await; + }); + + let err_response = VectorReaderClient::connect(format!("http://localhost:{}", port)) + .await + .unwrap() + .query_vectors(Request::new(QueryVectorsRequest { + segment_id: "00000000-0000-0000-0000-000000000000".to_string(), + vectors: vec![], + k: 10, + allowed_ids: vec![], + include_embeddings: false, + })) + .await + .unwrap_err(); + + assert_eq!(err_response.code(), tonic::Code::Internal); + assert!(err_response.message().contains("Service panicked")); + assert!(err_response + .message() + .contains("throwing panic for testing")); + } +} From 1005510c53a841e76716d70095438f802a3bc02e Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 15:54:08 -0700 Subject: [PATCH 02/11] Fix TestSysDb filtering --- rust/worker/src/sysdb/test_sysdb.rs | 11 ++--------- rust/worker/src/types/segment.rs | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/rust/worker/src/sysdb/test_sysdb.rs b/rust/worker/src/sysdb/test_sysdb.rs index 67dd0ac6edd..11a19b80bc0 100644 --- a/rust/worker/src/sysdb/test_sysdb.rs +++ b/rust/worker/src/sysdb/test_sysdb.rs @@ -91,15 +91,8 @@ impl TestSysDb { if id.is_some() && id.unwrap() != segment.id { return false; } - if r#type.is_some() { - match r#type.unwrap().as_str() { - "hnsw" => { - if segment.r#type != SegmentType::HnswDistributed { - return false; - } - } - _ => return false, - } + if let Some(r#type) = r#type { + return segment.r#type == SegmentType::try_from(r#type.as_str()).unwrap(); } if scope.is_some() && scope.unwrap() != segment.scope { return false; diff --git a/rust/worker/src/types/segment.rs b/rust/worker/src/types/segment.rs index 951850b9877..988a88f956d 100644 --- a/rust/worker/src/types/segment.rs +++ b/rust/worker/src/types/segment.rs @@ -28,6 +28,20 @@ impl From for String { } } +impl TryFrom<&str> for SegmentType { + type Error = SegmentConversionError; + + fn try_from(segment_type: &str) -> Result { + match segment_type { + "urn:chroma:segment/vector/hnsw-distributed" => Ok(SegmentType::HnswDistributed), + "urn:chroma:segment/record/blockfile" => Ok(SegmentType::BlockfileRecord), + "urn:chroma:segment/metadata/sqlite" => Ok(SegmentType::Sqlite), + "urn:chroma:segment/metadata/blockfile" => Ok(SegmentType::BlockfileMetadata), + _ => Err(SegmentConversionError::InvalidSegmentType), + } + } +} + #[derive(Clone, Debug, PartialEq)] pub(crate) struct Segment { pub(crate) id: Uuid, From 4e90ec839a35bbe701784bcf47e334c34b9d57ea Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 15:55:20 -0700 Subject: [PATCH 03/11] Test that gRPC server still works after panic is thrown --- rust/worker/src/catch_panic.rs | 2 +- rust/worker/src/server.rs | 51 +++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/rust/worker/src/catch_panic.rs b/rust/worker/src/catch_panic.rs index f1dfb65b6f1..4a2c5da8fce 100644 --- a/rust/worker/src/catch_panic.rs +++ b/rust/worker/src/catch_panic.rs @@ -61,7 +61,7 @@ where } else if let Some(s) = err.downcast_ref::<&str>() { format!("Service panicked: {}", s) } else { - "Service panicked but `CatchPanic` was unable to downcast the panic info" + "Service panicked but `CatchPanicMiddleware` was unable to downcast the panic info" .to_string() }; tracing::error!("{}", message); diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 4e555078f8b..116006e247e 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -540,6 +540,7 @@ mod tests { use crate::storage::Storage; use crate::sysdb::test_sysdb::TestSysDb; use crate::system; + use crate::types::{Collection, Segment}; use super::*; use chroma_proto::vector_reader_client::VectorReaderClient; @@ -547,7 +548,42 @@ mod tests { #[tokio::test] async fn foo() { - let sysdb = TestSysDb::new(); + let mut sysdb = TestSysDb::new(); + + // Add some data for testing + let collection_uuid = Uuid::new_v4(); + let collection = Collection { + id: collection_uuid, + name: "foo".to_string(), + metadata: None, + dimension: Some(1), + tenant: "foo".to_string(), + database: "foo".to_string(), + log_position: -1, + version: 0, + }; + sysdb.add_collection(collection); + + let record_segment = Segment { + id: Uuid::new_v4(), + r#type: crate::types::SegmentType::BlockfileRecord, + scope: crate::types::SegmentScope::RECORD, + collection: Some(collection_uuid), + metadata: None, + file_path: HashMap::new(), + }; + sysdb.add_segment(record_segment.clone()); + + let hnsw_segment = Segment { + id: Uuid::new_v4(), + r#type: crate::types::SegmentType::HnswDistributed, + scope: crate::types::SegmentScope::VECTOR, + collection: Some(collection_uuid), + metadata: None, + file_path: HashMap::new(), + }; + sysdb.add_segment(hnsw_segment.clone()); + let log = InMemoryLog::new(); let tmp_dir = tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); @@ -577,6 +613,7 @@ mod tests { let _ = crate::server::WorkerServer::run(server).await; }); + // Test response when handler panics let err_response = VectorReaderClient::connect(format!("http://localhost:{}", port)) .await .unwrap() @@ -595,5 +632,17 @@ mod tests { assert!(err_response .message() .contains("throwing panic for testing")); + + // A well-formatted request should still work, even after a panic was thrown + let response = VectorReaderClient::connect(format!("http://localhost:{}", port)) + .await + .unwrap() + .get_vectors(Request::new(GetVectorsRequest { + segment_id: hnsw_segment.id.to_string(), + ids: vec![], + })) + .await; + + assert!(response.is_ok()); } } From ebb451dc5d538863cfaaff36bb8434bd4a40376b Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 16:04:59 -0700 Subject: [PATCH 04/11] Cleanup --- rust/worker/Cargo.toml | 2 +- rust/worker/src/server.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index 5c155468091..f994f99c793 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -55,7 +55,6 @@ opentelemetry = { version = "0.19.0", default-features = false, features = [ opentelemetry-otlp = "0.12.0" shuttle = "0.7.1" tower = "0.4.13" -random-port = "0.1.1" hyper = "0.14" @@ -65,6 +64,7 @@ proptest-state-machine = "0.1.0" "rand" = "0.8.5" rayon = "1.8.0" criterion = "0.3" +random-port = "0.1.1" [build-dependencies] tonic-build = "0.10" diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 116006e247e..e0a08349cc6 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -613,10 +613,12 @@ mod tests { let _ = crate::server::WorkerServer::run(server).await; }); - // Test response when handler panics - let err_response = VectorReaderClient::connect(format!("http://localhost:{}", port)) + let mut client = VectorReaderClient::connect(format!("http://localhost:{}", port)) .await - .unwrap() + .unwrap(); + + // Test response when handler panics + let err_response = client .query_vectors(Request::new(QueryVectorsRequest { segment_id: "00000000-0000-0000-0000-000000000000".to_string(), vectors: vec![], @@ -634,9 +636,7 @@ mod tests { .contains("throwing panic for testing")); // A well-formatted request should still work, even after a panic was thrown - let response = VectorReaderClient::connect(format!("http://localhost:{}", port)) - .await - .unwrap() + let response = client .get_vectors(Request::new(GetVectorsRequest { segment_id: hnsw_segment.id.to_string(), ids: vec![], From 47388d7f3118d31a0d429bffa628440abe316cdd Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 16:10:23 -0700 Subject: [PATCH 05/11] Rename --- rust/worker/src/lib.rs | 1 - rust/worker/src/server.rs | 4 ++-- .../src/{catch_panic.rs => utils/catch_panic_middleware.rs} | 1 + rust/worker/src/utils/mod.rs | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) rename rust/worker/src/{catch_panic.rs => utils/catch_panic_middleware.rs} (96%) diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 5521b513847..f043665ef9a 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,6 +1,5 @@ mod assignment; mod blockstore; -mod catch_panic; mod compactor; mod config; pub mod distance; diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index e0a08349cc6..0667575d380 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::path::PathBuf; use crate::blockstore::provider::BlockfileProvider; -use crate::catch_panic::CatchPanicLayer; use crate::chroma_proto::{ self, CountRecordsRequest, CountRecordsResponse, QueryMetadataRequest, QueryMetadataResponse, }; @@ -23,6 +22,7 @@ use crate::system::{Receiver, System}; use crate::tracing::util::wrap_span_with_parent_context; use crate::types::MetadataValue; use crate::types::ScalarEncoding; +use crate::utils::catch_panic_middleware::CatchPanicLayer; use async_trait::async_trait; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; @@ -547,7 +547,7 @@ mod tests { use tempfile::tempdir; #[tokio::test] - async fn foo() { + async fn gracefully_handles_panics() { let mut sysdb = TestSysDb::new(); // Add some data for testing diff --git a/rust/worker/src/catch_panic.rs b/rust/worker/src/utils/catch_panic_middleware.rs similarity index 96% rename from rust/worker/src/catch_panic.rs rename to rust/worker/src/utils/catch_panic_middleware.rs index 4a2c5da8fce..697159ae4e5 100644 --- a/rust/worker/src/catch_panic.rs +++ b/rust/worker/src/utils/catch_panic_middleware.rs @@ -9,6 +9,7 @@ use hyper::body::Body; use tonic::body::BoxBody; use tower::{Layer, Service}; +/// Middleware layer for Tonic services that catches panics and returns an internal server error. #[derive(Debug, Clone, Default)] pub struct CatchPanicLayer; diff --git a/rust/worker/src/utils/mod.rs b/rust/worker/src/utils/mod.rs index c8b933c8776..c101a5bb4d5 100644 --- a/rust/worker/src/utils/mod.rs +++ b/rust/worker/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod catch_panic_middleware; mod vec; pub(crate) use vec::*; From d787855c682e88eed661252ddbcd8cf0895cb262 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 16:38:39 -0700 Subject: [PATCH 06/11] Separate .proto --- idl/chromadb/proto/debug.proto | 14 ++++ rust/worker/build.rs | 18 ++--- rust/worker/src/server.rs | 125 ++++++++++++--------------------- 3 files changed, 69 insertions(+), 88 deletions(-) create mode 100644 idl/chromadb/proto/debug.proto diff --git a/idl/chromadb/proto/debug.proto b/idl/chromadb/proto/debug.proto new file mode 100644 index 00000000000..6cd6bc73184 --- /dev/null +++ b/idl/chromadb/proto/debug.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package chroma; + +import "google/protobuf/empty.proto"; + +message GetInfoResponse { + string version = 1; +} + +service Debug { + rpc GetInfo(google.protobuf.Empty) returns (GetInfoResponse) {} + rpc TriggerPanic(google.protobuf.Empty) returns (google.protobuf.Empty) {} +} diff --git a/rust/worker/build.rs b/rust/worker/build.rs index c98a96de683..6d356930518 100644 --- a/rust/worker/build.rs +++ b/rust/worker/build.rs @@ -1,15 +1,17 @@ fn main() -> Result<(), Box> { // Compile the protobuf files in the chromadb proto directory. + let mut proto_paths = vec![ + "../../idl/chromadb/proto/chroma.proto", + "../../idl/chromadb/proto/coordinator.proto", + "../../idl/chromadb/proto/logservice.proto", + ]; + + #[cfg(debug_assertions)] + proto_paths.push("../../idl/chromadb/proto/debug.proto"); + tonic_build::configure() .emit_rerun_if_changed(true) - .compile( - &[ - "../../idl/chromadb/proto/chroma.proto", - "../../idl/chromadb/proto/coordinator.proto", - "../../idl/chromadb/proto/logservice.proto", - ], - &["../../idl/"], - )?; + .compile(&proto_paths, &["../../idl/"])?; // Compile the hnswlib bindings. cc::Build::new() diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 0667575d380..069b3591e10 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -96,18 +96,23 @@ impl WorkerServer { )) .add_service( chroma_proto::metadata_reader_server::MetadataReaderServer::new(worker.clone()), - ) - .serve_with_shutdown(addr, async { - let mut sigterm = match signal(SignalKind::terminate()) { - Ok(sigterm) => sigterm, - Err(e) => { - tracing::error!("Failed to create signal handler: {:?}", e); - return; - } - }; - sigterm.recv().await; - tracing::info!("Received SIGTERM, shutting down"); - }); + ); + + #[cfg(debug_assertions)] + let server = + server.add_service(chroma_proto::debug_server::DebugServer::new(worker.clone())); + + let server = server.serve_with_shutdown(addr, async { + let mut sigterm = match signal(SignalKind::terminate()) { + Ok(sigterm) => sigterm, + Err(e) => { + tracing::error!("Failed to create signal handler: {:?}", e); + return; + } + }; + sigterm.recv().await; + tracing::info!("Received SIGTERM, shutting down"); + }); server.await?; Ok(()) @@ -133,13 +138,6 @@ impl WorkerServer { } }; - #[cfg(debug_assertions)] - { - if segment_uuid == Uuid::nil() { - panic!("Invalid Segment UUID (throwing panic for testing)"); - } - } - let mut proto_results_for_all = Vec::new(); let parse_vectors_span = trace_span!("Input vectors parsing"); @@ -532,6 +530,25 @@ impl chroma_proto::metadata_reader_server::MetadataReader for WorkerServer { } } +#[tonic::async_trait] +impl chroma_proto::debug_server::Debug for WorkerServer { + async fn get_info( + &self, + _: Request<()>, + ) -> Result, Status> { + let response = chroma_proto::GetInfoResponse { + version: option_env!("CARGO_PKG_VERSION") + .unwrap_or("unknown") + .to_string(), + }; + Ok(Response::new(response)) + } + + async fn trigger_panic(&self, _: Request<()>) -> Result, Status> { + panic!("Intentional panic triggered"); + } +} + #[cfg(test)] mod tests { use crate::execution::dispatcher; @@ -540,50 +557,14 @@ mod tests { use crate::storage::Storage; use crate::sysdb::test_sysdb::TestSysDb; use crate::system; - use crate::types::{Collection, Segment}; use super::*; - use chroma_proto::vector_reader_client::VectorReaderClient; + use chroma_proto::debug_client::DebugClient; use tempfile::tempdir; #[tokio::test] async fn gracefully_handles_panics() { - let mut sysdb = TestSysDb::new(); - - // Add some data for testing - let collection_uuid = Uuid::new_v4(); - let collection = Collection { - id: collection_uuid, - name: "foo".to_string(), - metadata: None, - dimension: Some(1), - tenant: "foo".to_string(), - database: "foo".to_string(), - log_position: -1, - version: 0, - }; - sysdb.add_collection(collection); - - let record_segment = Segment { - id: Uuid::new_v4(), - r#type: crate::types::SegmentType::BlockfileRecord, - scope: crate::types::SegmentScope::RECORD, - collection: Some(collection_uuid), - metadata: None, - file_path: HashMap::new(), - }; - sysdb.add_segment(record_segment.clone()); - - let hnsw_segment = Segment { - id: Uuid::new_v4(), - r#type: crate::types::SegmentType::HnswDistributed, - scope: crate::types::SegmentScope::VECTOR, - collection: Some(collection_uuid), - metadata: None, - file_path: HashMap::new(), - }; - sysdb.add_segment(hnsw_segment.clone()); - + let sysdb = TestSysDb::new(); let log = InMemoryLog::new(); let tmp_dir = tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); @@ -613,36 +594,20 @@ mod tests { let _ = crate::server::WorkerServer::run(server).await; }); - let mut client = VectorReaderClient::connect(format!("http://localhost:{}", port)) + let mut client = DebugClient::connect(format!("http://localhost:{}", port)) .await .unwrap(); // Test response when handler panics - let err_response = client - .query_vectors(Request::new(QueryVectorsRequest { - segment_id: "00000000-0000-0000-0000-000000000000".to_string(), - vectors: vec![], - k: 10, - allowed_ids: vec![], - include_embeddings: false, - })) - .await - .unwrap_err(); - + let err_response = client.trigger_panic(Request::new(())).await.unwrap_err(); assert_eq!(err_response.code(), tonic::Code::Internal); - assert!(err_response.message().contains("Service panicked")); - assert!(err_response - .message() - .contains("throwing panic for testing")); + assert_eq!( + err_response.message(), + "Service panicked: Intentional panic triggered" + ); // A well-formatted request should still work, even after a panic was thrown - let response = client - .get_vectors(Request::new(GetVectorsRequest { - segment_id: hnsw_segment.id.to_string(), - ids: vec![], - })) - .await; - + let response = client.get_info(Request::new(())).await; assert!(response.is_ok()); } } From c825fe0646fe70f716edc7afd0c7f366bd62cebd Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 18 Jun 2024 16:45:06 -0700 Subject: [PATCH 07/11] Clarify --- rust/worker/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 069b3591e10..6bd96a89d02 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -606,7 +606,7 @@ mod tests { "Service panicked: Intentional panic triggered" ); - // A well-formatted request should still work, even after a panic was thrown + // The server should still work, even after a panic was thrown let response = client.get_info(Request::new(())).await; assert!(response.is_ok()); } From ca28c3504142848dc0a7a48340993c040a03a886 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 19 Jun 2024 09:36:03 -0700 Subject: [PATCH 08/11] pub(crate) --- rust/worker/src/server.rs | 2 +- rust/worker/src/utils/mod.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 6bd96a89d02..48bfd2ccc20 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -22,7 +22,7 @@ use crate::system::{Receiver, System}; use crate::tracing::util::wrap_span_with_parent_context; use crate::types::MetadataValue; use crate::types::ScalarEncoding; -use crate::utils::catch_panic_middleware::CatchPanicLayer; +use crate::utils::CatchPanicLayer; use async_trait::async_trait; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; diff --git a/rust/worker/src/utils/mod.rs b/rust/worker/src/utils/mod.rs index c101a5bb4d5..cd6ba5689b5 100644 --- a/rust/worker/src/utils/mod.rs +++ b/rust/worker/src/utils/mod.rs @@ -1,4 +1,5 @@ -pub mod catch_panic_middleware; +mod catch_panic_middleware; mod vec; +pub(crate) use catch_panic_middleware::*; pub(crate) use vec::*; From a04aa76e88f2396c6c42bba5fe65c73222625346 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 19 Jun 2024 09:59:22 -0700 Subject: [PATCH 09/11] Remove catch panic middleware --- rust/worker/src/server.rs | 8 +- .../src/utils/catch_panic_middleware.rs | 76 ------------------- rust/worker/src/utils/mod.rs | 2 - 3 files changed, 1 insertion(+), 85 deletions(-) delete mode 100644 rust/worker/src/utils/catch_panic_middleware.rs diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 48bfd2ccc20..4ec705532a5 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -22,7 +22,6 @@ use crate::system::{Receiver, System}; use crate::tracing::util::wrap_span_with_parent_context; use crate::types::MetadataValue; use crate::types::ScalarEncoding; -use crate::utils::CatchPanicLayer; use async_trait::async_trait; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; @@ -90,7 +89,6 @@ impl WorkerServer { let addr = format!("[::]:{}", worker.port).parse().unwrap(); println!("Worker listening on {}", addr); let server = Server::builder() - .layer(CatchPanicLayer::default()) .add_service(chroma_proto::vector_reader_server::VectorReaderServer::new( worker.clone(), )) @@ -600,11 +598,7 @@ mod tests { // Test response when handler panics let err_response = client.trigger_panic(Request::new(())).await.unwrap_err(); - assert_eq!(err_response.code(), tonic::Code::Internal); - assert_eq!( - err_response.message(), - "Service panicked: Intentional panic triggered" - ); + assert_eq!(err_response.code(), tonic::Code::Cancelled); // The server should still work, even after a panic was thrown let response = client.get_info(Request::new(())).await; diff --git a/rust/worker/src/utils/catch_panic_middleware.rs b/rust/worker/src/utils/catch_panic_middleware.rs deleted file mode 100644 index 697159ae4e5..00000000000 --- a/rust/worker/src/utils/catch_panic_middleware.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::{ - panic::AssertUnwindSafe, - pin::Pin, - task::{Context, Poll}, -}; - -use futures::FutureExt; -use hyper::body::Body; -use tonic::body::BoxBody; -use tower::{Layer, Service}; - -/// Middleware layer for Tonic services that catches panics and returns an internal server error. -#[derive(Debug, Clone, Default)] -pub struct CatchPanicLayer; - -impl Layer for CatchPanicLayer { - type Service = CatchPanicMiddleware; - - fn layer(&self, service: S) -> Self::Service { - CatchPanicMiddleware { inner: service } - } -} - -#[derive(Debug, Clone)] -pub struct CatchPanicMiddleware { - inner: S, -} - -type BoxFuture<'a, T> = Pin + Send + 'a>>; - -impl Service> for CatchPanicMiddleware -where - S: Service, Response = hyper::Response> + Clone + Send + 'static, - S::Future: Send + 'static, -{ - type Response = S::Response; - type Error = S::Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: hyper::Request) -> Self::Future { - // This is necessary because tonic internally uses `tower::buffer::Buffer`. - // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 - // for details on why this is necessary. - // (We could also probably avoid a clone using an approach similar to https://docs.rs/tower-http/latest/tower_http/catch_panic, but wrangling the types between hyper/tower/tonic is very tricky.) - let clone = self.inner.clone(); - let mut inner = std::mem::replace(&mut self.inner, clone); - - Box::pin(async move { - // See https://doc.rust-lang.org/core/panic/trait.UnwindSafe.html for details on unwind safety. - // tl;dr: it's not guaranteed to be safe to continue execution after a panic as the world may be in an inconsistent state. - // Many types *are* unwind safe and marked as such with the UnwindSafe trait. In our case, since we want a generic wrapper around any service, we need to manually assert that the service is unwind safe. - // Note that this can lead to unexpected behavior if the service is not actually unwind safe and it panics. - match AssertUnwindSafe(inner.call(req)).catch_unwind().await { - Ok(response) => response, - Err(err) => { - let message = if let Some(s) = err.downcast_ref::() { - format!("Service panicked: {}", s) - } else if let Some(s) = err.downcast_ref::<&str>() { - format!("Service panicked: {}", s) - } else { - "Service panicked but `CatchPanicMiddleware` was unable to downcast the panic info" - .to_string() - }; - tracing::error!("{}", message); - - let response = tonic::Status::internal(message).to_http(); - Ok(response) - } - } - }) - } -} diff --git a/rust/worker/src/utils/mod.rs b/rust/worker/src/utils/mod.rs index cd6ba5689b5..c8b933c8776 100644 --- a/rust/worker/src/utils/mod.rs +++ b/rust/worker/src/utils/mod.rs @@ -1,5 +1,3 @@ -mod catch_panic_middleware; mod vec; -pub(crate) use catch_panic_middleware::*; pub(crate) use vec::*; From 73112f1f30efa62f0b9f7c69e577877a744ba3bf Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 19 Jun 2024 11:24:18 -0700 Subject: [PATCH 10/11] Add comment --- rust/worker/build.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust/worker/build.rs b/rust/worker/build.rs index 6d356930518..22bd8b44f64 100644 --- a/rust/worker/build.rs +++ b/rust/worker/build.rs @@ -6,6 +6,8 @@ fn main() -> Result<(), Box> { "../../idl/chromadb/proto/logservice.proto", ]; + // Can't use #[cfg(test)] here because a build for tests is technically a regular debug build, meaning that #[cfg(test)] is useless in build.rs. + // See https://github.com/rust-lang/cargo/issues/1581 #[cfg(debug_assertions)] proto_paths.push("../../idl/chromadb/proto/debug.proto"); From e01be96206b9df2d45597eaa169652f265ebac5a Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 19 Jun 2024 11:25:18 -0700 Subject: [PATCH 11/11] Revert unnecessary dependency changes --- Cargo.lock | 2 -- rust/worker/Cargo.toml | 2 -- 2 files changed, 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 965e3c20e21..654fa108a01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4660,7 +4660,6 @@ dependencies = [ "criterion", "figment", "futures", - "hyper", "k8s-openapi", "kube", "murmur3", @@ -4687,7 +4686,6 @@ dependencies = [ "tokio-util", "tonic 0.10.2", "tonic-build", - "tower", "tracing", "tracing-bunyan-formatter", "tracing-opentelemetry", diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index f994f99c793..52b0cf5351a 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -54,8 +54,6 @@ opentelemetry = { version = "0.19.0", default-features = false, features = [ ] } opentelemetry-otlp = "0.12.0" shuttle = "0.7.1" -tower = "0.4.13" -hyper = "0.14" [dev-dependencies]