diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 48f057bbc1ea..08f05fc51b69 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1675,6 +1675,11 @@ message StandardRunnerProtocols { // https://s.apache.org/beam-fn-api-control-data-embedding CONTROL_RESPONSE_ELEMENTS_EMBEDDING = 6 [(beam_urn) = "beam:protocol:control_response_elements_embedding:v1"]; + + // Indicates that this runner can handle the multimap_keys_values_side_input + // style read of a multimap side input. + MULTIMAP_KEYS_VALUES_SIDE_INPUT = 7 + [(beam_urn) = "beam:protocol:multimap_keys_values_side_input:v1"]; } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index ddf52125b2e4..19c13775684e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.NavigableSet; +import java.util.Set; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -168,6 +169,7 @@ static class Factory runner = new FnApiDoFnRunner<>( context.getPipelineOptions(), + context.getRunnerCapabilities(), context.getShortIdMap(), context.getBeamFnStateClient(), context.getPTransformId(), @@ -336,6 +338,7 @@ static class Factory runnerCapabilities, ShortIdMap shortIds, BeamFnStateClient beamFnStateClient, String pTransformId, @@ -740,6 +743,7 @@ private ByteString encodeProgress(double value) throws IOException { this.stateAccessor = new FnApiStateAccessor( pipelineOptions, + runnerCapabilities, pTransformId, processBundleInstructionId, cacheTokens, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index 69d9a1ff6c8e..204c491dc102 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -26,6 +26,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.Cache; @@ -33,7 +34,9 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.core.construction.BeamUrns; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; @@ -74,6 +77,7 @@ }) public class FnApiStateAccessor implements SideInputReader, StateBinder { private final PipelineOptions pipelineOptions; + private final Set runnerCapabilites; private final Map stateKeyObjectCache; private final Map, SideInputSpec> sideInputSpecMap; private final BeamFnStateClient beamFnStateClient; @@ -91,6 +95,7 @@ public class FnApiStateAccessor implements SideInputReader, StateBinder { public FnApiStateAccessor( PipelineOptions pipelineOptions, + Set runnerCapabilites, String ptransformId, Supplier processBundleInstructionId, Supplier> cacheTokens, @@ -103,6 +108,7 @@ public FnApiStateAccessor( Supplier currentKeySupplier, Supplier currentWindowSupplier) { this.pipelineOptions = pipelineOptions; + this.runnerCapabilites = runnerCapabilites; this.stateKeyObjectCache = Maps.newHashMap(); this.sideInputSpecMap = sideInputSpecMap; this.beamFnStateClient = beamFnStateClient; @@ -238,7 +244,11 @@ public ResultT get() { processBundleInstructionId.get(), key, ((KvCoder) sideInputSpec.getCoder()).getKeyCoder(), - ((KvCoder) sideInputSpec.getCoder()).getValueCoder())); + ((KvCoder) sideInputSpec.getCoder()).getValueCoder(), + runnerCapabilites.contains( + BeamUrns.getUrn( + RunnerApi.StandardRunnerProtocols.Enum + .MULTIMAP_KEYS_VALUES_SIDE_INPUT)))); default: throw new IllegalStateException( String.format( diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java index ec7429fcdc0e..0c7726441021 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java @@ -56,17 +56,6 @@ public class MultimapSideInput implements MultimapView { private volatile Function> bulkReadResult; private final boolean useBulkRead; - public MultimapSideInput( - Cache cache, - BeamFnStateClient beamFnStateClient, - String instructionId, - StateKey stateKey, - Coder keyCoder, - Coder valueCoder) { - // TODO(robertwb): Plumb the value of useBulkRead from runner capabilities. - this(cache, beamFnStateClient, instructionId, stateKey, keyCoder, valueCoder, false); - } - public MultimapSideInput( Cache cache, BeamFnStateClient beamFnStateClient, diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py index 3799af5d2e1b..e7b086c5a649 100644 --- a/sdks/python/apache_beam/portability/common_urns.py +++ b/sdks/python/apache_beam/portability/common_urns.py @@ -34,6 +34,7 @@ StandardPTransforms = beam_runner_api_pb2_urns.StandardPTransforms StandardRequirements = beam_runner_api_pb2_urns.StandardRequirements StandardResourceHints = beam_runner_api_pb2_urns.StandardResourceHints +StandardRunnerProtocols = beam_runner_api_pb2_urns.StandardRunnerProtocols StandardSideInputTypes = beam_runner_api_pb2_urns.StandardSideInputTypes StandardUserStateTypes = beam_runner_api_pb2_urns.StandardUserStateTypes ExpansionMethods = external_transforms_pb2_urns.ExpansionMethods @@ -73,6 +74,7 @@ monitoring_info_labels = MonitoringInfo.MonitoringInfoLabels protocols = StandardProtocols.Enum +runner_protocols = StandardRunnerProtocols.Enum requirements = StandardRequirements.Enum displayData = StandardDisplayData.DisplayData diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index b0b4b1957dd8..b0421a6e43af 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -87,6 +87,10 @@ # Time-based flush is enabled in the fn_api_runner by default. DATA_BUFFER_TIME_LIMIT_MS = 1000 +FNAPI_RUNNER_CAPABILITIES = frozenset([ + common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn, +]) + _LOGGER = logging.getLogger(__name__) T = TypeVar('T') @@ -363,6 +367,7 @@ def __init__(self, self.data_conn = self.data_plane_handler state_cache = StateCache(STATE_CACHE_SIZE_MB * MB_TO_BYTES) self.bundle_processor_cache = sdk_worker.BundleProcessorCache( + FNAPI_RUNNER_CAPABILITIES, SingletonStateHandlerFactory( sdk_worker.GlobalCachingStateHandler(state_cache, state)), data_plane.InMemoryDataChannelFactory( @@ -433,6 +438,7 @@ def GetProvisionInfo(self, request, context=None): info.control_endpoint.CopyFrom(worker.control_api_service_descriptor()) else: info = self._base_info + info.runner_capabilities[:] = FNAPI_RUNNER_CAPABILITIES return beam_provision_api_pb2.GetProvisionInfoResponse(info=info) @@ -663,7 +669,8 @@ def start_worker(self): self.control_address, state_cache_size=self._state_cache_size, data_buffer_time_limit_ms=self._data_buffer_time_limit_ms, - worker_id=self.worker_id) + worker_id=self.worker_id, + runner_capabilities=FNAPI_RUNNER_CAPABILITIES) self.worker_thread = threading.Thread( name='run_worker', target=self.worker.run) self.worker_thread.daemon = True diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 02a2f6016f71..cf2b61d48b50 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -920,6 +920,7 @@ class BundleProcessor(object): """ A class for processing bundles of elements. """ def __init__(self, + runner_capabilities, # type: FrozenSet[str] process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor state_handler, # type: sdk_worker.CachingStateHandler data_channel_factory, # type: data_plane.DataChannelFactory @@ -930,11 +931,14 @@ def __init__(self, """Initialize a bundle processor. Args: + runner_capabilities (``FrozenSet[str]``): The set of capabilities of the + runner with which we will be interacting process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``): a description of the stage that this ``BundleProcessor``is to execute. state_handler (CachingStateHandler). data_channel_factory (``data_plane.DataChannelFactory``). """ + self.runner_capabilities = runner_capabilities self.process_bundle_descriptor = process_bundle_descriptor self.state_handler = state_handler self.data_channel_factory = data_channel_factory @@ -976,12 +980,14 @@ def create_execution_tree( ): # type: (...) -> collections.OrderedDict[str, operations.DoOperation] transform_factory = BeamTransformFactory( + self.runner_capabilities, descriptor, self.data_channel_factory, self.counter_factory, self.state_sampler, self.state_handler, - self.data_sampler) + self.data_sampler, + ) self.timers_info = transform_factory.extract_timers_info() @@ -1267,6 +1273,7 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" def __init__(self, + runner_capabilities, # type: FrozenSet[str] descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor data_channel_factory, # type: data_plane.DataChannelFactory counter_factory, # type: counters.CounterFactory @@ -1274,6 +1281,7 @@ def __init__(self, state_handler, # type: sdk_worker.CachingStateHandler data_sampler, # type: Optional[data_sampler.DataSampler] ): + self.runner_capabilities = runner_capabilities self.descriptor = descriptor self.data_channel_factory = data_channel_factory self.counter_factory = counter_factory @@ -1699,8 +1707,11 @@ def _create_pardo_operation( transform_id, tag, si, - input_tags_to_coders[tag]) for tag, - si in tagged_side_inputs + input_tags_to_coders[tag], + use_bulk_read=( + common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn + in factory.runner_capabilities)) + for (tag, si) in tagged_side_inputs ] else: side_input_maps = [] diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index 292b8431063c..dafb4dbd4bf0 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -267,7 +267,7 @@ def test_disabled_by_default(self): """ descriptor = beam_fn_api_pb2.ProcessBundleDescriptor() descriptor.pcollections['a'].unique_name = 'a' - _ = BundleProcessor(descriptor, None, None) + _ = BundleProcessor(set(), descriptor, None, None) self.assertEqual(len(descriptor.transforms), 0) def test_can_sample(self): @@ -301,7 +301,7 @@ def test_can_sample(self): # Create and process a fake bundle. The instruction id doesn't matter # here. processor = BundleProcessor( - descriptor, None, None, data_sampler=data_sampler) + set(), descriptor, None, None, data_sampler=data_sampler) processor.process_bundle('instruction_id') samples = data_sampler.wait_for_samples([PCOLLECTION_ID]) @@ -377,7 +377,7 @@ def test_can_sample_exceptions(self): # Create and process a fake bundle. The instruction id doesn't matter # here. processor = BundleProcessor( - descriptor, None, None, data_sampler=data_sampler) + set(), descriptor, None, None, data_sampler=data_sampler) with self.assertRaisesRegex(RuntimeError, 'expected exception'): processor.process_bundle('instruction_id') diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py index 9eb9299cac39..2cf7dff9d57f 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler_test.py +++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py @@ -286,7 +286,7 @@ def test_extracts_transform_id_during_exceptions(self): # Create and process a fake bundle. The instruction id doesn't matter # here. - processor = BundleProcessor(descriptor, None, None) + processor = BundleProcessor(set(), descriptor, None, None) with self.assertRaisesRegex(RuntimeError, 'expected exception'): processor.process_bundle('instruction_id') diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index bfd6544d802b..b55f505a6adc 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -174,6 +174,7 @@ def __init__( # Unrecoverable SDK harness initialization error (if any) # that should be reported to the runner when proocessing the first bundle. deferred_exception=None, # type: Optional[Exception] + runner_capabilities=frozenset(), # type: FrozenSet[str] ): # type: (...) -> None self._alive = True @@ -202,6 +203,7 @@ def __init__( self._state_cache, credentials) self._profiler_factory = profiler_factory self.data_sampler = data_sampler + self.runner_capabilities = runner_capabilities def default_factory(id): # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor @@ -212,6 +214,7 @@ def default_factory(id): self._fns = KeyedDefaultDict(default_factory) # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( + self.runner_capabilities, state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns, @@ -419,12 +422,14 @@ class BundleProcessorCache(object): def __init__( self, + runner_capabilities, # type: FrozenSet[str] state_handler_factory, # type: StateHandlerFactory data_channel_factory, # type: data_plane.DataChannelFactory fns, # type: MutableMapping[str, beam_fn_api_pb2.ProcessBundleDescriptor] data_sampler=None, # type: Optional[data_sampler.DataSampler] ): # type: (...) -> None + self.runner_capabilities = runner_capabilities self.fns = fns self.state_handler_factory = state_handler_factory self.data_channel_factory = data_channel_factory @@ -485,6 +490,7 @@ def get(self, instruction_id, bundle_descriptor_id): # Make sure we instantiate the processor while not holding the lock. processor = bundle_processor.BundleProcessor( + self.runner_capabilities, self.fns[bundle_descriptor_id], self.state_handler_factory.create_state_handler( self.fns[bundle_descriptor_id].state_api_service_descriptor), diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index 1af0071edc14..cd49c69a80aa 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -103,10 +103,9 @@ def create_harness(environment, dry_run=False): pickle_library = sdk_pipeline_options.view_as(SetupOptions).pickle_library pickler.set_library(pickle_library) - if 'SEMI_PERSISTENT_DIRECTORY' in environment: - semi_persistent_directory = environment['SEMI_PERSISTENT_DIRECTORY'] - else: - semi_persistent_directory = None + semi_persistent_directory = environment.get('SEMI_PERSISTENT_DIRECTORY', None) + runner_capabilities = frozenset( + environment.get('RUNNER_CAPABILITIES', '').split()) _LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory) _worker_id = environment.get('WORKER_ID', None) @@ -167,7 +166,8 @@ def create_harness(environment, dry_run=False): sdk_pipeline_options.view_as(ProfilingOptions)), enable_heap_dump=enable_heap_dump, data_sampler=data_sampler, - deferred_exception=deferred_exception) + deferred_exception=deferred_exception, + runner_capabilities=runner_capabilities) return fn_log_handler, sdk_harness, sdk_pipeline_options diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 8570c5a7722c..4e202910345c 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -126,7 +126,7 @@ def test_fn_registration(self): def test_inactive_bundle_processor_returns_empty_progress_response(self): bundle_processor = mock.MagicMock() - bundle_processor_cache = BundleProcessorCache(None, None, {}) + bundle_processor_cache = BundleProcessorCache(None, None, None, {}) bundle_processor_cache.activate('instruction_id') worker = SdkWorker(bundle_processor_cache) split_request = beam_fn_api_pb2.InstructionRequest( @@ -153,7 +153,7 @@ def test_inactive_bundle_processor_returns_empty_progress_response(self): def test_failed_bundle_processor_returns_failed_progress_response(self): bundle_processor = mock.MagicMock() - bundle_processor_cache = BundleProcessorCache(None, None, {}) + bundle_processor_cache = BundleProcessorCache(None, None, None, {}) bundle_processor_cache.activate('instruction_id') worker = SdkWorker(bundle_processor_cache) @@ -172,7 +172,7 @@ def test_failed_bundle_processor_returns_failed_progress_response(self): def test_inactive_bundle_processor_returns_empty_split_response(self): bundle_processor = mock.MagicMock() - bundle_processor_cache = BundleProcessorCache(None, None, {}) + bundle_processor_cache = BundleProcessorCache(None, None, None, {}) bundle_processor_cache.activate('instruction_id') worker = SdkWorker(bundle_processor_cache) split_request = beam_fn_api_pb2.InstructionRequest( @@ -258,7 +258,7 @@ def test_harness_monitoring_infos_and_metadata(self): def test_failed_bundle_processor_returns_failed_split_response(self): bundle_processor = mock.MagicMock() - bundle_processor_cache = BundleProcessorCache(None, None, {}) + bundle_processor_cache = BundleProcessorCache(None, None, None, {}) bundle_processor_cache.activate('instruction_id') worker = SdkWorker(bundle_processor_cache) diff --git a/sdks/python/gen_protos.py b/sdks/python/gen_protos.py index 2b488af0afb5..a2cd1bd4cef3 100644 --- a/sdks/python/gen_protos.py +++ b/sdks/python/gen_protos.py @@ -527,7 +527,6 @@ def generate_proto_files(force=False): generate_init_files_lite(PYTHON_OUTPUT_PATH) for proto_package in proto_packages: generate_urn_files(proto_package, PYTHON_OUTPUT_PATH) - generate_init_files_full(PYTHON_OUTPUT_PATH)