Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-50887: Mask value using a masking config #227

Merged
merged 28 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package org.hypertrace.core.query.service;

import com.typesafe.config.Config;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HandlerScopedMaskingConfig {
private static final String TENANT_SCOPED_MASKS_CONFIG_KEY = "tenantScopedMaskingCriteria";
private final Map<String, List<MaskValuesForTimeRange>> tenantToMaskValuesMap;
private HashMap<String, String> maskedValue = new HashMap<>();

public HandlerScopedMaskingConfig(Config config) {
if (config.hasPath(TENANT_SCOPED_MASKS_CONFIG_KEY)) {
this.tenantToMaskValuesMap =
config.getConfigList(TENANT_SCOPED_MASKS_CONFIG_KEY).stream()
.map(maskConfig -> new TenantMasks(maskConfig))
.collect(
Collectors.toMap(
tenantFilters -> tenantFilters.tenantId,
tenantFilters -> tenantFilters.maskValues));
} else {
this.tenantToMaskValuesMap = Collections.emptyMap();
}
}

public void parseColumns(ExecutionContext executionContext) {
String tenantId = executionContext.getTenantId();
maskedValue.clear();
if (!tenantToMaskValuesMap.containsKey(tenantId)) {
return;
}

Optional<QueryTimeRange> queryTimeRange = executionContext.getQueryTimeRange();
Instant queryStartTime, queryEndTime;
if (queryTimeRange.isPresent()) {
queryStartTime = queryTimeRange.get().getStartTime();
queryEndTime = queryTimeRange.get().getEndTime();
} else {
queryEndTime = Instant.MAX;
queryStartTime = Instant.MIN;
}
for (MaskValuesForTimeRange timeRangeAndMasks : tenantToMaskValuesMap.get(tenantId)) {
boolean timeRangeOverlap =
isTimeRangeOverlap(timeRangeAndMasks, queryStartTime, queryEndTime);

if (timeRangeOverlap) {
Map<String, String> attributeToMaskedValue =
timeRangeAndMasks.maskValues.attributeToMaskedValue;
for (String attribute : attributeToMaskedValue.keySet()) {
maskedValue.put(attribute, attributeToMaskedValue.get(attribute));
}
}
}
}

private static boolean isTimeRangeOverlap(
MaskValuesForTimeRange timeRangeAndMasks, Instant queryStartTime, Instant queryEndTime) {
boolean timeRangeOverlap = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default should be false, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following conditionals check for no overlap, i.e. they set the timeRangeOverlap to false. This statement is correct.


if (timeRangeAndMasks.getStartTimeMillis().isPresent()) {
Instant startTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getStartTimeMillis().get());
if (startTimeInstant.isBefore(queryStartTime) || startTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}
}

if (timeRangeAndMasks.getEndTimeMillis().isPresent()) {
Instant endTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getEndTimeMillis().get());
if (endTimeInstant.isBefore(queryStartTime) || endTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}
}
return timeRangeOverlap;
}

public boolean shouldMask(String attributeName) {
return this.maskedValue.containsKey(attributeName);
}

public String getMaskedValue(String attributeName) {
return this.maskedValue.get(attributeName);
}

@Value
@NonFinal
private class TenantMasks {
private static final String TENANT_ID_CONFIG_KEY = "tenantId";
private static final String TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY = "timeRangeAndMaskValues";
String tenantId;
List<MaskValuesForTimeRange> maskValues;

private TenantMasks(Config config) {
this.tenantId = config.getString(TENANT_ID_CONFIG_KEY);
this.maskValues =
config.getConfigList(TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY).stream()
.map(MaskValuesForTimeRange::new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter out the empty maskings

.collect(Collectors.toList());
}
}

@Value
private class MaskValues {
Map<String, String> attributeToMaskedValue;

MaskValues(Map<String, String> columnToMaskedValue) {
this.attributeToMaskedValue = columnToMaskedValue;
}
}

@Value
@NonFinal
class MaskValuesForTimeRange {
private static final String START_TIME_CONFIG_PATH = "startTimeMillis";
private static final String END_TIME_CONFIG_PATH = "endTimeMillis";
private static final String MASK_VALUE_CONFIG_PATH = "maskValues";
private static final String ATTRIBUTE_ID_CONFIG_PATH = "attributeId";
private static final String MASKED_VALUE_CONFIG_PATH = "maskedValue";
Optional<Long> startTimeMillis;
Optional<Long> endTimeMillis;
MaskValues maskValues;

private MaskValuesForTimeRange(Config config) {
if (config.hasPath(START_TIME_CONFIG_PATH) && config.hasPath(END_TIME_CONFIG_PATH)) {
this.startTimeMillis = Optional.of(config.getLong(START_TIME_CONFIG_PATH));
this.endTimeMillis = Optional.of(config.getLong(END_TIME_CONFIG_PATH));
} else {
startTimeMillis = Optional.empty();
endTimeMillis = Optional.empty();
}
if (config.hasPath(MASK_VALUE_CONFIG_PATH)) {
List<Config> maskedValuesList =
new ArrayList<>(config.getConfigList(MASK_VALUE_CONFIG_PATH));
HashMap<String, String> maskedValuesMap = new HashMap<>();
maskedValuesList.forEach(
maskedValue -> {
maskedValuesMap.put(
maskedValue.getString(ATTRIBUTE_ID_CONFIG_PATH),
maskedValue.getString(MASKED_VALUE_CONFIG_PATH));
});

maskValues = new MaskValues(maskedValuesMap);
} else {
maskValues = new MaskValues(new HashMap<>());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.client.ResultSetGroup;
import org.hypertrace.core.query.service.ExecutionContext;
import org.hypertrace.core.query.service.HandlerScopedFiltersConfig;
import org.hypertrace.core.query.service.HandlerScopedMaskingConfig;
import org.hypertrace.core.query.service.QueryCost;
import org.hypertrace.core.query.service.RequestHandler;
import org.hypertrace.core.query.service.api.Expression;
Expand Down Expand Up @@ -67,6 +68,7 @@ public class PinotBasedRequestHandler implements RequestHandler {
private QueryRequestToPinotSQLConverter request2PinotSqlConverter;
private final PinotMapConverter pinotMapConverter;
private HandlerScopedFiltersConfig handlerScopedFiltersConfig;
private HandlerScopedMaskingConfig handlerScopedMaskingConfig;
// The implementations of ResultSet are package private and hence there's no way to determine the
// shape of the results
// other than to do string comparison on the simple class names. In order to be able to unit test
Expand Down Expand Up @@ -143,6 +145,7 @@ private void processConfig(Config config) {

this.handlerScopedFiltersConfig =
new HandlerScopedFiltersConfig(config, this.startTimeAttributeName);
this.handlerScopedMaskingConfig = new HandlerScopedMaskingConfig(config);
LOG.info(
"Using {}ms as the threshold for logging slow queries of handler: {}",
slowQueryThreshold,
Expand Down Expand Up @@ -424,7 +427,7 @@ public Observable<Row> handleRequest(
LOG.debug("Query results: [ {} ]", resultSetGroup.toString());
}
// need to merge data especially for Pinot. That's why we need to track the map columns
return this.convert(resultSetGroup, executionContext.getSelectedColumns())
return this.convert(resultSetGroup, executionContext)
.doOnComplete(
() -> {
long requestTimeMs = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -493,10 +496,13 @@ private Filter rewriteLeafFilter(
return queryFilter;
}

Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes) {
Observable<Row> convert(ResultSetGroup resultSetGroup, ExecutionContext executionContext) {
String tenantId = executionContext.getTenantId();
LinkedHashSet<String> selectedAttributes = executionContext.getSelectedColumns();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this also inside resultSetGroup.getResultSetCount() > 0?

List<Row.Builder> rowBuilderList = new ArrayList<>();
if (resultSetGroup.getResultSetCount() > 0) {
ResultSet resultSet = resultSetGroup.getResultSet(0);
handlerScopedMaskingConfig.parseColumns(executionContext);
// Pinot has different Response format for selection and aggregation/group by query.
if (resultSetTypePredicateProvider.isSelectionResultSetType(resultSet)) {
// map merging is only supported in the selection. Filtering and Group by has its own
Expand All @@ -508,6 +514,7 @@ Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> sel
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList);
}
}

return Observable.fromIterable(rowBuilderList)
.map(Builder::build)
.doOnNext(row -> LOG.debug("collect a row: {}", row));
Expand Down Expand Up @@ -536,7 +543,11 @@ private void handleSelection(
for (String logicalName : selectedAttributes) {
// colVal will never be null. But getDataRow can throw a runtime exception if it failed
// to retrieve data
String colVal = resultAnalyzer.getDataFromRow(rowId, logicalName);
String colVal =
!handlerScopedMaskingConfig.shouldMask(logicalName)
? resultAnalyzer.getDataFromRow(rowId, logicalName)
: handlerScopedMaskingConfig.getMaskedValue(logicalName);

builder.addColumn(Value.newBuilder().setString(colVal).build());
}
}
Expand Down Expand Up @@ -678,4 +689,8 @@ private boolean isInvalidExpression(Expression expression) {
&& viewDefinition.getColumnType(expression.getAttributeExpression().getAttributeId())
!= ValueType.STRING_MAP;
}

HandlerScopedMaskingConfig getHandlerScopedMaskingConfig() {
return handlerScopedMaskingConfig;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.typesafe.config.ConfigFactory;
import io.reactivex.rxjava3.core.Observable;
import java.io.IOException;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -1353,7 +1352,9 @@ public void testConvertSimpleSelectionsQueryResultSet() throws IOException {
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), resultTable);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
resultTable);
}

@Test
Expand All @@ -1371,7 +1372,9 @@ public void testConvertAggregationColumnsQueryResultSet() throws IOException {
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), resultTable);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
resultTable);
}

@Test
Expand Down Expand Up @@ -1432,7 +1435,9 @@ public void testConvertSelectionsWithMapKeysAndValuesQueryResultSet() throws IOE
};

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), expectedRows);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
expectedRows);
}

@Test
Expand Down Expand Up @@ -1467,7 +1472,9 @@ public void testConvertMultipleResultSetsInFResultSetGroup() throws IOException
};

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), expectedRows);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
expectedRows);
}

@Test
Expand Down Expand Up @@ -1756,6 +1763,80 @@ public boolean isResultTableResultSetType(ResultSet resultSet) {
}
}

@Test
public void testMaskColumnValue() throws IOException {
for (Config config : serviceConfig.getConfigList("queryRequestHandlersConfig")) {
if (!isPinotConfig(config)) {
continue;
}

if (!config.getString("name").equals("span-event-view-handler")) {
continue;
}

// Mock the PinotClient
PinotClient pinotClient = mock(PinotClient.class);
PinotClientFactory factory = mock(PinotClientFactory.class);
when(factory.getPinotClient(any())).thenReturn(pinotClient);

String[][] resultTable =
new String[][] {
{
"test-span-id-1", "trace-id-1",
},
{"test-span-id-2", "trace-id-1"},
{"test-span-id-3", "trace-id-1"},
{"test-span-id-4", "trace-id-2"}
};
List<String> columnNames = List.of("span_id", "trace_id");
ResultSet resultSet = mockResultSet(4, 2, columnNames, resultTable);
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

PinotBasedRequestHandler handler =
new PinotBasedRequestHandler(
config.getString("name"),
config.getConfig("requestHandlerInfo"),
new ResultSetTypePredicateProvider() {
@Override
public boolean isSelectionResultSetType(ResultSet resultSet) {
return true;
}

@Override
public boolean isResultTableResultSetType(ResultSet resultSet) {
return false;
}
},
factory);

QueryRequest request =
QueryRequest.newBuilder()
.addSelection(QueryRequestBuilderUtils.createColumnExpression("EVENT.id"))
.addSelection(QueryRequestBuilderUtils.createColumnExpression("EVENT.traceId"))
.build();
ExecutionContext context = new ExecutionContext("maskTenant", request);

// The query filter is based on both isEntrySpan and startTime. Since the viewFilter
// checks for both the true and false values of isEntrySpan and query filter only needs
// "true", isEntrySpan predicate is still passed to the store in the query.
String expectedQuery = "Select span_id, trace_id FROM spanEventView WHERE tenant_id = ?";
Params params = Params.newBuilder().addStringParam("maskTenant").build();
when(pinotClient.executeQuery(expectedQuery, params)).thenReturn(resultSetGroup);

String[][] expectedTable =
new String[][] {
{
"*", "trace-id-1",
},
{"*", "trace-id-1"},
{"*", "trace-id-1"},
{"*", "trace-id-2"}
};

verifyResponseRows(handler.handleRequest(request, context), expectedTable);
}
}

@Test
public void testViewColumnFilterRemovalComplexCase() throws IOException {
for (Config config : serviceConfig.getConfigList("queryRequestHandlersConfig")) {
Expand Down
15 changes: 15 additions & 0 deletions query-service-impl/src/test/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ service.config = {
]
}
]
tenantScopedMaskingCriteria = [
{
"tenantId": "maskTenant",
"timeRangeAndMaskValues": [
{
"maskValues": [
{
"attributeId": "EVENT.id",
"maskedValue": "*"
}
]
},
]
}
]
viewDefinition = {
viewName = spanEventView
mapFields = ["tags"]
Expand Down
Loading
Loading