Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-50887: Mask value using a masking config #227

Merged
merged 28 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package org.hypertrace.core.query.service;

import com.typesafe.config.Config;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HandlerScopedMaskingConfig {
private static final String TENANT_SCOPED_MASKS_CONFIG_KEY = "tenantScopedMaskingCriteria";
private final Map<String, List<MaskValuesForTimeRange>> tenantToMaskValuesMap;

public HandlerScopedMaskingConfig(Config config) {
if (config.hasPath(TENANT_SCOPED_MASKS_CONFIG_KEY)) {
this.tenantToMaskValuesMap =
config.getConfigList(TENANT_SCOPED_MASKS_CONFIG_KEY).stream()
.map(maskConfig -> new TenantMasks(maskConfig))
.collect(
Collectors.toMap(
tenantFilters -> tenantFilters.tenantId,
tenantFilters -> tenantFilters.maskValues));
} else {
this.tenantToMaskValuesMap = Collections.emptyMap();
}
}

public List<String> getMaskedAttributes(ExecutionContext executionContext) {
String tenantId = executionContext.getTenantId();
List<String> maskedAttributes = new ArrayList<>();
// maskedValue.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this.

if (!tenantToMaskValuesMap.containsKey(tenantId)) {
return maskedAttributes;
}

Optional<QueryTimeRange> queryTimeRange = executionContext.getQueryTimeRange();
Instant queryStartTime, queryEndTime;
if (queryTimeRange.isPresent()) {
queryStartTime = queryTimeRange.get().getStartTime();
queryEndTime = queryTimeRange.get().getEndTime();
} else {
queryEndTime = Instant.MAX;
queryStartTime = Instant.MIN;
}
for (MaskValuesForTimeRange timeRangeAndMasks : tenantToMaskValuesMap.get(tenantId)) {
boolean timeRangeOverlap =
isTimeRangeOverlap(timeRangeAndMasks, queryStartTime, queryEndTime);

if (timeRangeOverlap) {
maskedAttributes.addAll(timeRangeAndMasks.maskedAttributes);
}
}

return maskedAttributes;
}

private static boolean isTimeRangeOverlap(
MaskValuesForTimeRange timeRangeAndMasks, Instant queryStartTime, Instant queryEndTime) {
boolean timeRangeOverlap = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default should be false, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following conditionals check for no overlap, i.e. they set the timeRangeOverlap to false. This statement is correct.


if (timeRangeAndMasks.getStartTimeMillis().isPresent()) {
Instant startTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getStartTimeMillis().get());
if (startTimeInstant.isBefore(queryStartTime) || startTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}
}

if (timeRangeAndMasks.getEndTimeMillis().isPresent()) {
Instant endTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getEndTimeMillis().get());
if (endTimeInstant.isBefore(queryStartTime) || endTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}
}
return timeRangeOverlap;
}

@Value
@NonFinal
private class TenantMasks {
private static final String TENANT_ID_CONFIG_KEY = "tenantId";
private static final String TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY = "timeRangeAndMaskValues";
String tenantId;
List<MaskValuesForTimeRange> maskValues;

private TenantMasks(Config config) {
this.tenantId = config.getString(TENANT_ID_CONFIG_KEY);
this.maskValues =
config.getConfigList(TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY).stream()
.map(MaskValuesForTimeRange::new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter out the empty maskings

.collect(Collectors.toList());
}
}

@Value
@NonFinal
class MaskValuesForTimeRange {
private static final String START_TIME_CONFIG_PATH = "startTimeMillis";
private static final String END_TIME_CONFIG_PATH = "endTimeMillis";
private static final String MASK_ATTRIBUTES_CONFIG_PATH = "maskedAttributes";
Optional<Long> startTimeMillis;
Optional<Long> endTimeMillis;
ArrayList<String> maskedAttributes;

private MaskValuesForTimeRange(Config config) {
if (config.hasPath(START_TIME_CONFIG_PATH) && config.hasPath(END_TIME_CONFIG_PATH)) {
this.startTimeMillis = Optional.of(config.getLong(START_TIME_CONFIG_PATH));
this.endTimeMillis = Optional.of(config.getLong(END_TIME_CONFIG_PATH));
} else {
startTimeMillis = Optional.empty();
endTimeMillis = Optional.empty();
}
if (config.hasPath(MASK_ATTRIBUTES_CONFIG_PATH)) {
maskedAttributes = new ArrayList<>(config.getStringList(MASK_ATTRIBUTES_CONFIG_PATH));
} else {
maskedAttributes = new ArrayList<>();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.client.ResultSetGroup;
import org.hypertrace.core.query.service.ExecutionContext;
import org.hypertrace.core.query.service.HandlerScopedFiltersConfig;
import org.hypertrace.core.query.service.HandlerScopedMaskingConfig;
import org.hypertrace.core.query.service.QueryCost;
import org.hypertrace.core.query.service.RequestHandler;
import org.hypertrace.core.query.service.api.Expression;
Expand Down Expand Up @@ -58,6 +59,10 @@ public class PinotBasedRequestHandler implements RequestHandler {
private static final String START_TIME_ATTRIBUTE_NAME_CONFIG_KEY = "startTimeAttributeName";
private static final String SLOW_QUERY_THRESHOLD_MS_CONFIG = "slowQueryThresholdMs";

private static final String MASKED_VALUE = "*";
// This is how empty list is represented in Pinot
private static final String PINOT_EMPTY_LIST = "[\"\"]";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PINOT_EMPTY_LIST -> ARRAY_TYPE_MASKED_VALUE
MASKED_VALUE-> DEFAULT_MASKED_VALUE


private static final int DEFAULT_SLOW_QUERY_THRESHOLD_MS = 3000;
private static final Set<Operator> GTE_OPERATORS = Set.of(Operator.GE, Operator.GT, Operator.EQ);

Expand All @@ -67,6 +72,7 @@ public class PinotBasedRequestHandler implements RequestHandler {
private QueryRequestToPinotSQLConverter request2PinotSqlConverter;
private final PinotMapConverter pinotMapConverter;
private HandlerScopedFiltersConfig handlerScopedFiltersConfig;
private HandlerScopedMaskingConfig handlerScopedMaskingConfig;
// The implementations of ResultSet are package private and hence there's no way to determine the
// shape of the results
// other than to do string comparison on the simple class names. In order to be able to unit test
Expand Down Expand Up @@ -143,6 +149,7 @@ private void processConfig(Config config) {

this.handlerScopedFiltersConfig =
new HandlerScopedFiltersConfig(config, this.startTimeAttributeName);
this.handlerScopedMaskingConfig = new HandlerScopedMaskingConfig(config);
LOG.info(
"Using {}ms as the threshold for logging slow queries of handler: {}",
slowQueryThreshold,
Expand Down Expand Up @@ -424,7 +431,7 @@ public Observable<Row> handleRequest(
LOG.debug("Query results: [ {} ]", resultSetGroup.toString());
}
// need to merge data especially for Pinot. That's why we need to track the map columns
return this.convert(resultSetGroup, executionContext.getSelectedColumns())
return this.convert(resultSetGroup, executionContext)
.doOnComplete(
() -> {
long requestTimeMs = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -493,21 +500,26 @@ private Filter rewriteLeafFilter(
return queryFilter;
}

Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes) {
Observable<Row> convert(ResultSetGroup resultSetGroup, ExecutionContext executionContext) {
LinkedHashSet<String> selectedAttributes = executionContext.getSelectedColumns();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this also inside resultSetGroup.getResultSetCount() > 0?

List<Row.Builder> rowBuilderList = new ArrayList<>();
if (resultSetGroup.getResultSetCount() > 0) {
ResultSet resultSet = resultSetGroup.getResultSet(0);
List<String> maskedAttributes =
handlerScopedMaskingConfig.getMaskedAttributes(executionContext);
// Pinot has different Response format for selection and aggregation/group by query.
if (resultSetTypePredicateProvider.isSelectionResultSetType(resultSet)) {
// map merging is only supported in the selection. Filtering and Group by has its own
// syntax in Pinot
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes);
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else if (resultSetTypePredicateProvider.isResultTableResultSetType(resultSet)) {
handleTableFormatResultSet(resultSetGroup, rowBuilderList);
handleTableFormatResultSet(
resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else {
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList);
}
}

return Observable.fromIterable(rowBuilderList)
.map(Builder::build)
.doOnNext(row -> LOG.debug("collect a row: {}", row));
Expand All @@ -516,7 +528,8 @@ Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> sel
private void handleSelection(
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes) {
LinkedHashSet<String> selectedAttributes,
List<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
Expand All @@ -536,7 +549,11 @@ private void handleSelection(
for (String logicalName : selectedAttributes) {
// colVal will never be null. But getDataRow can throw a runtime exception if it failed
// to retrieve data
String colVal = resultAnalyzer.getDataFromRow(rowId, logicalName);
String colVal =
!maskedAttributes.contains(logicalName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : get rid of the ! and invert to simplify

? resultAnalyzer.getDataFromRow(rowId, logicalName)
: MASKED_VALUE;

builder.addColumn(Value.newBuilder().setString(colVal).build());
}
}
Expand Down Expand Up @@ -588,10 +605,15 @@ private void handleAggregationAndGroupBy(
}

private void handleTableFormatResultSet(
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList) {
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes,
List<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
PinotResultAnalyzer resultAnalyzer =
PinotResultAnalyzer.create(resultSet, selectedAttributes, viewDefinition);
for (int rowIdx = 0; rowIdx < resultSet.getRowCount(); rowIdx++) {
Builder builder;
builder = Row.newBuilder();
Expand All @@ -604,6 +626,15 @@ private void handleTableFormatResultSet(
// is structured
String mapKeys = resultSet.getString(rowIdx, colIdx);
String mapVals = resultSet.getString(rowIdx, colIdx + 1);

String logicalNameKey = resultAnalyzer.getLogicalNameFromColIdx(colIdx);
String logicalNameValue = resultAnalyzer.getLogicalNameFromColIdx(colIdx + 1);

if (maskedAttributes.contains(logicalNameKey)
|| maskedAttributes.contains(logicalNameValue)) {
mapVals = PINOT_EMPTY_LIST;
}

try {
builder.addColumn(
Value.newBuilder().setString(pinotMapConverter.merge(mapKeys, mapVals)).build());
Expand All @@ -616,6 +647,12 @@ private void handleTableFormatResultSet(
colIdx++;
} else {
String val = resultSet.getString(rowIdx, colIdx);
String columnLogicalName = resultAnalyzer.getLogicalNameFromColIdx(colIdx);

if (maskedAttributes.contains(columnLogicalName)) {
val = MASKED_VALUE;
}

builder.addColumn(Value.newBuilder().setString(val).build());
}
}
Expand Down Expand Up @@ -678,4 +715,8 @@ private boolean isInvalidExpression(Expression expression) {
&& viewDefinition.getColumnType(expression.getAttributeExpression().getAttributeId())
!= ValueType.STRING_MAP;
}

HandlerScopedMaskingConfig getHandlerScopedMaskingConfig() {
return handlerScopedMaskingConfig;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ class PinotResultAnalyzer {
private final ViewDefinition viewDefinition;
private final Map<String, RateLimiter> attributeLogRateLimitter;
private final PinotMapConverter pinotMapConverter;
private final Map<Integer, String> indexToLogicalName;

PinotResultAnalyzer(
ResultSet resultSet,
LinkedHashSet<String> selectedAttributes,
ViewDefinition viewDefinition,
Map<String, Integer> mapLogicalNameToKeyIndex,
Map<String, Integer> mapLogicalNameToValueIndex,
Map<String, Integer> logicalNameToPhysicalNameIndex) {
Map<String, Integer> logicalNameToPhysicalNameIndex,
Map<Integer, String> indexToLogicalName) {
this.mapLogicalNameToKeyIndex = mapLogicalNameToKeyIndex;
this.mapLogicalNameToValueIndex = mapLogicalNameToValueIndex;
this.logicalNameToPhysicalNameIndex = logicalNameToPhysicalNameIndex;
this.indexToLogicalName = indexToLogicalName;
this.resultSet = resultSet;
this.viewDefinition = viewDefinition;
this.attributeLogRateLimitter = new HashMap<>();
Expand All @@ -53,6 +56,7 @@ static PinotResultAnalyzer create(
Map<String, Integer> mapLogicalNameToKeyIndex = new HashMap<>();
Map<String, Integer> mapLogicalNameToValueIndex = new HashMap<>();
Map<String, Integer> logicalNameToPhysicalNameIndex = new HashMap<>();
Map<Integer, String> indexToLogicalName = new HashMap<>();

for (String logicalName : selectedAttributes) {
if (viewDefinition.isMap(logicalName)) {
Expand All @@ -62,8 +66,10 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(keyPhysicalName)) {
mapLogicalNameToKeyIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
} else if (physName.equalsIgnoreCase(valuePhysicalName)) {
mapLogicalNameToValueIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
}
}
} else {
Expand All @@ -73,6 +79,7 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(names.get(0))) {
logicalNameToPhysicalNameIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
break;
}
}
Expand All @@ -87,7 +94,8 @@ static PinotResultAnalyzer create(
viewDefinition,
mapLogicalNameToKeyIndex,
mapLogicalNameToValueIndex,
logicalNameToPhysicalNameIndex);
logicalNameToPhysicalNameIndex,
indexToLogicalName);
}

@VisibleForTesting
Expand Down Expand Up @@ -149,4 +157,12 @@ String getDataFromRow(int rowIndex, String logicalName) {
}
return result;
}

String getLogicalNameFromColIdx(Integer colIdx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Lets see if we can use Optional as return type.

if (indexToLogicalName.containsKey(colIdx)) {
return indexToLogicalName.get(colIdx);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove additional space.

return null;
}
}
Loading
Loading