diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java index c85191a1e..1b54ee8a4 100644 --- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java +++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java @@ -357,9 +357,9 @@ public void setClientName(String clientName) { /** * Authenticates the current connection using the provided credentials. *

- * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active - * transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC} - * command is executed, ensuring that authentication does not interfere with ongoing transactions. + * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection + * is within an active transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} + * or {@code EXEC} command is executed, ensuring that authentication does not interfere with ongoing transactions. *

* * @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed. @@ -421,20 +421,29 @@ protected void dispatchAuth(RedisCredentials credentials) { return; } - RedisFuture auth; + // dispatch directly to avoid AUTH preprocessing overrides credentials provider + RedisCommand auth = super.dispatch(authCommand(credentials)); + if (auth instanceof CompleteableCommand) { + ((CompleteableCommand) auth).onComplete((status, throwable) -> { + if (throwable != null) { + logger.error("Re-authentication failed {}.", getEpid(), throwable); + publishReauthFailedEvent(throwable); + } else { + logger.info("Re-authentication succeeded {}.", getEpid()); + publishReauthEvent(); + } + }); + } + } + + private AsyncCommand authCommand(RedisCredentials credentials) { + CommandArgs args = new CommandArgs<>(codec); if (credentials.getUsername() != null) { - auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword())); + args.add(credentials.getUsername()).add(credentials.getPassword()); } else { - auth = async().auth(String.valueOf(credentials.getPassword())); + args.add(credentials.getPassword()); } - auth.thenRun(() -> { - publishReauthEvent(); - logger.info("Re-authentication succeeded {}.", getEpid()); - }).exceptionally(throwable -> { - publishReauthFailedEvent(throwable); - logger.error("Re-authentication failed {}.", getEpid(), throwable); - return null; - }); + return new AsyncCommand<>(new Command<>(AUTH, new StatusOutput<>(codec), args)); } private void publishReauthEvent() { @@ -463,4 +472,5 @@ private String getEpid() { return ((Endpoint) writer).getId(); } + } diff --git a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java index 245eef09b..d418c4251 100644 --- a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java +++ b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java @@ -2,13 +2,17 @@ import io.lettuce.core.codec.StringCodec; import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.CommandType; import io.lettuce.core.protocol.PushHandler; +import io.lettuce.core.protocol.RedisCommand; import io.lettuce.core.resource.ClientResources; import io.lettuce.core.tracing.Tracing; import io.lettuce.test.ReflectionTestUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; @@ -21,6 +25,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -31,10 +36,8 @@ public class StatefulRedisConnectionImplUnitTests extends TestSupport { RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8); - StatefulRedisConnectionImpl connection; - @Mock - RedisAsyncCommandsImpl asyncCommands; + StatefulRedisConnectionImpl connection; @Mock PushHandler pushHandler; @@ -53,45 +56,26 @@ void setup() throws NoSuchFieldException, IllegalAccessException { when(writer.getClientResources()).thenReturn(clientResources); when(clientResources.tracing()).thenReturn(tracing); when(tracing.isEnabled()).thenReturn(false); - when(asyncCommands.auth(any(CharSequence.class))) - .thenAnswer( invocation -> { - String pass = invocation.getArgument(0); - AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(pass)); - auth.complete(); - return auth; - }); - when(asyncCommands.auth(anyString(), any(CharSequence.class))) - .thenAnswer( invocation -> { - String user = invocation.getArgument(0); // Capture username - String pass = invocation.getArgument(1); // Capture password - AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(user, pass)); - auth.complete(); - return auth; - }); Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async"); asyncField.setAccessible(true); - connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1)); - asyncField.set(connection,asyncCommands); } @Test public void testSetCredentialsWhenCredentialsAreNull() { connection.setCredentials(null); - verify(asyncCommands, never()).auth(any(CharSequence.class)); - verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class)); + verify(writer, never()).write(ArgumentMatchers.> any()); } @Test void testSetCredentialsDispatchesAuthWhenNotInTransaction() { connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); - verify(asyncCommands).auth(eq("user"), eq("pass")); + verify(writer).write(argThat(isAuthCommand("user", "pass"))); } - @Test void testSetCredentialsDoesNotDispatchAuthIfInTransaction() { AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction"); @@ -99,11 +83,9 @@ void testSetCredentialsDoesNotDispatchAuthIfInTransaction() { connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); - verify(asyncCommands, never()).auth(any(CharSequence.class)); - verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class)); + verify(writer, never()).write(ArgumentMatchers.> any()); } - @Test void testSetCredentialsDispatchesAuthAfterTransaction() { AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction"); @@ -116,7 +98,7 @@ void testSetCredentialsDispatchesAuthAfterTransaction() { assertThat(inTransaction.get()).isFalse(); - verify(asyncCommands).auth(eq("user"), eq("pass")); + verify(writer).write(argThat(isAuthCommand("user", "pass"))); } @Test @@ -136,7 +118,30 @@ void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws In thread.join(); assertThat(inTransaction.get()).isFalse(); - verify(asyncCommands).auth(eq("user"), eq("pass")); + verify(writer).write(argThat(isAuthCommand("user", "pass"))); + } + + public static ArgumentMatcher> isAuthCommand(String expectedUsername, + String expectedPassword) { + return new ArgumentMatcher>() { + + @Override + public boolean matches(RedisCommand command) { + if (command.getType() != CommandType.AUTH) { + return false; + } + + // Retrieve arguments (adjust based on your RedisCommand implementation) + return command.getArgs().toCommandString().equals(expectedUsername + " " + expectedPassword); + } + + @Override + public String toString() { + return String.format("Expected AUTH command with username=%s and password=%s", expectedUsername, + expectedPassword); + } + + }; } }