Skip to content

Commit

Permalink
dekaf: Always make sure we're using fresh tokens
Browse files Browse the repository at this point in the history
I was noticing a bunch of `ExpiredSignature` errors, and realized that if a client didn't close its connection on read errors, a new access token would never get created. This ensures that every time we get a `flow_client::Client`, we call `refresh_authorizations()` to make sure it's using valid tokens. This is cheap in the common case that the tokens aren't expired.
  • Loading branch information
jshearer committed Oct 8, 2024
1 parent 4670440 commit afb9428
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 50 deletions.
32 changes: 31 additions & 1 deletion crates/dekaf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,36 @@ pub struct DeprecatedConfigOptions {

pub struct Authenticated {
client: flow_client::Client,
refresh_token: RefreshToken,
access_token: String,
task_config: DekafConfig,
claims: models::authorizations::ControlClaims,
}

impl Authenticated {
pub async fn get_client(&mut self) -> anyhow::Result<&flow_client::Client> {
let (access, refresh) = refresh_authorizations(
&self.client,
Some(self.access_token.to_owned()),
Some(self.refresh_token.to_owned()),
)
.await?;

if access.ne(&self.access_token) {
self.access_token = access.clone();
self.refresh_token = refresh;

self.client = self
.client
.clone()
.with_creds(Some(access))
.with_fresh_gazette_client();
}

Ok(&self.client)
}
}

impl App {
#[tracing::instrument(level = "info", err(Debug, level = "warn"), skip(self, password))]
async fn authenticate(&self, username: &str, password: &str) -> anyhow::Result<Authenticated> {
Expand All @@ -72,14 +98,16 @@ impl App {
let client = self
.client_base
.clone()
.with_creds(Some(access), Some(refresh))
.with_creds(Some(access.clone()))
.with_fresh_gazette_client();

let claims = flow_client::client::client_claims(&client)?;

if models::Materialization::regex().is_match(username.as_ref()) {
Ok(Authenticated {
client,
access_token: access,
refresh_token: refresh,
task_config: todo!("Fetch and unseal task config"),
claims,
})
Expand All @@ -93,6 +121,8 @@ impl App {
strict_topic_names: config.strict_topic_names,
token: "".to_string(),
},
access_token: access,
refresh_token: refresh,
claims,
})
} else {
Expand Down
1 change: 0 additions & 1 deletion crates/dekaf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ async fn main() -> anyhow::Result<()> {
api_key,
api_endpoint,
None,
None
)
});

Expand Down
81 changes: 45 additions & 36 deletions crates/dekaf/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::{App, Collection, Read};
use crate::{
connector::DekafConfig, from_downstream_topic_name, from_upstream_topic_name,
read::BatchResult, to_downstream_topic_name, to_upstream_topic_name,
topology::fetch_all_collection_names, Authenticated,
from_downstream_topic_name, from_upstream_topic_name, read::BatchResult,
to_downstream_topic_name, to_upstream_topic_name, topology::fetch_all_collection_names,
Authenticated,
};
use anyhow::Context;
use bytes::{BufMut, BytesMut};
Expand Down Expand Up @@ -32,22 +32,17 @@ struct PendingRead {

pub struct Session {
app: Arc<App>,
client: Option<flow_client::Client>,
reads: HashMap<(TopicName, i32), PendingRead>,
/// ID of the authenticated user
user_id: Option<String>,
task_config: Option<DekafConfig>,
secret: String,
auth: Option<Authenticated>,
}

impl Session {
pub fn new(app: Arc<App>, secret: String) -> Self {
Self {
app,
client: None,
reads: HashMap::new(),
user_id: None,
task_config: None,
auth: None,
secret,
}
}
Expand Down Expand Up @@ -82,14 +77,9 @@ impl Session {
let password = it.next().context("expected SASL passwd")??;

let response = match self.app.authenticate(authcid, password).await {
Ok(Authenticated {
client,
task_config,
claims,
}) => {
self.client.replace(client);
self.task_config.replace(task_config);
self.user_id.replace(claims.sub.to_string());
Ok(auth) => {
let claims = auth.claims.clone();
self.auth.replace(auth);

let mut response = messages::SaslAuthenticateResponse::default();
response.session_lifetime_ms = (1000
Expand Down Expand Up @@ -146,9 +136,11 @@ impl Session {
) -> anyhow::Result<IndexMap<TopicName, MetadataResponseTopic>> {
let collections = fetch_all_collection_names(
&self
.client
.as_ref()
.auth
.as_mut()
.ok_or(anyhow::anyhow!("Session not authenticated"))?
.get_client()
.await?
.pg_client(),
)
.await?;
Expand Down Expand Up @@ -177,10 +169,12 @@ impl Session {
&mut self,
requests: Vec<messages::metadata_request::MetadataRequestTopic>,
) -> anyhow::Result<IndexMap<TopicName, MetadataResponseTopic>> {
let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;
let client = self
.auth
.as_mut()
.ok_or(anyhow::anyhow!("Session not authenticated"))?
.get_client()
.await?;

// Concurrently fetch Collection instances for all requested topics.
let collections: anyhow::Result<Vec<(TopicName, Option<Collection>)>> =
Expand Down Expand Up @@ -262,10 +256,12 @@ impl Session {
&mut self,
request: messages::ListOffsetsRequest,
) -> anyhow::Result<messages::ListOffsetsResponse> {
let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;
let client = self
.auth
.as_mut()
.ok_or(anyhow::anyhow!("Session not authenticated"))?
.get_client()
.await?;

// Concurrently fetch Collection instances and offsets for all requested topics and partitions.
// Map each "topic" into Vec<(Partition Index, Option<(Journal Offset, Timestamp))>.
Expand Down Expand Up @@ -360,10 +356,12 @@ impl Session {
..
} = request;

let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;
let client = self
.auth
.as_mut()
.ok_or(anyhow::anyhow!("Session not authenticated"))?
.get_client()
.await?;

let timeout_at =
std::time::Instant::now() + std::time::Duration::from_millis(max_wait_ms as u64);
Expand Down Expand Up @@ -1040,22 +1038,33 @@ impl Session {
to_upstream_topic_name(
name,
self.secret.to_owned(),
self.user_id.clone().expect("User ID should exist"),
self.auth
.as_ref()
.expect("Must be authenticated")
.claims
.sub
.to_string(),
)
}
fn decrypt_topic_name(&self, name: TopicName) -> TopicName {
from_upstream_topic_name(
name,
self.secret.to_owned(),
self.user_id.clone().expect("User ID should exist"),
self.auth
.as_ref()
.expect("Must be authenticated")
.claims
.sub
.to_string(),
)
}

fn encode_topic_name(&self, name: String) -> TopicName {
if self
.task_config
.auth
.as_ref()
.expect("should have config already")
.expect("Must be authenticated")
.task_config
.strict_topic_names
{
to_downstream_topic_name(TopicName(StrBytes::from_string(name)))
Expand Down
11 changes: 1 addition & 10 deletions crates/flow-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ pub struct Client {
http_client: reqwest::Client,
// User's access token, if authenticated.
user_access_token: Option<String>,
// User's refresh token, if authenticated.
user_refresh_token: Option<RefreshToken>,
// Base shard client which is cloned to build token-specific clients.
shard_client: gazette::shard::Client,
// Base journal client which is cloned to build token-specific clients.
Expand All @@ -32,7 +30,6 @@ impl Client {
pg_api_token: String,
pg_url: Url,
user_access_token: Option<String>,
user_refresh_token: Option<RefreshToken>,
) -> Self {
// Build journal and shard clients with an empty default service address.
// We'll use their with_endpoint_and_metadata() routines to cheaply clone
Expand All @@ -59,18 +56,12 @@ impl Client {
journal_client,
shard_client,
user_access_token,
user_refresh_token,
}
}

pub fn with_creds(
self,
user_access_token: Option<String>,
user_refresh_token: Option<RefreshToken>,
) -> Self {
pub fn with_creds(self, user_access_token: Option<String>) -> Self {
Self {
user_access_token: user_access_token.or(self.user_access_token),
user_refresh_token: user_refresh_token.or(self.user_refresh_token),
..self
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/flowctl/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ impl Config {
self.get_pg_public_token().to_string(),
self.get_pg_url().clone(),
None,
None,
)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/flowctl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl Cli {
config.user_access_token = Some(access.to_owned());
config.user_refresh_token = Some(refresh.to_owned());

anon_client.with_creds(Some(access), Some(refresh))
anon_client.with_creds(Some(access))
} else {
tracing::warn!("You are not authenticated. Run `auth login` to login to Flow.");
anon_client
Expand Down

0 comments on commit afb9428

Please sign in to comment.