Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

suggested changes #229

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package org.hypertrace.core.query.service;

import com.typesafe.config.Config;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HandlerScopedMaskingConfig {
private static final String TENANT_SCOPED_MASKS_CONFIG_KEY = "tenantScopedMaskingCriteria";
private final Map<String, List<MaskValuesForTimeRange>> tenantToMaskValuesMap;

public HandlerScopedMaskingConfig(Config config) {
if (config.hasPath(TENANT_SCOPED_MASKS_CONFIG_KEY)) {
this.tenantToMaskValuesMap =
config.getConfigList(TENANT_SCOPED_MASKS_CONFIG_KEY).stream()
.map(maskConfig -> new TenantMasks(maskConfig))
.collect(
Collectors.toMap(
tenantFilters -> tenantFilters.tenantId,
tenantFilters -> tenantFilters.maskValues));
} else {
this.tenantToMaskValuesMap = Collections.emptyMap();
}
}

public Map<String, String> getMaskedColumnsToValueMap(ExecutionContext executionContext) {
Map<String, String> maskedColumnsToValueMap = new HashMap<>();

String tenantId = executionContext.getTenantId();
if (!tenantToMaskValuesMap.containsKey(tenantId)) {
return maskedColumnsToValueMap;
}

Optional<QueryTimeRange> queryTimeRange = executionContext.getQueryTimeRange();
Instant queryStartTime, queryEndTime;
if (queryTimeRange.isPresent()) {
queryStartTime = queryTimeRange.get().getStartTime();
queryEndTime = queryTimeRange.get().getEndTime();
} else {
queryEndTime = Instant.MAX;
queryStartTime = Instant.MIN;
}
for (MaskValuesForTimeRange timeRangeAndMasks : tenantToMaskValuesMap.get(tenantId)) {
boolean timeRangeOverlap =
isTimeRangeOverlap(timeRangeAndMasks, queryStartTime, queryEndTime);

if (timeRangeOverlap) {
Map<String, String> attributeToMaskedValue =
timeRangeAndMasks.maskValues.attributeToMaskedValue;
for (String attribute : attributeToMaskedValue.keySet()) {
maskedColumnsToValueMap.put(attribute, attributeToMaskedValue.get(attribute));
}
}
}
return maskedColumnsToValueMap;
}

private static boolean isTimeRangeOverlap(
MaskValuesForTimeRange timeRangeAndMasks, Instant queryStartTime, Instant queryEndTime) {
boolean timeRangeOverlap = true;

if (timeRangeAndMasks.getStartTimeMillis().isPresent()) {
Instant startTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getStartTimeMillis().get());
if (startTimeInstant.isBefore(queryStartTime) || startTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}

Instant endTimeInstant = Instant.ofEpochMilli(timeRangeAndMasks.getStartTimeMillis().get());
if (endTimeInstant.isBefore(queryStartTime) || endTimeInstant.isAfter(queryEndTime)) {
timeRangeOverlap = false;
}
}
return timeRangeOverlap;
}

@Value
@NonFinal
private class TenantMasks {
private static final String TENANT_ID_CONFIG_KEY = "tenantId";
private static final String TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY = "timeRangeAndMaskValues";
String tenantId;
List<MaskValuesForTimeRange> maskValues;

private TenantMasks(Config config) {
this.tenantId = config.getString(TENANT_ID_CONFIG_KEY);
this.maskValues =
config.getConfigList(TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY).stream()
.map(MaskValuesForTimeRange::new)
.collect(Collectors.toList());
}
}

@Value
private class MaskValues {
Map<String, String> attributeToMaskedValue;

MaskValues(Map<String, String> columnToMaskedValue) {
this.attributeToMaskedValue = columnToMaskedValue;
}
}

@Value
@NonFinal
class MaskValuesForTimeRange {
private static final String START_TIME_CONFIG_PATH = "startTimeMillis";
private static final String END_TIME_CONFIG_PATH = "endTimeMillis";
private static final String MASK_VALUE_CONFIG_PATH = "maskValues";
private static final String ATTRIBUTE_ID_CONFIG_PATH = "attributeId";
private static final String MASKED_VALUE_CONFIG_PATH = "maskedValue";
Optional<Long> startTimeMillis;
Optional<Long> endTimeMillis;
MaskValues maskValues;

private MaskValuesForTimeRange(Config config) {
if (config.hasPath(START_TIME_CONFIG_PATH) && config.hasPath(END_TIME_CONFIG_PATH)) {
this.startTimeMillis = Optional.of(config.getLong(START_TIME_CONFIG_PATH));
this.endTimeMillis = Optional.of(config.getLong(END_TIME_CONFIG_PATH));
} else {
startTimeMillis = Optional.empty();
endTimeMillis = Optional.empty();
}
if (config.hasPath(MASK_VALUE_CONFIG_PATH)) {
List<Config> maskedValuesList =
new ArrayList<>(config.getConfigList(MASK_VALUE_CONFIG_PATH));
HashMap<String, String> maskedValuesMap = new HashMap<>();
maskedValuesList.forEach(
maskedValue -> {
maskedValuesMap.put(
maskedValue.getString(ATTRIBUTE_ID_CONFIG_PATH),
maskedValue.getString(MASKED_VALUE_CONFIG_PATH));
});

maskValues = new MaskValues(maskedValuesMap);
} else {
maskValues = new MaskValues(new HashMap<>());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.client.ResultSetGroup;
import org.hypertrace.core.query.service.ExecutionContext;
import org.hypertrace.core.query.service.HandlerScopedFiltersConfig;
import org.hypertrace.core.query.service.HandlerScopedMaskingConfig;
import org.hypertrace.core.query.service.QueryCost;
import org.hypertrace.core.query.service.RequestHandler;
import org.hypertrace.core.query.service.api.Expression;
Expand Down Expand Up @@ -67,6 +68,7 @@ public class PinotBasedRequestHandler implements RequestHandler {
private QueryRequestToPinotSQLConverter request2PinotSqlConverter;
private final PinotMapConverter pinotMapConverter;
private HandlerScopedFiltersConfig handlerScopedFiltersConfig;
private HandlerScopedMaskingConfig handlerScopedMaskingConfig;
// The implementations of ResultSet are package private and hence there's no way to determine the
// shape of the results
// other than to do string comparison on the simple class names. In order to be able to unit test
Expand Down Expand Up @@ -143,6 +145,7 @@ private void processConfig(Config config) {

this.handlerScopedFiltersConfig =
new HandlerScopedFiltersConfig(config, this.startTimeAttributeName);
this.handlerScopedMaskingConfig = new HandlerScopedMaskingConfig(config);
LOG.info(
"Using {}ms as the threshold for logging slow queries of handler: {}",
slowQueryThreshold,
Expand Down Expand Up @@ -423,8 +426,9 @@ public Observable<Row> handleRequest(
if (LOG.isDebugEnabled()) {
LOG.debug("Query results: [ {} ]", resultSetGroup.toString());
}

// need to merge data especially for Pinot. That's why we need to track the map columns
return this.convert(resultSetGroup, executionContext.getSelectedColumns())
return this.convert(resultSetGroup, executionContext.getSelectedColumns(), handlerScopedMaskingConfig.getMaskedColumnsToValueMap(executionContext))
.doOnComplete(
() -> {
long requestTimeMs = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -493,21 +497,22 @@ private Filter rewriteLeafFilter(
return queryFilter;
}

Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes) {
Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes, Map<String, String> maskedColumnsToValueMap) {
List<Row.Builder> rowBuilderList = new ArrayList<>();
if (resultSetGroup.getResultSetCount() > 0) {
ResultSet resultSet = resultSetGroup.getResultSet(0);
// Pinot has different Response format for selection and aggregation/group by query.
if (resultSetTypePredicateProvider.isSelectionResultSetType(resultSet)) {
// map merging is only supported in the selection. Filtering and Group by has its own
// syntax in Pinot
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes);
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes, maskedColumnsToValueMap);
} else if (resultSetTypePredicateProvider.isResultTableResultSetType(resultSet)) {
handleTableFormatResultSet(resultSetGroup, rowBuilderList);
handleTableFormatResultSet(resultSetGroup, rowBuilderList, maskedColumnsToValueMap);
} else {
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList);
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList, maskedColumnsToValueMap);
}
}

return Observable.fromIterable(rowBuilderList)
.map(Builder::build)
.doOnNext(row -> LOG.debug("collect a row: {}", row));
Expand All @@ -516,7 +521,8 @@ Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> sel
private void handleSelection(
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes) {
LinkedHashSet<String> selectedAttributes,
Map<String, String> maskedColumnsToValueMap) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
Expand All @@ -536,15 +542,17 @@ private void handleSelection(
for (String logicalName : selectedAttributes) {
// colVal will never be null. But getDataRow can throw a runtime exception if it failed
// to retrieve data
String colVal = resultAnalyzer.getDataFromRow(rowId, logicalName);
String colVal = maskedColumnsToValueMap.containsKey(logicalName)
? maskedColumnsToValueMap.get(logicalName): resultAnalyzer.getDataFromRow(rowId, logicalName);

builder.addColumn(Value.newBuilder().setString(colVal).build());
}
}
}
}

private void handleAggregationAndGroupBy(
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList) {
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList, Map<String, String> maskedColumnsToValueMap) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
Map<String, Integer> groupKey2RowIdMap = new HashMap<>();
for (int i = 0; i < resultSetGroupCount; i++) {
Expand Down Expand Up @@ -588,7 +596,7 @@ private void handleAggregationAndGroupBy(
}

private void handleTableFormatResultSet(
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList) {
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList, Map<String, String> maskedColumnsToValueMap) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
Expand Down Expand Up @@ -678,4 +686,8 @@ private boolean isInvalidExpression(Expression expression) {
&& viewDefinition.getColumnType(expression.getAttributeExpression().getAttributeId())
!= ValueType.STRING_MAP;
}

HandlerScopedMaskingConfig getHandlerScopedMaskingConfig() {
return handlerScopedMaskingConfig;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.typesafe.config.ConfigFactory;
import io.reactivex.rxjava3.core.Observable;
import java.io.IOException;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -1353,7 +1352,9 @@ public void testConvertSimpleSelectionsQueryResultSet() throws IOException {
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), resultTable);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
resultTable);
}

@Test
Expand All @@ -1371,7 +1372,9 @@ public void testConvertAggregationColumnsQueryResultSet() throws IOException {
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), resultTable);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
resultTable);
}

@Test
Expand Down Expand Up @@ -1432,7 +1435,9 @@ public void testConvertSelectionsWithMapKeysAndValuesQueryResultSet() throws IOE
};

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), expectedRows);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
expectedRows);
}

@Test
Expand Down Expand Up @@ -1467,7 +1472,9 @@ public void testConvertMultipleResultSetsInFResultSetGroup() throws IOException
};

verifyResponseRows(
pinotBasedRequestHandler.convert(resultSetGroup, new LinkedHashSet<>()), expectedRows);
pinotBasedRequestHandler.convert(
resultSetGroup, new ExecutionContext("__default", QueryRequest.newBuilder().build())),
expectedRows);
}

@Test
Expand Down Expand Up @@ -1756,6 +1763,80 @@ public boolean isResultTableResultSetType(ResultSet resultSet) {
}
}

@Test
public void testMaskColumnValue() throws IOException {
for (Config config : serviceConfig.getConfigList("queryRequestHandlersConfig")) {
if (!isPinotConfig(config)) {
continue;
}

if (!config.getString("name").equals("span-event-view-handler")) {
continue;
}

// Mock the PinotClient
PinotClient pinotClient = mock(PinotClient.class);
PinotClientFactory factory = mock(PinotClientFactory.class);
when(factory.getPinotClient(any())).thenReturn(pinotClient);

String[][] resultTable =
new String[][] {
{
"test-span-id-1", "trace-id-1",
},
{"test-span-id-2", "trace-id-1"},
{"test-span-id-3", "trace-id-1"},
{"test-span-id-4", "trace-id-2"}
};
List<String> columnNames = List.of("span_id", "trace_id");
ResultSet resultSet = mockResultSet(4, 2, columnNames, resultTable);
ResultSetGroup resultSetGroup = mockResultSetGroup(List.of(resultSet));

PinotBasedRequestHandler handler =
new PinotBasedRequestHandler(
config.getString("name"),
config.getConfig("requestHandlerInfo"),
new ResultSetTypePredicateProvider() {
@Override
public boolean isSelectionResultSetType(ResultSet resultSet) {
return true;
}

@Override
public boolean isResultTableResultSetType(ResultSet resultSet) {
return false;
}
},
factory);

QueryRequest request =
QueryRequest.newBuilder()
.addSelection(QueryRequestBuilderUtils.createColumnExpression("EVENT.id"))
.addSelection(QueryRequestBuilderUtils.createColumnExpression("EVENT.traceId"))
.build();
ExecutionContext context = new ExecutionContext("maskTenant", request);

// The query filter is based on both isEntrySpan and startTime. Since the viewFilter
// checks for both the true and false values of isEntrySpan and query filter only needs
// "true", isEntrySpan predicate is still passed to the store in the query.
String expectedQuery = "Select span_id, trace_id FROM spanEventView WHERE tenant_id = ?";
Params params = Params.newBuilder().addStringParam("maskTenant").build();
when(pinotClient.executeQuery(expectedQuery, params)).thenReturn(resultSetGroup);

String[][] expectedTable =
new String[][] {
{
"*", "trace-id-1",
},
{"*", "trace-id-1"},
{"*", "trace-id-1"},
{"*", "trace-id-2"}
};

verifyResponseRows(handler.handleRequest(request, context), expectedTable);
}
}

@Test
public void testViewColumnFilterRemovalComplexCase() throws IOException {
for (Config config : serviceConfig.getConfigList("queryRequestHandlersConfig")) {
Expand Down
Loading
Loading