Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-50887: Mask value using a masking config #227

Merged
merged 28 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package org.hypertrace.core.query.service;

import com.typesafe.config.Config;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HandlerScopedMaskingConfig {
private static final String TENANT_SCOPED_MASKS_CONFIG_KEY = "tenantScopedMaskingConfig";
private Map<String, List<TimeRangeToMaskedAttributes>> tenantToTimeRangeMaskedAttributes =
Collections.emptyMap();

public HandlerScopedMaskingConfig(Config config) {
if (config.hasPath(TENANT_SCOPED_MASKS_CONFIG_KEY)) {
this.tenantToTimeRangeMaskedAttributes =
config.getConfigList(TENANT_SCOPED_MASKS_CONFIG_KEY).stream()
.map(TenantMaskingConfig::new)
.collect(
Collectors.toMap(
TenantMaskingConfig::getTenantId,
TenantMaskingConfig::getTimeRangeToMaskedAttributes));
}
}

public Set<String> getMaskedAttributes(ExecutionContext executionContext) {
String tenantId = executionContext.getTenantId();
HashSet<String> maskedAttributes = new HashSet<>();
if (!tenantToTimeRangeMaskedAttributes.containsKey(tenantId)) {
return maskedAttributes;
}

Optional<QueryTimeRange> queryTimeRange = executionContext.getQueryTimeRange();
Instant queryStartTime = Instant.MIN, queryEndTime = Instant.MAX;
if (queryTimeRange.isPresent()) {
queryStartTime = queryTimeRange.get().getStartTime();
queryEndTime = queryTimeRange.get().getEndTime();
}
for (TimeRangeToMaskedAttributes timeRangeAndMasks :
tenantToTimeRangeMaskedAttributes.get(tenantId)) {
if (isTimeRangeOverlap(timeRangeAndMasks, queryStartTime, queryEndTime)) {
maskedAttributes.addAll(timeRangeAndMasks.maskedAttributes);
}
}
return maskedAttributes;
}

private static boolean isTimeRangeOverlap(
TimeRangeToMaskedAttributes timeRangeAndMasks, Instant queryStartTime, Instant queryEndTime) {
return !(timeRangeAndMasks.startTimeMillis.isAfter(queryEndTime)
|| timeRangeAndMasks.endTimeMillis.isBefore(queryStartTime));
}

@Value
@NonFinal
static class TenantMaskingConfig {
private static final String TENANT_ID_CONFIG_KEY = "tenantId";
private static final String TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY =
"timeRangeToMaskedAttributes";
String tenantId;
List<TimeRangeToMaskedAttributes> timeRangeToMaskedAttributes;

private TenantMaskingConfig(Config config) {
this.tenantId = config.getString(TENANT_ID_CONFIG_KEY);
this.timeRangeToMaskedAttributes =
config.getConfigList(TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY).stream()
.map(TimeRangeToMaskedAttributes::new)
.filter(
timeRangeToMaskedAttributes -> {
if (!timeRangeToMaskedAttributes.isValid()) {
log.warn(
"Invalid masking configuration for tenant: {}. Either the time range is missing or the mask list is empty.",
this.tenantId);
return false;
}
return true;
})
.collect(Collectors.toList());
}
}

@NonFinal
static class TimeRangeToMaskedAttributes {
private static final String START_TIME_CONFIG_PATH = "startTimeMillis";
private static final String END_TIME_CONFIG_PATH = "endTimeMillis";
private static final String MASK_ATTRIBUTES_CONFIG_PATH = "maskedAttributes";
Instant startTimeMillis = null;
Instant endTimeMillis = null;
ArrayList<String> maskedAttributes = new ArrayList<>();

private TimeRangeToMaskedAttributes(Config config) {
if (config.hasPath(START_TIME_CONFIG_PATH) && config.hasPath(END_TIME_CONFIG_PATH)) {
Instant startTimeMillis = Instant.ofEpochMilli(config.getLong(START_TIME_CONFIG_PATH));
Instant endTimeMillis = Instant.ofEpochMilli(config.getLong(END_TIME_CONFIG_PATH));

if (startTimeMillis.isBefore(endTimeMillis)) {
this.startTimeMillis = startTimeMillis;
this.endTimeMillis = endTimeMillis;
if (config.hasPath(MASK_ATTRIBUTES_CONFIG_PATH)) {
maskedAttributes = new ArrayList<>(config.getStringList(MASK_ATTRIBUTES_CONFIG_PATH));
}
}
}
}

boolean isValid() {
return startTimeMillis != null && endTimeMillis != null && !maskedAttributes.isEmpty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.client.ResultSetGroup;
import org.hypertrace.core.query.service.ExecutionContext;
import org.hypertrace.core.query.service.HandlerScopedFiltersConfig;
import org.hypertrace.core.query.service.HandlerScopedMaskingConfig;
import org.hypertrace.core.query.service.QueryCost;
import org.hypertrace.core.query.service.RequestHandler;
import org.hypertrace.core.query.service.api.Expression;
Expand Down Expand Up @@ -58,6 +59,10 @@ public class PinotBasedRequestHandler implements RequestHandler {
private static final String START_TIME_ATTRIBUTE_NAME_CONFIG_KEY = "startTimeAttributeName";
private static final String SLOW_QUERY_THRESHOLD_MS_CONFIG = "slowQueryThresholdMs";

private static final String DEFAULT_MASKED_VALUE = "*";
// This is how empty list is represented in Pinot
private static final String ARRAY_TYPE_MASKED_VALUE = "[\"\"]";

private static final int DEFAULT_SLOW_QUERY_THRESHOLD_MS = 3000;
private static final Set<Operator> GTE_OPERATORS = Set.of(Operator.GE, Operator.GT, Operator.EQ);

Expand All @@ -67,6 +72,7 @@ public class PinotBasedRequestHandler implements RequestHandler {
private QueryRequestToPinotSQLConverter request2PinotSqlConverter;
private final PinotMapConverter pinotMapConverter;
private HandlerScopedFiltersConfig handlerScopedFiltersConfig;
private HandlerScopedMaskingConfig handlerScopedMaskingConfig;
// The implementations of ResultSet are package private and hence there's no way to determine the
// shape of the results
// other than to do string comparison on the simple class names. In order to be able to unit test
Expand Down Expand Up @@ -143,6 +149,7 @@ private void processConfig(Config config) {

this.handlerScopedFiltersConfig =
new HandlerScopedFiltersConfig(config, this.startTimeAttributeName);
this.handlerScopedMaskingConfig = new HandlerScopedMaskingConfig(config);
LOG.info(
"Using {}ms as the threshold for logging slow queries of handler: {}",
slowQueryThreshold,
Expand Down Expand Up @@ -424,7 +431,7 @@ public Observable<Row> handleRequest(
LOG.debug("Query results: [ {} ]", resultSetGroup.toString());
}
// need to merge data especially for Pinot. That's why we need to track the map columns
return this.convert(resultSetGroup, executionContext.getSelectedColumns())
return this.convert(resultSetGroup, executionContext)
.doOnComplete(
() -> {
long requestTimeMs = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -493,17 +500,21 @@ private Filter rewriteLeafFilter(
return queryFilter;
}

Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes) {
Observable<Row> convert(ResultSetGroup resultSetGroup, ExecutionContext executionContext) {
List<Row.Builder> rowBuilderList = new ArrayList<>();
if (resultSetGroup.getResultSetCount() > 0) {
LinkedHashSet<String> selectedAttributes = executionContext.getSelectedColumns();
Set<String> maskedAttributes =
handlerScopedMaskingConfig.getMaskedAttributes(executionContext);
ResultSet resultSet = resultSetGroup.getResultSet(0);
// Pinot has different Response format for selection and aggregation/group by query.
if (resultSetTypePredicateProvider.isSelectionResultSetType(resultSet)) {
// map merging is only supported in the selection. Filtering and Group by has its own
// syntax in Pinot
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes);
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else if (resultSetTypePredicateProvider.isResultTableResultSetType(resultSet)) {
handleTableFormatResultSet(resultSetGroup, rowBuilderList);
handleTableFormatResultSet(
resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else {
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList);
}
Expand All @@ -516,7 +527,8 @@ Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> sel
private void handleSelection(
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes) {
LinkedHashSet<String> selectedAttributes,
Set<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
Expand All @@ -536,7 +548,11 @@ private void handleSelection(
for (String logicalName : selectedAttributes) {
// colVal will never be null. But getDataRow can throw a runtime exception if it failed
// to retrieve data
String colVal = resultAnalyzer.getDataFromRow(rowId, logicalName);
String colVal =
maskedAttributes.contains(logicalName)
? DEFAULT_MASKED_VALUE
: resultAnalyzer.getDataFromRow(rowId, logicalName);

builder.addColumn(Value.newBuilder().setString(colVal).build());
}
}
Expand Down Expand Up @@ -588,10 +604,15 @@ private void handleAggregationAndGroupBy(
}

private void handleTableFormatResultSet(
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList) {
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes,
Set<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
PinotResultAnalyzer resultAnalyzer =
PinotResultAnalyzer.create(resultSet, selectedAttributes, viewDefinition);
for (int rowIdx = 0; rowIdx < resultSet.getRowCount(); rowIdx++) {
Builder builder;
builder = Row.newBuilder();
Expand All @@ -602,8 +623,13 @@ private void handleTableFormatResultSet(
// Read the key and value column values. The columns should be side by side. That's how
// the Pinot query
// is structured
String logicalName = resultAnalyzer.getLogicalNameFromColIdx(colIdx);
String mapKeys = resultSet.getString(rowIdx, colIdx);
String mapVals = resultSet.getString(rowIdx, colIdx + 1);
String mapVals =
maskedAttributes.contains(logicalName)
? ARRAY_TYPE_MASKED_VALUE
: resultSet.getString(rowIdx, colIdx + 1);

try {
builder.addColumn(
Value.newBuilder().setString(pinotMapConverter.merge(mapKeys, mapVals)).build());
Expand All @@ -615,7 +641,11 @@ private void handleTableFormatResultSet(
// advance colIdx by 1 since we have read 2 columns
colIdx++;
} else {
String val = resultSet.getString(rowIdx, colIdx);
String logicalName = resultAnalyzer.getLogicalNameFromColIdx(colIdx);
String val =
maskedAttributes.contains(logicalName)
? DEFAULT_MASKED_VALUE
: resultSet.getString(rowIdx, colIdx);
builder.addColumn(Value.newBuilder().setString(val).build());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ class PinotResultAnalyzer {
private final ViewDefinition viewDefinition;
private final Map<String, RateLimiter> attributeLogRateLimitter;
private final PinotMapConverter pinotMapConverter;
private final Map<Integer, String> indexToLogicalName;

PinotResultAnalyzer(
ResultSet resultSet,
LinkedHashSet<String> selectedAttributes,
ViewDefinition viewDefinition,
Map<String, Integer> mapLogicalNameToKeyIndex,
Map<String, Integer> mapLogicalNameToValueIndex,
Map<String, Integer> logicalNameToPhysicalNameIndex) {
Map<String, Integer> logicalNameToPhysicalNameIndex,
Map<Integer, String> indexToLogicalName) {
this.mapLogicalNameToKeyIndex = mapLogicalNameToKeyIndex;
this.mapLogicalNameToValueIndex = mapLogicalNameToValueIndex;
this.logicalNameToPhysicalNameIndex = logicalNameToPhysicalNameIndex;
this.indexToLogicalName = indexToLogicalName;
this.resultSet = resultSet;
this.viewDefinition = viewDefinition;
this.attributeLogRateLimitter = new HashMap<>();
Expand All @@ -53,6 +56,7 @@ static PinotResultAnalyzer create(
Map<String, Integer> mapLogicalNameToKeyIndex = new HashMap<>();
Map<String, Integer> mapLogicalNameToValueIndex = new HashMap<>();
Map<String, Integer> logicalNameToPhysicalNameIndex = new HashMap<>();
Map<Integer, String> indexToLogicalName = new HashMap<>();

for (String logicalName : selectedAttributes) {
if (viewDefinition.isMap(logicalName)) {
Expand All @@ -62,8 +66,10 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(keyPhysicalName)) {
mapLogicalNameToKeyIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
} else if (physName.equalsIgnoreCase(valuePhysicalName)) {
mapLogicalNameToValueIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
}
}
} else {
Expand All @@ -73,21 +79,24 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(names.get(0))) {
logicalNameToPhysicalNameIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
break;
}
}
}
}
LOG.info("Map LogicalName to Key Index: {} ", mapLogicalNameToKeyIndex);
LOG.info("Map LogicalName to Value Index: {}", mapLogicalNameToValueIndex);
LOG.info("Attributes to Index: {}", logicalNameToPhysicalNameIndex);
LOG.debug("Map LogicalName to Key Index: {} ", mapLogicalNameToKeyIndex);
LOG.debug("Map LogicalName to Value Index: {}", mapLogicalNameToValueIndex);
LOG.debug("Attributes to Index: {}", logicalNameToPhysicalNameIndex);
LOG.debug("Index to LogicalName: {}", indexToLogicalName);
return new PinotResultAnalyzer(
resultSet,
selectedAttributes,
viewDefinition,
mapLogicalNameToKeyIndex,
mapLogicalNameToValueIndex,
logicalNameToPhysicalNameIndex);
logicalNameToPhysicalNameIndex,
indexToLogicalName);
}

@VisibleForTesting
Expand Down Expand Up @@ -149,4 +158,8 @@ String getDataFromRow(int rowIndex, String logicalName) {
}
return result;
}

String getLogicalNameFromColIdx(Integer colIdx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Lets see if we can use Optional as return type.

return indexToLogicalName.get(colIdx);
}
}
Loading
Loading