diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
index c85191a1e..1b54ee8a4 100644
--- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
+++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
@@ -357,9 +357,9 @@ public void setClientName(String clientName) {
/**
* Authenticates the current connection using the provided credentials.
*
- * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active
- * transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC}
- * command is executed, ensuring that authentication does not interfere with ongoing transactions.
+ * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection
+ * is within an active transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD}
+ * or {@code EXEC} command is executed, ensuring that authentication does not interfere with ongoing transactions.
*
*
* @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed.
@@ -421,20 +421,29 @@ protected void dispatchAuth(RedisCredentials credentials) {
return;
}
- RedisFuture auth;
+ // dispatch directly to avoid AUTH preprocessing overrides credentials provider
+ RedisCommand auth = super.dispatch(authCommand(credentials));
+ if (auth instanceof CompleteableCommand) {
+ ((CompleteableCommand>) auth).onComplete((status, throwable) -> {
+ if (throwable != null) {
+ logger.error("Re-authentication failed {}.", getEpid(), throwable);
+ publishReauthFailedEvent(throwable);
+ } else {
+ logger.info("Re-authentication succeeded {}.", getEpid());
+ publishReauthEvent();
+ }
+ });
+ }
+ }
+
+ private AsyncCommand authCommand(RedisCredentials credentials) {
+ CommandArgs args = new CommandArgs<>(codec);
if (credentials.getUsername() != null) {
- auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword()));
+ args.add(credentials.getUsername()).add(credentials.getPassword());
} else {
- auth = async().auth(String.valueOf(credentials.getPassword()));
+ args.add(credentials.getPassword());
}
- auth.thenRun(() -> {
- publishReauthEvent();
- logger.info("Re-authentication succeeded {}.", getEpid());
- }).exceptionally(throwable -> {
- publishReauthFailedEvent(throwable);
- logger.error("Re-authentication failed {}.", getEpid(), throwable);
- return null;
- });
+ return new AsyncCommand<>(new Command<>(AUTH, new StatusOutput<>(codec), args));
}
private void publishReauthEvent() {
@@ -463,4 +472,5 @@ private String getEpid() {
return ((Endpoint) writer).getId();
}
+
}
diff --git a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java
index 245eef09b..d418c4251 100644
--- a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java
+++ b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java
@@ -2,13 +2,17 @@
import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.protocol.AsyncCommand;
+import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.PushHandler;
+import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.core.resource.ClientResources;
import io.lettuce.core.tracing.Tracing;
import io.lettuce.test.ReflectionTestUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentMatcher;
+import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
@@ -21,6 +25,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
@@ -31,10 +36,8 @@
public class StatefulRedisConnectionImplUnitTests extends TestSupport {
RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
- StatefulRedisConnectionImpl connection;
- @Mock
- RedisAsyncCommandsImpl asyncCommands;
+ StatefulRedisConnectionImpl connection;
@Mock
PushHandler pushHandler;
@@ -53,45 +56,26 @@ void setup() throws NoSuchFieldException, IllegalAccessException {
when(writer.getClientResources()).thenReturn(clientResources);
when(clientResources.tracing()).thenReturn(tracing);
when(tracing.isEnabled()).thenReturn(false);
- when(asyncCommands.auth(any(CharSequence.class)))
- .thenAnswer( invocation -> {
- String pass = invocation.getArgument(0);
- AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(pass));
- auth.complete();
- return auth;
- });
- when(asyncCommands.auth(anyString(), any(CharSequence.class)))
- .thenAnswer( invocation -> {
- String user = invocation.getArgument(0); // Capture username
- String pass = invocation.getArgument(1); // Capture password
- AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(user, pass));
- auth.complete();
- return auth;
- });
Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async");
asyncField.setAccessible(true);
-
connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1));
- asyncField.set(connection,asyncCommands);
}
@Test
public void testSetCredentialsWhenCredentialsAreNull() {
connection.setCredentials(null);
- verify(asyncCommands, never()).auth(any(CharSequence.class));
- verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
+ verify(writer, never()).write(ArgumentMatchers.> any());
}
@Test
void testSetCredentialsDispatchesAuthWhenNotInTransaction() {
connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
- verify(asyncCommands).auth(eq("user"), eq("pass"));
+ verify(writer).write(argThat(isAuthCommand("user", "pass")));
}
-
@Test
void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
@@ -99,11 +83,9 @@ void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
- verify(asyncCommands, never()).auth(any(CharSequence.class));
- verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
+ verify(writer, never()).write(ArgumentMatchers.> any());
}
-
@Test
void testSetCredentialsDispatchesAuthAfterTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
@@ -116,7 +98,7 @@ void testSetCredentialsDispatchesAuthAfterTransaction() {
assertThat(inTransaction.get()).isFalse();
- verify(asyncCommands).auth(eq("user"), eq("pass"));
+ verify(writer).write(argThat(isAuthCommand("user", "pass")));
}
@Test
@@ -136,7 +118,30 @@ void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws In
thread.join();
assertThat(inTransaction.get()).isFalse();
- verify(asyncCommands).auth(eq("user"), eq("pass"));
+ verify(writer).write(argThat(isAuthCommand("user", "pass")));
+ }
+
+ public static ArgumentMatcher> isAuthCommand(String expectedUsername,
+ String expectedPassword) {
+ return new ArgumentMatcher>() {
+
+ @Override
+ public boolean matches(RedisCommand command) {
+ if (command.getType() != CommandType.AUTH) {
+ return false;
+ }
+
+ // Retrieve arguments (adjust based on your RedisCommand implementation)
+ return command.getArgs().toCommandString().equals(expectedUsername + " " + expectedPassword);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("Expected AUTH command with username=%s and password=%s", expectedUsername,
+ expectedPassword);
+ }
+
+ };
}
}