Skip to content

Commit

Permalink
Skip preProcessing of auth command to avoid replacing the credential …
Browse files Browse the repository at this point in the history
…provider with static one provider

Add unit tests for setCredentials
  • Loading branch information
ggivo committed Dec 13, 2024
1 parent 61158f2 commit 9a0e513
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 43 deletions.
38 changes: 24 additions & 14 deletions src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,9 @@ public void setClientName(String clientName) {
/**
* Authenticates the current connection using the provided credentials.
* <p>
* Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active
* transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC}
* command is executed, ensuring that authentication does not interfere with ongoing transactions.
* Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection
* is within an active transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD}
* or {@code EXEC} command is executed, ensuring that authentication does not interfere with ongoing transactions.
* </p>
*
* @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed.
Expand Down Expand Up @@ -421,20 +421,29 @@ protected void dispatchAuth(RedisCredentials credentials) {
return;
}

RedisFuture<String> auth;
// dispatch directly to avoid AUTH preprocessing overrides credentials provider
RedisCommand<K, V, ?> auth = super.dispatch(authCommand(credentials));
if (auth instanceof CompleteableCommand) {
((CompleteableCommand<?>) auth).onComplete((status, throwable) -> {
if (throwable != null) {
logger.error("Re-authentication failed {}.", getEpid(), throwable);
publishReauthFailedEvent(throwable);
} else {
logger.info("Re-authentication succeeded {}.", getEpid());
publishReauthEvent();
}
});
}
}

private AsyncCommand<K, V, String> authCommand(RedisCredentials credentials) {
CommandArgs<K, V> args = new CommandArgs<>(codec);
if (credentials.getUsername() != null) {
auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword()));
args.add(credentials.getUsername()).add(credentials.getPassword());
} else {
auth = async().auth(String.valueOf(credentials.getPassword()));
args.add(credentials.getPassword());
}
auth.thenRun(() -> {
publishReauthEvent();
logger.info("Re-authentication succeeded {}.", getEpid());
}).exceptionally(throwable -> {
publishReauthFailedEvent(throwable);
logger.error("Re-authentication failed {}.", getEpid(), throwable);
return null;
});
return new AsyncCommand<>(new Command<>(AUTH, new StatusOutput<>(codec), args));
}

private void publishReauthEvent() {
Expand Down Expand Up @@ -463,4 +472,5 @@ private String getEpid() {

return ((Endpoint) writer).getId();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.PushHandler;
import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.core.resource.ClientResources;
import io.lettuce.core.tracing.Tracing;
import io.lettuce.test.ReflectionTestUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentMatcher;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
Expand All @@ -21,6 +25,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
Expand All @@ -31,10 +36,8 @@
public class StatefulRedisConnectionImplUnitTests extends TestSupport {

RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
StatefulRedisConnectionImpl<String,String> connection;

@Mock
RedisAsyncCommandsImpl<String, String> asyncCommands;
StatefulRedisConnectionImpl<String, String> connection;

@Mock
PushHandler pushHandler;
Expand All @@ -53,57 +56,36 @@ void setup() throws NoSuchFieldException, IllegalAccessException {
when(writer.getClientResources()).thenReturn(clientResources);
when(clientResources.tracing()).thenReturn(tracing);
when(tracing.isEnabled()).thenReturn(false);
when(asyncCommands.auth(any(CharSequence.class)))
.thenAnswer( invocation -> {
String pass = invocation.getArgument(0);
AsyncCommand<String, String, String> auth = new AsyncCommand<>(commandBuilder.auth(pass));
auth.complete();
return auth;
});
when(asyncCommands.auth(anyString(), any(CharSequence.class)))
.thenAnswer( invocation -> {
String user = invocation.getArgument(0); // Capture username
String pass = invocation.getArgument(1); // Capture password
AsyncCommand<String, String, String> auth = new AsyncCommand<>(commandBuilder.auth(user, pass));
auth.complete();
return auth;
});

Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async");
asyncField.setAccessible(true);


connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1));
asyncField.set(connection,asyncCommands);
}

@Test
public void testSetCredentialsWhenCredentialsAreNull() {
connection.setCredentials(null);

verify(asyncCommands, never()).auth(any(CharSequence.class));
verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
verify(writer, never()).write(ArgumentMatchers.<RedisCommand<String, String, String>> any());
}

@Test
void testSetCredentialsDispatchesAuthWhenNotInTransaction() {
connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
verify(asyncCommands).auth(eq("user"), eq("pass"));
verify(writer).write(argThat(isAuthCommand("user", "pass")));
}


@Test
void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
inTransaction.set(true);

connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));

verify(asyncCommands, never()).auth(any(CharSequence.class));
verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
verify(writer, never()).write(ArgumentMatchers.<RedisCommand<String, String, String>> any());
}


@Test
void testSetCredentialsDispatchesAuthAfterTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
Expand All @@ -116,7 +98,7 @@ void testSetCredentialsDispatchesAuthAfterTransaction() {

assertThat(inTransaction.get()).isFalse();

verify(asyncCommands).auth(eq("user"), eq("pass"));
verify(writer).write(argThat(isAuthCommand("user", "pass")));
}

@Test
Expand All @@ -136,7 +118,30 @@ void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws In
thread.join();

assertThat(inTransaction.get()).isFalse();
verify(asyncCommands).auth(eq("user"), eq("pass"));
verify(writer).write(argThat(isAuthCommand("user", "pass")));
}

public static <K, V, T> ArgumentMatcher<RedisCommand<K, V, T>> isAuthCommand(String expectedUsername,
String expectedPassword) {
return new ArgumentMatcher<RedisCommand<K, V, T>>() {

@Override
public boolean matches(RedisCommand command) {
if (command.getType() != CommandType.AUTH) {
return false;
}

// Retrieve arguments (adjust based on your RedisCommand implementation)
return command.getArgs().toCommandString().equals(expectedUsername + " " + expectedPassword);
}

@Override
public String toString() {
return String.format("Expected AUTH command with username=%s and password=%s", expectedUsername,
expectedPassword);
}

};
}

}

0 comments on commit 9a0e513

Please sign in to comment.