Skip to content

Commit

Permalink
Add unit tests for setCredenatials
Browse files Browse the repository at this point in the history
  • Loading branch information
ggivo committed Dec 13, 2024
1 parent 086ccf3 commit 61158f2
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER;
import static io.lettuce.core.protocol.CommandType.*;

import java.nio.CharBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -355,6 +354,29 @@ public void setClientName(String clientName) {
dispatch((RedisCommand) async);
}

/**
* Authenticates the current connection using the provided credentials.
* <p>
* Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active
* transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC}
* command is executed, ensuring that authentication does not interfere with ongoing transactions.
* </p>
*
* @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed.
*
* <p>
* <b>Behavior:</b>
* <ul>
* <li>If the provided credentials are {@code null}, the method exits immediately.</li>
* <li>If a transaction is active (as indicated by {@code inTransaction}), the {@code AUTH} command is not dispatched
* immediately but deferred until the transaction ends.</li>
* <li>If no transaction is active, the {@code AUTH} command is dispatched immediately using the provided
* credentials.</li>
* </ul>
* </p>
*
* @see RedisAsyncCommands#auth
*/
public void setCredentials(RedisCredentials credentials) {
if (credentials == null) {
return;
Expand All @@ -363,7 +385,7 @@ public void setCredentials(RedisCredentials credentials) {
try {
credentialsRef.set(credentials);
if (!inTransaction.get()) {
dispatchAuthCommand(credentialsRef.getAndSet(null));
dispatchAuth(credentialsRef.getAndSet(null));
}
} finally {
reAuthSafety.unlock();
Expand Down Expand Up @@ -394,16 +416,16 @@ public void setAuthenticationHandler(RedisAuthenticationHandler handler) {
authHandler = handler;
}

private void dispatchAuthCommand(RedisCredentials credentials) {
protected void dispatchAuth(RedisCredentials credentials) {
if (credentials == null) {
return;
}

RedisFuture<String> auth;
if (credentials.getUsername() != null) {
auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword()));
auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword()));
} else {
auth = async().auth(CharBuffer.wrap(credentials.getPassword()));
auth = async().auth(String.valueOf(credentials.getPassword()));
}
auth.thenRun(() -> {
publishReauthEvent();
Expand Down Expand Up @@ -441,5 +463,4 @@ private String getEpid() {

return ((Endpoint) writer).getId();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package io.lettuce.core;

import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.PushHandler;
import io.lettuce.core.resource.ClientResources;
import io.lettuce.core.tracing.Tracing;
import io.lettuce.test.ReflectionTestUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

import java.lang.reflect.Field;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class StatefulRedisConnectionImplUnitTests extends TestSupport {

RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
StatefulRedisConnectionImpl<String,String> connection;

@Mock
RedisAsyncCommandsImpl<String, String> asyncCommands;

@Mock
PushHandler pushHandler;

@Mock
RedisChannelWriter writer;

@Mock
ClientResources clientResources;

@Mock
Tracing tracing;

@BeforeEach
void setup() throws NoSuchFieldException, IllegalAccessException {
when(writer.getClientResources()).thenReturn(clientResources);
when(clientResources.tracing()).thenReturn(tracing);
when(tracing.isEnabled()).thenReturn(false);
when(asyncCommands.auth(any(CharSequence.class)))
.thenAnswer( invocation -> {
String pass = invocation.getArgument(0);
AsyncCommand<String, String, String> auth = new AsyncCommand<>(commandBuilder.auth(pass));
auth.complete();
return auth;
});
when(asyncCommands.auth(anyString(), any(CharSequence.class)))
.thenAnswer( invocation -> {
String user = invocation.getArgument(0); // Capture username
String pass = invocation.getArgument(1); // Capture password
AsyncCommand<String, String, String> auth = new AsyncCommand<>(commandBuilder.auth(user, pass));
auth.complete();
return auth;
});

Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async");
asyncField.setAccessible(true);


connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1));
asyncField.set(connection,asyncCommands);
}

@Test
public void testSetCredentialsWhenCredentialsAreNull() {
connection.setCredentials(null);

verify(asyncCommands, never()).auth(any(CharSequence.class));
verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
}

@Test
void testSetCredentialsDispatchesAuthWhenNotInTransaction() {
connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
verify(asyncCommands).auth(eq("user"), eq("pass"));
}


@Test
void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
inTransaction.set(true);

connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));

verify(asyncCommands, never()).auth(any(CharSequence.class));
verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
}


@Test
void testSetCredentialsDispatchesAuthAfterTransaction() {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");

connection.dispatch(commandBuilder.multi());
assertThat(inTransaction.get()).isTrue();

connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
connection.dispatch(commandBuilder.discard());

assertThat(inTransaction.get()).isFalse();

verify(asyncCommands).auth(eq("user"), eq("pass"));
}

@Test
void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws InterruptedException {
AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");

connection.dispatch(commandBuilder.multi());
assertThat(inTransaction.get()).isTrue();

Thread thread = new Thread(() -> {
connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
});
thread.start();

connection.dispatch(commandBuilder.discard());

thread.join();

assertThat(inTransaction.get()).isFalse();
verify(asyncCommands).auth(eq("user"), eq("pass"));
}

}

0 comments on commit 61158f2

Please sign in to comment.