Skip to content

Commit

Permalink
make @SuppressWarnings annotations local and with comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Oct 19, 2023
1 parent 156102b commit 7732ecf
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@
import javax.annotation.Nullable;

/**
* Entry in the side input cache that stores the value (null if not ready), and the encoded size of
* the value.
* Entry in the side input cache that stores the value and the encoded size of the value.
*
* <p>Can be in 1 of 3 states:
*
* <ul>
* <li>Ready with a <T> value.
* <li>Ready with no value, represented as {@link Optional<T>}
* <li>Not ready.
* </ul>
*/
@AutoValue
public abstract class SideInput<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher;
Expand All @@ -37,15 +38,14 @@
* types of all objects.
*/
@CheckReturnValue
@SuppressWarnings("unchecked")
final class SideInputCache {

private static final long MAXIMUM_CACHE_WEIGHT = 100000000; /* 100 MB */
private static final long CACHE_ENTRY_EXPIRY_MINUTES = 1L;

private final Cache<Key, SideInput<?>> sideInputCache;
private final Cache<Key<?>, SideInput<?>> sideInputCache;

SideInputCache(Cache<Key, SideInput<?>> sideInputCache) {
SideInputCache(Cache<Key<?>, SideInput<?>> sideInputCache) {
this.sideInputCache = sideInputCache;
}

Expand All @@ -54,40 +54,60 @@ static SideInputCache create() {
CacheBuilder.newBuilder()
.maximumWeight(MAXIMUM_CACHE_WEIGHT)
.expireAfterWrite(CACHE_ENTRY_EXPIRY_MINUTES, TimeUnit.MINUTES)
.weigher((Weigher<Key, SideInput<?>>) (id, entry) -> entry.size())
.weigher((Weigher<Key<?>, SideInput<?>>) (id, entry) -> entry.size())
.build());
}

synchronized <T> SideInput<T> invalidateThenLoadNewEntry(
Key key, Callable<SideInput<T>> cacheLoaderFn) throws ExecutionException {
Key<T> key, Callable<SideInput<T>> cacheLoaderFn) throws ExecutionException {
// Invalidate the existing not-ready entry. This must be done atomically
// so that another thread doesn't replace the entry with a ready entry, which
// would then be deleted here.
SideInput<?> newEntry = sideInputCache.getIfPresent(key);
if (newEntry != null && !newEntry.isReady()) {
Optional<SideInput<T>> newEntry = getIfPresentUnchecked(key);
if (newEntry.isPresent() && !newEntry.get().isReady()) {
sideInputCache.invalidate(key);
}

return (SideInput<T>) sideInputCache.get(key, cacheLoaderFn);
return getUnchecked(key, cacheLoaderFn);
}

<T> Optional<SideInput<T>> get(Key key) {
return Optional.ofNullable((SideInput<T>) sideInputCache.getIfPresent(key));
<T> Optional<SideInput<T>> get(Key<T> key) {
return getIfPresentUnchecked(key);
}

<T> SideInput<T> getOrLoad(Key<T> key, Callable<SideInput<T>> cacheLoaderFn)
throws ExecutionException {
return getUnchecked(key, cacheLoaderFn);
}

<T> SideInput<T> getOrLoad(Key key, Callable<SideInput<T>> cacheLoaderFn)
@SuppressWarnings({
"unchecked" // cacheLoaderFn loads SideInput<T>, and key is of type T, so value for Key is
// always SideInput<T>.
})
private <T> SideInput<T> getUnchecked(Key<T> key, Callable<SideInput<T>> cacheLoaderFn)
throws ExecutionException {
return (SideInput<T>) sideInputCache.get(key, cacheLoaderFn);
}

@SuppressWarnings({
"unchecked" // cacheLoaderFn loads SideInput<T>, and key is of type T, so value for Key is
// always SideInput<T>.
})
private <T> Optional<SideInput<T>> getIfPresentUnchecked(Key<T> key) {
return Optional.ofNullable((SideInput<T>) sideInputCache.getIfPresent(key));
}

@AutoValue
abstract static class Key {
abstract static class Key<T> {
static <T> Key<T> create(
TupleTag<?> tag, BoundedWindow window, TypeDescriptor<T> typeDescriptor) {
return new AutoValue_SideInputCache_Key<>(tag, window, typeDescriptor);
}

abstract TupleTag<?> tag();

abstract BoundedWindow window();

static Key create(TupleTag<?> tag, BoundedWindow window) {
return new AutoValue_SideInputCache_Key(tag, window);
}
abstract TypeDescriptor<T> typeDescriptor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,14 @@
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Class responsible for fetching state from the windmill server. */
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
@NotThreadSafe
public class SideInputStateFetcher {
private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class);
Expand All @@ -73,14 +70,32 @@ public SideInputStateFetcher(MetricTrackingWindmillServerStub server) {
this.sideInputCache = sideInputCache;
}

@SuppressWarnings("deprecation")
private static Iterable<?> decodeRawData(Coder<?> viewInternalCoder, GlobalData data)
private static <T> Iterable<?> decodeRawData(PCollectionView<T> view, GlobalData data)
throws IOException {
return !data.getData().isEmpty()
? IterableCoder.of(viewInternalCoder).decode(data.getData().newInput(), Coder.Context.OUTER)
? IterableCoder.of(getCoder(view)).decode(data.getData().newInput())
: Collections.emptyList();
}

@SuppressWarnings({
"deprecation" // Required as part of the SideInputCacheKey, and not exposed.
})
private static <T> TupleTag<?> getInternalTag(PCollectionView<T> view) {
return view.getTagInternal();
}

@SuppressWarnings("deprecation")
private static <T> ViewFn<?, T> getViewFn(PCollectionView<T> view) {
return view.getViewFn();
}

@SuppressWarnings({
"deprecation" // The view's internal coder is required to decode the raw data.
})
private static <T> Coder<?> getCoder(PCollectionView<T> view) {
return view.getCoderInternal();
}

/** Returns a view of the underlying cache that keeps track of bytes read separately. */
public SideInputStateFetcher byteTrackingView() {
return new SideInputStateFetcher(server, sideInputCache);
Expand All @@ -95,11 +110,7 @@ public long getBytesRead() {
*
* <p>If state is KNOWN_READY, attempt to fetch the data regardless of whether a not-ready entry
* was cached.
*
* <p>Returns {@literal null} if the side input was not ready, {@literal Optional.absent()} if the
* side input was null, and {@literal Optional.present(...)} if the side input was non-null.
*/
@SuppressWarnings("deprecation")
public <T> SideInput<T> fetchSideInput(
PCollectionView<T> view,
BoundedWindow sideWindow,
Expand All @@ -108,9 +119,9 @@ public <T> SideInput<T> fetchSideInput(
Supplier<Closeable> scopedReadStateSupplier) {
Callable<SideInput<T>> loadSideInputFromWindmill =
() -> loadSideInputFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier);

SideInputCache.Key sideInputCacheKey =
SideInputCache.Key.create(view.getTagInternal(), sideWindow);
SideInputCache.Key<T> sideInputCacheKey =
SideInputCache.Key.create(
getInternalTag(view), sideWindow, getViewFn(view).getTypeDescriptor());

try {
if (state == SideInputState.KNOWN_READY) {
Expand All @@ -134,26 +145,29 @@ public <T> SideInput<T> fetchSideInput(
}
}

@SuppressWarnings({"deprecation", "unchecked"})
private <T, SideWindowT extends BoundedWindow> GlobalData fetchGlobalDataFromWindmill(
PCollectionView<T> view,
SideWindowT sideWindow,
String stateFamily,
Supplier<Closeable> scopedReadStateSupplier)
throws IOException {
@SuppressWarnings({
"deprecation", // Internal windowStrategy is required to fetch side input data from Windmill.
"unchecked" // Internal windowing strategy matches WindowingStrategy<?, SideWindowT>.
})
WindowingStrategy<?, SideWindowT> sideWindowStrategy =
(WindowingStrategy<?, SideWindowT>) view.getWindowingStrategyInternal();

Coder<SideWindowT> windowCoder = sideWindowStrategy.getWindowFn().windowCoder();

ByteStringOutputStream windowStream = new ByteStringOutputStream();
windowCoder.encode(sideWindow, windowStream, Coder.Context.OUTER);
windowCoder.encode(sideWindow, windowStream);

Windmill.GlobalDataRequest request =
Windmill.GlobalDataRequest.newBuilder()
.setDataId(
Windmill.GlobalDataId.newBuilder()
.setTag(view.getTagInternal().getId())
.setTag(getInternalTag(view).getId())
.setVersion(windowStream.toByteString())
.build())
.setStateFamily(stateFamily)
Expand All @@ -167,49 +181,65 @@ private <T, SideWindowT extends BoundedWindow> GlobalData fetchGlobalDataFromWin
}
}

@SuppressWarnings("deprecation")
private <T> SideInput<T> loadSideInputFromWindmill(
PCollectionView<T> view,
BoundedWindow sideWindow,
String stateFamily,
Supplier<Closeable> scopedReadStateSupplier)
throws IOException {
checkState(
SUPPORTED_MATERIALIZATIONS.contains(view.getViewFn().getMaterialization().getUrn()),
"Only materialization's of type %s supported, received %s",
SUPPORTED_MATERIALIZATIONS,
view.getViewFn().getMaterialization().getUrn());

validateViewMaterialization(view);
GlobalData data =
fetchGlobalDataFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier);
bytesRead += data.getSerializedSize();
return data.getIsReady() ? createSideInputCacheEntry(view, data) : SideInput.notReady();
}

@SuppressWarnings({"deprecation", "unchecked"})
private <T> void validateViewMaterialization(PCollectionView<T> view) {
String materializationUrn = getViewFn(view).getMaterialization().getUrn();
checkState(
SUPPORTED_MATERIALIZATIONS.contains(materializationUrn),
"Only materialization's of type %s supported, received %s",
SUPPORTED_MATERIALIZATIONS,
materializationUrn);
}

private <T> SideInput<T> createSideInputCacheEntry(PCollectionView<T> view, GlobalData data)
throws IOException {
Iterable<?> rawData = decodeRawData(view.getCoderInternal(), data);
switch (view.getViewFn().getMaterialization().getUrn()) {
Iterable<?> rawData = decodeRawData(view, data);
switch (getViewFn(view).getMaterialization().getUrn()) {
case ITERABLE_MATERIALIZATION_URN:
{
ViewFn<IterableView, T> viewFn = (ViewFn<IterableView, T>) view.getViewFn();
@SuppressWarnings({
"unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn<IterableView, T>.
"rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
})
ViewFn<IterableView, T> viewFn = (ViewFn<IterableView, T>) getViewFn(view);
return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size());
}
case MULTIMAP_MATERIALIZATION_URN:
{
ViewFn<MultimapView, T> viewFn = (ViewFn<MultimapView, T>) view.getViewFn();
Coder<?> keyCoder = ((KvCoder<?, ?>) view.getCoderInternal()).getKeyCoder();
return SideInput.ready(
@SuppressWarnings({
"unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn<MultimapView, T>.
"rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
})
ViewFn<MultimapView, T> viewFn = (ViewFn<MultimapView, T>) getViewFn(view);
Coder<?> keyCoder = ((KvCoder<?, ?>) getCoder(view)).getKeyCoder();

@SuppressWarnings({
"unchecked", // Safe since multimap rawData is of type Iterable<KV<K, V>>
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
})
T multimapSideInputValue =
viewFn.apply(
InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)),
data.getData().size());
InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData));
return SideInput.ready(multimapSideInputValue, data.getData().size());
}
default:
throw new IllegalStateException(
String.format(
"Unknown side input materialization format requested '%s'",
view.getViewFn().getMaterialization().getUrn()));
{
throw new IllegalStateException(
"Unknown side input materialization format requested: "
+ getViewFn(view).getMaterialization().getUrn());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@
import org.mockito.MockitoAnnotations;

/** Unit tests for {@link SideInputStateFetcher}. */
// TODO: Add tests with different encoded windows to verify version is correctly plumbed.
@SuppressWarnings("deprecation")
@RunWith(JUnit4.class)
public class SideInputStateFetcherTest {
private static final String STATE_FAMILY = "state";

@Mock MetricTrackingWindmillServerStub server;
@Mock private MetricTrackingWindmillServerStub server;

@Mock Supplier<Closeable> readStateSupplier;
@Mock private Supplier<Closeable> readStateSupplier;

@Before
public void setUp() {
Expand Down Expand Up @@ -215,7 +216,7 @@ public void testFetchGlobalDataCacheOverflow() throws Exception {
coder.encode(Collections.singletonList("data2"), stream, Coder.Context.OUTER);
ByteString encodedIterable2 = stream.toByteString();

Cache<SideInputCache.Key, SideInput<?>> cache = CacheBuilder.newBuilder().build();
Cache<SideInputCache.Key<?>, SideInput<?>> cache = CacheBuilder.newBuilder().build();

SideInputStateFetcher fetcher = new SideInputStateFetcher(server, new SideInputCache(cache));

Expand Down Expand Up @@ -331,7 +332,7 @@ public void testEmptyFetchGlobalData() throws Exception {
verifyNoMoreInteractions(server);
}

private Windmill.GlobalData buildGlobalDataResponse(
private static Windmill.GlobalData buildGlobalDataResponse(
String tag, boolean isReady, ByteString data) {
Windmill.GlobalData.Builder builder =
Windmill.GlobalData.newBuilder()
Expand All @@ -349,9 +350,9 @@ private Windmill.GlobalData buildGlobalDataResponse(
return builder.build();
}

private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) {
private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag, ByteString version) {
Windmill.GlobalDataId id =
Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(ByteString.EMPTY).build();
Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build();

return Windmill.GlobalDataRequest.newBuilder()
.setDataId(id)
Expand All @@ -360,4 +361,8 @@ private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) {
TimeUnit.MILLISECONDS.toMicros(GlobalWindow.INSTANCE.maxTimestamp().getMillis()))
.build();
}

private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) {
return buildGlobalDataRequest(tag, ByteString.EMPTY);
}
}

0 comments on commit 7732ecf

Please sign in to comment.