Skip to content

Commit

Permalink
Deduplicate common environments. (#30681)
Browse files Browse the repository at this point in the history
We deduplicate both on proto construction (as before, but fixed) and again after more environments have been resolved.
  • Loading branch information
robertwb authored Mar 26, 2024
1 parent 4af19ff commit 7f04c4f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 30 deletions.
31 changes: 2 additions & 29 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import PipelineRunner
from apache_beam.runners import common
from apache_beam.runners import create_runner
from apache_beam.transforms import ParDo
from apache_beam.transforms import ptransform
Expand Down Expand Up @@ -967,35 +968,7 @@ def merge_compatible_environments(proto):
Mutates proto as contexts may have references to proto.components.
"""
env_map = {}
canonical_env = {}
files_by_hash = {}
for env_id, env in proto.components.environments.items():
# First deduplicate any file dependencies by their hash.
for dep in env.dependencies:
if dep.type_urn == common_urns.artifact_types.FILE.urn:
file_payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
if file_payload.sha256:
if file_payload.sha256 in files_by_hash:
file_payload.path = files_by_hash[file_payload.sha256]
dep.type_payload = file_payload.SerializeToString()
else:
files_by_hash[file_payload.sha256] = file_payload.path
# Next check if we've ever seen this environment before.
normalized = env.SerializeToString(deterministic=True)
if normalized in canonical_env:
env_map[env_id] = canonical_env[normalized]
else:
canonical_env[normalized] = env_id
for old_env, new_env in env_map.items():
for transform in proto.components.transforms.values():
if transform.environment_id == old_env:
transform.environment_id = new_env
for windowing_strategy in proto.components.windowing_strategies.values():
if windowing_strategy.environment_id == old_env:
windowing_strategy.environment_id = new_env
del proto.components.environments[old_env]
common.merge_common_environments(proto, inplace=True)

@staticmethod
def from_runner_api(
Expand Down
67 changes: 67 additions & 0 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

# pytype: skip-file

import collections
import copy
import logging
import sys
import threading
Expand All @@ -43,6 +45,7 @@
from apache_beam.internal import util
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.runners.sdf_utils import RestrictionTrackerView
Expand All @@ -52,6 +55,7 @@
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
from apache_beam.transforms import DoFn
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms.core import RestrictionProvider
from apache_beam.transforms.core import WatermarkEstimatorProvider
Expand Down Expand Up @@ -1941,3 +1945,66 @@ def validate_transform(transform_id):

for t in pipeline_proto.root_transform_ids:
validate_transform(t)


def merge_common_environments(pipeline_proto, inplace=False):
def dep_key(dep):
if dep.type_urn == common_urns.artifact_types.FILE.urn:
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'path', payload.path
elif dep.type_urn == common_urns.artifact_types.URL.urn:
payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'url', payload.url
else:
type_info = dep.type_urn, dep.type_payload
return type_info, dep.role_urn, dep.role_payload

def base_env_key(env):
return (
env.urn,
env.payload,
tuple(sorted(env.capabilities)),
tuple(sorted(env.resource_hints.items())),
tuple(sorted(dep_key(dep) for dep in env.dependencies)))

def env_key(env):
return tuple(
sorted(
base_env_key(e)
for e in environments.expand_anyof_environments(env)))

cannonical_enviornments = collections.defaultdict(list)
for env_id, env in pipeline_proto.components.environments.items():
cannonical_enviornments[env_key(env)].append(env_id)

if len(cannonical_enviornments) == len(
pipeline_proto.components.environments):
# All environments are already sufficiently distinct.
return pipeline_proto

environment_remappings = {
e: es[0]
for es in cannonical_enviornments.values() for e in es
}

if not inplace:
pipeline_proto = copy.copy(pipeline_proto)

for t in pipeline_proto.components.transforms.values():
if t.environment_id:
t.environment_id = environment_remappings[t.environment_id]
for w in pipeline_proto.components.windowing_strategies.values():
if w.environment_id:
w.environment_id = environment_remappings[w.environment_id]
for e in set(pipeline_proto.components.environments.keys()) - set(
environment_remappings.values()):
del pipeline_proto.components.environments[e]
return pipeline_proto
59 changes: 59 additions & 0 deletions sdks/python/apache_beam/runners/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import PerWindowInvoker
from apache_beam.runners.common import merge_common_environments
from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -584,5 +587,61 @@ def test_window_observing_split_on_window_boundary_round_down_on_last_window(
self.assertEqual(stop_index, 2)


class UtilitiesTest(unittest.TestCase):
def test_equal_environments_merged(self):
pipeline_proto = merge_common_environments(
beam_runner_api_pb2.Pipeline(
components=beam_runner_api_pb2.Components(
environments={
'a1': beam_runner_api_pb2.Environment(urn='A'),
'a2': beam_runner_api_pb2.Environment(urn='A'),
'b1': beam_runner_api_pb2.Environment(
urn='B', payload=b'x'),
'b2': beam_runner_api_pb2.Environment(
urn='B', payload=b'x'),
'b3': beam_runner_api_pb2.Environment(
urn='B', payload=b'y'),
},
transforms={
't1': beam_runner_api_pb2.PTransform(
unique_name='t1', environment_id='a1'),
't2': beam_runner_api_pb2.PTransform(
unique_name='t2', environment_id='a2'),
},
windowing_strategies={
'w1': beam_runner_api_pb2.WindowingStrategy(
environment_id='b1'),
'w2': beam_runner_api_pb2.WindowingStrategy(
environment_id='b2'),
})))
self.assertEqual(len(pipeline_proto.components.environments), 3)
self.assertTrue(('a1' in pipeline_proto.components.environments)
^ ('a2' in pipeline_proto.components.environments))
self.assertTrue(('b1' in pipeline_proto.components.environments)
^ ('b2' in pipeline_proto.components.environments))
self.assertEqual(
len(
set(
t.environment_id
for t in pipeline_proto.components.transforms.values())),
1)
self.assertEqual(
len(
set(
w.environment_id for w in
pipeline_proto.components.windowing_strategies.values())),
1)

def test_external_merged(self):
p = beam.Pipeline()
# This transform recursively creates several external environments.
_ = p | FibTransform(4)
pipeline_proto = p.to_runner_api()
# All our external environments are equal and consolidated.
# We also have a placeholder "default" environment that has not been
# resolved do anything concrete yet.
self.assertEqual(len(pipeline_proto.components.environments), 2)


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.portability import common_urns
from apache_beam.runners.common import group_by_key_input_visitor
from apache_beam.runners.common import merge_common_environments
from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineRunner
Expand Down Expand Up @@ -419,6 +420,7 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None):
self.proto_pipeline.components.environments[env_id].CopyFrom(
environments.resolve_anyof_environment(
env, common_urns.environments.DOCKER.urn))
self.proto_pipeline = merge_common_environments(self.proto_pipeline)

# Optimize the pipeline if it not streaming and the pre_optimize
# experiment is set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import runner
from apache_beam.runners.common import group_by_key_input_visitor
from apache_beam.runners.common import merge_common_environments
from apache_beam.runners.common import validate_pipeline_graph
from apache_beam.runners.portability import portable_metrics
from apache_beam.runners.portability.fn_api_runner import execution
Expand Down Expand Up @@ -221,7 +222,8 @@ def run_via_runner_api(self, pipeline_proto, options):
]
if direct_options.direct_embed_docker_python:
pipeline_proto = self.embed_default_docker_image(pipeline_proto)
pipeline_proto = self.resolve_any_environments(pipeline_proto)
pipeline_proto = merge_common_environments(
self.resolve_any_environments(pipeline_proto))
stage_context, stages = self.create_stages(pipeline_proto)
return self.run_stages(stage_context, stages)

Expand Down

0 comments on commit 7f04c4f

Please sign in to comment.