Skip to content

Commit

Permalink
Informed choice of yaml provider for root transforms. (#28310)
Browse files Browse the repository at this point in the history
Beam now considers the possible providers of consuming transforms.

There is still room for improvement, but this should be a decent
general heuristic.
  • Loading branch information
robertwb authored Sep 8, 2023
1 parent 3a99c9d commit 19160f4
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 27 deletions.
16 changes: 14 additions & 2 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.typehints import trivial_inference
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform


Expand Down Expand Up @@ -121,14 +122,20 @@ def expand(self, pcoll):


RENDER_DIR = None
TEST_PROVIDERS = {
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeAggregation': SomeAggregation,
}


class TestProvider(yaml_provider.InlineProvider):
def _affinity(self, other):
# Always try to choose this one.
return float('inf')


class TestEnvironment:
def __enter__(self):
self.tempdir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -196,8 +203,13 @@ def test(self):
os.path.join(RENDER_DIR, test_name + '.png')
]
options['render_leaf_composite_nodes'] = ['.*']
test_provider = TestProvider(TEST_TRANSFORMS)
p = beam.Pipeline(options=PipelineOptions(**options))
yaml_transform.expand_pipeline(p, modified_yaml, TEST_PROVIDERS)
yaml_transform.expand_pipeline(
p,
modified_yaml,
{t: test_provider
for t in test_provider.provided_transforms()})
if test_type == 'BUILD':
return
p.run().wait_until_finish()
Expand Down
135 changes: 111 additions & 24 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import pprint
import re
import uuid
from typing import Any
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Set

import yaml
from yaml.loader import SafeLoader
Expand Down Expand Up @@ -162,13 +165,31 @@ def get_transform_id(self, transform_name):

class Scope(LightweightScope):
"""To look up PCollections (typically outputs of prior transforms) by name."""
def __init__(self, root, inputs, transforms, providers, input_providers):
def __init__(
self,
root,
inputs: Mapping[str, Any],
transforms: Iterable[dict],
providers: Mapping[str, Iterable[yaml_provider.Provider]],
input_providers: Iterable[yaml_provider.Provider]):
super().__init__(transforms)
self.root = root
self._inputs = inputs
self.providers = providers
self._seen_names = set()
self._seen_names: Set[str] = set()
self.input_providers = input_providers
self._all_followers = None

def followers(self, transform_name):
if self._all_followers is None:
self._all_followers = collections.defaultdict(list)
# TODO(yaml): Also trace through outputs and composites.
for transform in self._transforms:
if transform['type'] != 'composite':
for input in transform.get('input').values():
transform_id, _ = self.get_transform_id_and_output_name(input)
self._all_followers[transform_id].append(transform['__uuid__'])
return self._all_followers[self.get_transform_id(transform_name)]

def compute_all(self):
for transform_id in self._transforms_by_uuid.keys():
Expand Down Expand Up @@ -208,6 +229,77 @@ def get_outputs(self, transform_name):
def compute_outputs(self, transform_id):
return expand_transform(self._transforms_by_uuid[transform_id], self)

def best_provider(
self, t, input_providers: yaml_provider.Iterable[yaml_provider.Provider]):
if isinstance(t, dict):
spec = t
else:
spec = self._transforms_by_uuid[self.get_transform_id(t)]
possible_providers = [
p for p in self.providers[spec['type']] if p.available()
]
if not possible_providers:
raise ValueError(
'No available provider for type %r at %s' %
(spec['type'], identify_object(spec)))
# From here on, we have the invariant that possible_providers is not empty.

# Only one possible provider, no need to rank further.
if len(possible_providers) == 1:
return possible_providers[0]

def best_matches(
possible_providers: Iterable[yaml_provider.Provider],
adjacent_provider_options: Iterable[Iterable[yaml_provider.Provider]]
) -> List[yaml_provider.Provider]:
"""Given a set of possible providers, and a set of providers for each
adjacent transform, returns the top possible providers as ranked by
affinity to the adjacent transforms' providers.
"""
providers_by_score = collections.defaultdict(list)
for p in possible_providers:
# The sum of the affinity of the best provider
# for each adjacent transform.
providers_by_score[sum(
max(p.affinity(ap) for ap in apo)
for apo in adjacent_provider_options)].append(p)
return providers_by_score[max(providers_by_score.keys())]

# If there are any inputs, prefer to match them.
if input_providers:
possible_providers = best_matches(
possible_providers, [[p] for p in input_providers])

# Without __uuid__ we can't find downstream operations.
if '__uuid__' not in spec:
return possible_providers[0]

# Match against downstream transforms, continuing until there is no tie
# or we run out of downstream transforms.
if len(possible_providers) > 1:
adjacent_transforms = list(self.followers(spec['__uuid__']))
while adjacent_transforms:
# This is a list of all possible providers for each adjacent transform.
adjacent_provider_options = [[
p for p in self.providers[self._transforms_by_uuid[t]['type']]
if p.available()
] for t in adjacent_transforms]
if any(not apo for apo in adjacent_provider_options):
# One of the transforms had no available providers.
# We will throw an error later, doesn't matter what we return.
break
# Filter down the set of possible providers to the best ones.
possible_providers = best_matches(
possible_providers, adjacent_provider_options)
# If we are down to one option, no need to go further.
if len(possible_providers) == 1:
break
# Go downstream one more step.
adjacent_transforms = sum(
[list(self.followers(t)) for t in adjacent_transforms], [])

return possible_providers[0]

# A method on scope as providers may be scoped...
def create_ptransform(self, spec, input_pcolls):
if 'type' not in spec:
Expand All @@ -225,19 +317,7 @@ def create_ptransform(self, spec, input_pcolls):
providers_by_input[pcoll] for pcoll in input_pcolls
if pcoll in providers_by_input
]

def provider_score(p):
return sum(p.affinity(o) for o in input_providers)

for provider in sorted(self.providers.get(spec['type']),
key=provider_score,
reverse=True):
if provider.available():
break
else:
raise ValueError(
'No available provider for type %r at %s' %
(spec['type'], identify_object(spec)))
provider = self.best_provider(spec, input_providers)

config = SafeLineLoader.strip_metadata(spec.get('config', {}))
if not isinstance(config, dict):
Expand Down Expand Up @@ -510,14 +590,19 @@ def identify_object(spec):


def extract_name(spec):
if 'name' in spec:
return spec['name']
elif 'id' in spec:
return spec['id']
elif 'type' in spec:
return spec['type']
elif len(spec) == 1:
return extract_name(next(iter(spec.values())))
if isinstance(spec, dict):
if 'name' in spec:
return spec['name']
elif 'id' in spec:
return spec['id']
elif 'type' in spec:
return spec['type']
elif len(spec) == 1:
return extract_name(next(iter(spec.values())))
else:
return ''
elif isinstance(spec, str):
return spec
else:
return ''

Expand Down Expand Up @@ -602,7 +687,9 @@ def preprocess_windowing(spec):
'type': 'WindowInto',
'name': f'WindowInto[{out}]',
'windowing': windowing,
'input': modified_spec['__uuid__'] + ('.' + out if out else ''),
'input': {
'input': modified_spec['__uuid__'] + ('.' + out if out else '')
},
'__line__': spec['__line__'],
'__uuid__': SafeLineLoader.create_uuid(),
} for out in consumed_outputs]
Expand Down
142 changes: 142 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# limitations under the License.
#

import collections
import logging
import unittest

import yaml

import apache_beam as beam
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml.yaml_transform import LightweightScope
from apache_beam.yaml.yaml_transform import SafeLineLoader
from apache_beam.yaml.yaml_transform import Scope
Expand Down Expand Up @@ -136,6 +138,146 @@ def test_create_ptransform_with_inputs(self):
self.assertDictEqual(result_annotations, target_annotations)


class TestProvider(yaml_provider.InlineProvider):
def __init__(self, transform, name):
super().__init__({
name: lambda: beam.Map(lambda x: (x or ()) + (name, )), # or None
transform: lambda: beam.Map(lambda x: (x or ()) + (name, )),
})
self._transform = transform
self._name = name

def __repr__(self):
return 'TestProvider(%r, %r)' % (self._transform, self._name)

def _affinity(self, other):
if isinstance(other, TestProvider):
# Providers are closer based on how much their names match prefixes.
affinity = 1
for x, y in zip(self._name, other._name):
if x != y:
break
affinity *= 10
return affinity
else:
return -1000


class ProviderAffinityTest(unittest.TestCase):
@staticmethod
def create_scope(s, providers):
providers_dict = collections.defaultdict(list)
for provider in providers:
for transform_type in provider.provided_transforms():
providers_dict[transform_type].append(provider)
spec = yaml_transform.preprocess(yaml.load(s, Loader=SafeLineLoader))
return Scope(
None, {},
transforms=spec['transforms'],
providers=providers_dict,
input_providers={})

def test_best_provider_based_on_input(self):
provider_Ax = TestProvider('A', 'xxx')
provider_Ay = TestProvider('A', 'yyy')
provider_Bx = TestProvider('B', 'xxz')
provider_By = TestProvider('B', 'yyz')
scope = self.create_scope(
'''
type: chain
transforms:
- type: A
- type: B
''', [provider_Ax, provider_Ay, provider_Bx, provider_By])
self.assertEqual(scope.best_provider('B', [provider_Ax]), provider_Bx)
self.assertEqual(scope.best_provider('B', [provider_Ay]), provider_By)

def test_best_provider_based_on_followers(self):
close_provider = TestProvider('A', 'xxy')
far_provider = TestProvider('A', 'yyy')
following_provider = TestProvider('B', 'xxx')
scope = self.create_scope(
'''
type: chain
transforms:
- type: A
- type: B
''', [far_provider, close_provider, following_provider])
self.assertEqual(scope.best_provider('A', []), close_provider)

def test_best_provider_based_on_multiple_followers(self):
close_provider = TestProvider('A', 'xxy')
provider_B = TestProvider('B', 'xxx')
# These are not quite as close as the two above.
far_provider = TestProvider('A', 'yyy')
provider_C = TestProvider('C', 'yzz')
scope = self.create_scope(
'''
type: composite
transforms:
- type: A
- type: B
input: A
- type: C
input: A
''', [far_provider, close_provider, provider_B, provider_C])
self.assertEqual(scope.best_provider('A', []), close_provider)

def test_best_provider_based_on_distant_follower(self):
providers = [
# xxx and yyy vend both
TestProvider('A', 'xxx'),
TestProvider('A', 'yyy'),
TestProvider('B', 'xxx'),
TestProvider('B', 'yyy'),
TestProvider('C', 'xxx'),
TestProvider('C', 'yyy'),
# D and E are only provided by a single provider each.
TestProvider('D', 'xxx'),
TestProvider('E', 'yyy')
]

# If D is the eventual destination, pick the xxx one.
scope = self.create_scope(
'''
type: chain
transforms:
- type: A
- type: B
- type: C
- type: D
''',
providers)
self.assertEqual(scope.best_provider('A', []), providers[0])

# If instead E is the eventual destination, pick the yyy one.
scope = self.create_scope(
'''
type: chain
transforms:
- type: A
- type: B
- type: C
- type: E
''',
providers)
self.assertEqual(scope.best_provider('A', []), providers[1])

# If we have D then E, stay with xxx as long as possible to only switch once
scope = self.create_scope(
'''
type: chain
transforms:
- type: A
- type: B
- type: C
- type: D
- type: E
''',
providers)
self.assertEqual(scope.best_provider('A', []), providers[0])


class LightweightScopeTest(unittest.TestCase):
@staticmethod
def get_spec():
Expand Down
Loading

0 comments on commit 19160f4

Please sign in to comment.