-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
iobase.py
1696 lines (1372 loc) · 66.2 KB
/
iobase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Sources and sinks.
A Source manages record-oriented data input from a particular kind of source
(e.g. a set of files, a database table, etc.). The reader() method of a source
returns a reader object supporting the iterator protocol; iteration yields
raw records of unprocessed, serialized data.
A Sink manages record-oriented data output to a particular kind of sink
(e.g. a set of files, a database table, etc.). The writer() method of a sink
returns a writer object supporting writing records of serialized data to
the sink.
"""
# pytype: skip-file
import logging
import math
import random
import uuid
from collections import namedtuple
from typing import Any
from typing import Iterator
from typing import Optional
from typing import Tuple
from typing import Union
from apache_beam import coders
from apache_beam import pvalue
from apache_beam.coders.coders import _MemoizingPickleCoder
from apache_beam.internal import pickler
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsIter
from apache_beam.pvalue import AsSingleton
from apache_beam.transforms import Impulse
from apache_beam.transforms import PTransform
from apache_beam.transforms import core
from apache_beam.transforms import ptransform
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.display import HasDisplayData
from apache_beam.utils import timestamp
from apache_beam.utils import urns
from apache_beam.utils.windowed_value import WindowedValue
__all__ = [
'BoundedSource',
'RangeTracker',
'Read',
'RestrictionProgress',
'RestrictionTracker',
'WatermarkEstimator',
'Sink',
'Write',
'Writer'
]
_LOGGER = logging.getLogger(__name__)
# Encapsulates information about a bundle of a source generated when method
# BoundedSource.split() is invoked.
# This is a named 4-tuple that has following fields.
# * weight - a number that represents the size of the bundle. This value will
# be used to compare the relative sizes of bundles generated by the
# current source.
# The weight returned here could be specified using a unit of your
# choice (for example, bundles of sizes 100MB, 200MB, and 700MB may
# specify weights 100, 200, 700 or 1, 2, 7) but all bundles of a
# source should specify the weight using the same unit.
# * source - a BoundedSource object for the bundle.
# * start_position - starting position of the bundle
# * stop_position - ending position of the bundle.
#
# Type for start and stop positions are specific to the bounded source and must
# be consistent throughout.
SourceBundle = namedtuple(
'SourceBundle', 'weight source start_position stop_position')
class SourceBase(HasDisplayData, urns.RunnerApiFn):
"""Base class for all sources that can be passed to beam.io.Read(...).
"""
urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE)
def is_bounded(self):
# type: () -> bool
raise NotImplementedError
class BoundedSource(SourceBase):
"""A source that reads a finite amount of input records.
This class defines following operations which can be used to read the source
efficiently.
* Size estimation - method ``estimate_size()`` may return an accurate
estimation in bytes for the size of the source.
* Splitting into bundles of a given size - method ``split()`` can be used to
split the source into a set of sub-sources (bundles) based on a desired
bundle size.
* Getting a RangeTracker - method ``get_range_tracker()`` should return a
``RangeTracker`` object for a given position range for the position type
of the records returned by the source.
* Reading the data - method ``read()`` can be used to read data from the
source while respecting the boundaries defined by a given
``RangeTracker``.
A runner will perform reading the source in two steps.
(1) Method ``get_range_tracker()`` will be invoked with start and end
positions to obtain a ``RangeTracker`` for the range of positions the
runner intends to read. Source must define a default initial start and end
position range. These positions must be used if the start and/or end
positions passed to the method ``get_range_tracker()`` are ``None``
(2) Method read() will be invoked with the ``RangeTracker`` obtained in the
previous step.
**Mutability**
A ``BoundedSource`` object should not be mutated while
its methods (for example, ``read()``) are being invoked by a runner. Runner
implementations may invoke methods of ``BoundedSource`` objects through
multi-threaded and/or reentrant execution modes.
"""
def estimate_size(self):
# type: () -> Optional[int]
"""Estimates the size of source in bytes.
An estimate of the total size (in bytes) of the data that would be read
from this source. This estimate is in terms of external storage size,
before performing decompression or other processing.
Returns:
estimated size of the source if the size can be determined, ``None``
otherwise.
"""
raise NotImplementedError
def split(self,
desired_bundle_size, # type: int
start_position=None, # type: Optional[Any]
stop_position=None, # type: Optional[Any]
):
# type: (...) -> Iterator[SourceBundle]
"""Splits the source into a set of bundles.
Bundles should be approximately of size ``desired_bundle_size`` bytes.
Args:
desired_bundle_size: the desired size (in bytes) of the bundles returned.
start_position: if specified the given position must be used as the
starting position of the first bundle.
stop_position: if specified the given position must be used as the ending
position of the last bundle.
Returns:
an iterator of objects of type 'SourceBundle' that gives information about
the generated bundles.
"""
raise NotImplementedError
def get_range_tracker(self,
start_position, # type: Optional[Any]
stop_position, # type: Optional[Any]
):
# type: (...) -> RangeTracker
"""Returns a RangeTracker for a given position range.
Framework may invoke ``read()`` method with the RangeTracker object returned
here to read data from the source.
Args:
start_position: starting position of the range. If 'None' default start
position of the source must be used.
stop_position: ending position of the range. If 'None' default stop
position of the source must be used.
Returns:
a ``RangeTracker`` for the given position range.
"""
raise NotImplementedError
def read(self, range_tracker):
"""Returns an iterator that reads data from the source.
The returned set of data must respect the boundaries defined by the given
``RangeTracker`` object. For example:
* Returned set of data must be for the range
``[range_tracker.start_position, range_tracker.stop_position)``. Note
that a source may decide to return records that start after
``range_tracker.stop_position``. See documentation in class
``RangeTracker`` for more details. Also, note that framework might
invoke ``range_tracker.try_split()`` to perform dynamic split
operations. range_tracker.stop_position may be updated
dynamically due to successful dynamic split operations.
* Method ``range_tracker.try_split()`` must be invoked for every record
that starts at a split point.
* Method ``range_tracker.record_current_position()`` may be invoked for
records that do not start at split points.
Args:
range_tracker: a ``RangeTracker`` whose boundaries must be respected
when reading data from the source. A runner that reads this
source muss pass a ``RangeTracker`` object that is not
``None``.
Returns:
an iterator of data read by the source.
"""
raise NotImplementedError
def default_output_coder(self):
"""Coder that should be used for the records returned by the source.
Should be overridden by sources that produce objects that can be encoded
more efficiently than pickling.
"""
return coders.registry.get_coder(object)
def is_bounded(self):
return True
class RangeTracker(object):
"""A thread safe object used by Dataflow source framework.
A Dataflow source is defined using a ''BoundedSource'' and a ''RangeTracker''
pair. A ''RangeTracker'' is used by Dataflow source framework to perform
dynamic work rebalancing of position-based sources.
**Position-based sources**
A position-based source is one where the source can be described by a range
of positions of an ordered type and the records returned by the reader can be
described by positions of the same type.
In case a record occupies a range of positions in the source, the most
important thing about the record is the position where it starts.
Defining the semantics of positions for a source is entirely up to the source
class, however the chosen definitions have to obey certain properties in order
to make it possible to correctly split the source into parts, including
dynamic splitting. Two main aspects need to be defined:
1. How to assign starting positions to records.
2. Which records should be read by a source with a range '[A, B)'.
Moreover, reading a range must be *efficient*, i.e., the performance of
reading a range should not significantly depend on the location of the range.
For example, reading the range [A, B) should not require reading all data
before 'A'.
The sections below explain exactly what properties these definitions must
satisfy, and how to use a ``RangeTracker`` with a properly defined source.
**Properties of position-based sources**
The main requirement for position-based sources is *associativity*: reading
records from '[A, B)' and records from '[B, C)' should give the same
records as reading from '[A, C)', where 'A <= B <= C'. This property
ensures that no matter how a range of positions is split into arbitrarily many
sub-ranges, the total set of records described by them stays the same.
The other important property is how the source's range relates to positions of
records in the source. In many sources each record can be identified by a
unique starting position. In this case:
* All records returned by a source '[A, B)' must have starting positions in
this range.
* All but the last record should end within this range. The last record may or
may not extend past the end of the range.
* Records should not overlap.
Such sources should define "read '[A, B)'" as "read from the first record
starting at or after 'A', up to but not including the first record starting
at or after 'B'".
Some examples of such sources include reading lines or CSV from a text file,
reading keys and values from a BigTable, etc.
The concept of *split points* allows to extend the definitions for dealing
with sources where some records cannot be identified by a unique starting
position.
In all cases, all records returned by a source '[A, B)' must *start* at or
after 'A'.
**Split points**
Some sources may have records that are not directly addressable. For example,
imagine a file format consisting of a sequence of compressed blocks. Each
block can be assigned an offset, but records within the block cannot be
directly addressed without decompressing the block. Let us refer to this
hypothetical format as <i>CBF (Compressed Blocks Format)</i>.
Many such formats can still satisfy the associativity property. For example,
in CBF, reading '[A, B)' can mean "read all the records in all blocks whose
starting offset is in '[A, B)'".
To support such complex formats, we introduce the notion of *split points*. We
say that a record is a split point if there exists a position 'A' such that
the record is the first one to be returned when reading the range
'[A, infinity)'. In CBF, the only split points would be the first records
in each block.
Split points allow us to define the meaning of a record's position and a
source's range in all cases:
* For a record that is at a split point, its position is defined to be the
largest 'A' such that reading a source with the range '[A, infinity)'
returns this record.
* Positions of other records are only required to be non-decreasing.
* Reading the source '[A, B)' must return records starting from the first
split point at or after 'A', up to but not including the first split point
at or after 'B'. In particular, this means that the first record returned
by a source MUST always be a split point.
* Positions of split points must be unique.
As a result, for any decomposition of the full range of the source into
position ranges, the total set of records will be the full set of records in
the source, and each record will be read exactly once.
**Consumed positions**
As the source is being read, and records read from it are being passed to the
downstream transforms in the pipeline, we say that positions in the source are
being *consumed*. When a reader has read a record (or promised to a caller
that a record will be returned), positions up to and including the record's
start position are considered *consumed*.
Dynamic splitting can happen only at *unconsumed* positions. If the reader
just returned a record at offset 42 in a file, dynamic splitting can happen
only at offset 43 or beyond, as otherwise that record could be read twice (by
the current reader and by a reader of the task starting at 43).
"""
SPLIT_POINTS_UNKNOWN = object()
def start_position(self):
"""Returns the starting position of the current range, inclusive."""
raise NotImplementedError(type(self))
def stop_position(self):
"""Returns the ending position of the current range, exclusive."""
raise NotImplementedError(type(self))
def try_claim(self, position): # pylint: disable=unused-argument
"""Atomically determines if a record at a split point is within the range.
This method should be called **if and only if** the record is at a split
point. This method may modify the internal state of the ``RangeTracker`` by
updating the last-consumed position to ``position``.
** Thread safety **
Methods of the class ``RangeTracker`` including this method may get invoked
by different threads, hence must be made thread-safe, e.g. by using a single
lock object.
Args:
position: starting position of a record being read by a source.
Returns:
``True``, if the given position falls within the current range, returns
``False`` otherwise.
"""
raise NotImplementedError
def set_current_position(self, position):
"""Updates the last-consumed position to the given position.
A source may invoke this method for records that do not start at split
points. This may modify the internal state of the ``RangeTracker``. If the
record starts at a split point, method ``try_claim()`` **must** be invoked
instead of this method.
Args:
position: starting position of a record being read by a source.
"""
raise NotImplementedError
def position_at_fraction(self, fraction):
"""Returns the position at the given fraction.
Given a fraction within the range [0.0, 1.0) this method will return the
position at the given fraction compared to the position range
[self.start_position, self.stop_position).
** Thread safety **
Methods of the class ``RangeTracker`` including this method may get invoked
by different threads, hence must be made thread-safe, e.g. by using a single
lock object.
Args:
fraction: a float value within the range [0.0, 1.0).
Returns:
a position within the range [self.start_position, self.stop_position).
"""
raise NotImplementedError
def try_split(self, position):
"""Atomically splits the current range.
Determines a position to split the current range, split_position, based on
the given position. In most cases split_position and position will be the
same.
Splits the current range '[self.start_position, self.stop_position)'
into a "primary" part '[self.start_position, split_position)' and a
"residual" part '[split_position, self.stop_position)', assuming the
current last-consumed position is within
'[self.start_position, split_position)' (i.e., split_position has not been
consumed yet).
If successful, updates the current range to be the primary and returns a
tuple (split_position, split_fraction). split_fraction should be the
fraction of size of range '[self.start_position, split_position)' compared
to the original (before split) range
'[self.start_position, self.stop_position)'.
If the split_position has already been consumed, returns ``None``.
** Thread safety **
Methods of the class ``RangeTracker`` including this method may get invoked
by different threads, hence must be made thread-safe, e.g. by using a single
lock object.
Args:
position: suggested position where the current range should try to
be split at.
Returns:
a tuple containing the split position and split fraction if split is
successful. Returns ``None`` otherwise.
"""
raise NotImplementedError
def fraction_consumed(self):
"""Returns the approximate fraction of consumed positions in the source.
** Thread safety **
Methods of the class ``RangeTracker`` including this method may get invoked
by different threads, hence must be made thread-safe, e.g. by using a single
lock object.
Returns:
the approximate fraction of positions that have been consumed by
successful 'try_split()' and 'try_claim()' calls, or
0.0 if no such calls have happened.
"""
raise NotImplementedError
def split_points(self):
"""Gives the number of split points consumed and remaining.
For a ``RangeTracker`` used by a ``BoundedSource`` (within a
``BoundedSource.read()`` invocation) this method produces a 2-tuple that
gives the number of split points consumed by the ``BoundedSource`` and the
number of split points remaining within the range of the ``RangeTracker``
that has not been consumed by the ``BoundedSource``.
More specifically, given that the position of the current record being read
by ``BoundedSource`` is current_position this method produces a tuple that
consists of
(1) number of split points in the range [self.start_position(),
current_position) without including the split point that is currently being
consumed. This represents the total amount of parallelism in the consumed
part of the source.
(2) number of split points within the range
[current_position, self.stop_position()) including the split point that is
currently being consumed. This represents the total amount of parallelism in
the unconsumed part of the source.
Methods of the class ``RangeTracker`` including this method may get invoked
by different threads, hence must be made thread-safe, e.g. by using a single
lock object.
** General information about consumed and remaining number of split
points returned by this method. **
* Before a source read (``BoundedSource.read()`` invocation) claims the
first split point, number of consumed split points is 0. This condition
holds independent of whether the input is "splittable". A splittable
source is a source that has more than one split point.
* Any source read that has only claimed one split point has 0 consumed
split points since the first split point is the current split point and
is still being processed. This condition holds independent of whether
the input is splittable.
* For an empty source read which never invokes
``RangeTracker.try_claim()``, the consumed number of split points is 0.
This condition holds independent of whether the input is splittable.
* For a source read which has invoked ``RangeTracker.try_claim()`` n
times, the consumed number of split points is n -1.
* If a ``BoundedSource`` sets a callback through function
``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use that
callback when determining remaining number of split points.
* Remaining split points should include the split point that is currently
being consumed by the source read. Hence if the above callback returns
an integer value n, remaining number of split points should be (n + 1).
* After last split point is claimed remaining split points becomes 1,
because this unfinished read itself represents an unfinished split
point.
* After all records of the source has been consumed, remaining number of
split points becomes 0 and consumed number of split points becomes equal
to the total number of split points within the range being read by the
source. This method does not address this condition and will continue to
report number of consumed split points as
("total number of split points" - 1) and number of remaining split
points as 1. A runner that performs the reading of the source can
detect when all records have been consumed and adjust remaining and
consumed number of split points accordingly.
** Examples **
(1) A "perfectly splittable" input which can be read in parallel down to the
individual records.
Consider a perfectly splittable input that consists of 50 split points.
* Before a source read (``BoundedSource.read()`` invocation) claims the
first split point, number of consumed split points is 0 number of
remaining split points is 50.
* After claiming first split point, consumed number of split points is 0
and remaining number of split is 50.
* After claiming split point #30, consumed number of split points is 29
and remaining number of split points is 21.
* After claiming all 50 split points, consumed number of split points is
49 and remaining number of split points is 1.
(2) a "block-compressed" file format such as ``avroio``, in which a block of
records has to be read as a whole, but different blocks can be read in
parallel.
Consider a block compressed input that consists of 5 blocks.
* Before a source read (``BoundedSource.read()`` invocation) claims the
first split point (first block), number of consumed split points is 0
number of remaining split points is 5.
* After claiming first split point, consumed number of split points is 0
and remaining number of split is 5.
* After claiming split point #3, consumed number of split points is 2
and remaining number of split points is 3.
* After claiming all 5 split points, consumed number of split points is
4 and remaining number of split points is 1.
(3) an "unsplittable" input such as a cursor in a database or a gzip
compressed file.
Such an input is considered to have only a single split point. Number of
consumed split points is always 0 and number of remaining split points
is always 1.
By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` for
both consumed and remaining number of split points, which indicates that the
number of split points consumed and remaining is unknown.
Returns:
A pair that gives consumed and remaining number of split points. Consumed
number of split points should be an integer larger than or equal to zero
or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split points
should be an integer larger than zero or
``RangeTracker.SPLIT_POINTS_UNKNOWN``.
"""
return (
RangeTracker.SPLIT_POINTS_UNKNOWN, RangeTracker.SPLIT_POINTS_UNKNOWN)
def set_split_points_unclaimed_callback(self, callback):
"""Sets a callback for determining the unclaimed number of split points.
By invoking this function, a ``BoundedSource`` can set a callback function
that may get invoked by the ``RangeTracker`` to determine the number of
unclaimed split points. A split point is unclaimed if
``RangeTracker.try_claim()`` method has not been successfully invoked for
that particular split point. The callback function accepts a single
parameter, a stop position for the BoundedSource (stop_position). If the
record currently being consumed by the ``BoundedSource`` is at position
current_position, callback should return the number of split points within
the range (current_position, stop_position). Note that, this should not
include the split point that is currently being consumed by the source.
This function must be implemented by subclasses before being used.
Args:
callback: a function that takes a single parameter, a stop position,
and returns unclaimed number of split points for the source read
operation that is calling this function. Value returned from
callback should be either an integer larger than or equal to
zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
"""
raise NotImplementedError
class Sink(HasDisplayData):
"""This class is deprecated, no backwards-compatibility guarantees.
A resource that can be written to using the ``beam.io.Write`` transform.
Here ``beam`` stands for Apache Beam Python code imported in following manner.
``import apache_beam as beam``.
A parallel write to an ``iobase.Sink`` consists of three phases:
1. A sequential *initialization* phase (e.g., creating a temporary output
directory, etc.)
2. A parallel write phase where workers write *bundles* of records
3. A sequential *finalization* phase (e.g., committing the writes, merging
output files, etc.)
Implementing a new sink requires extending two classes.
1. iobase.Sink
``iobase.Sink`` is an immutable logical description of the location/resource
to write to. Depending on the type of sink, it may contain fields such as the
path to an output directory on a filesystem, a database table name,
etc. ``iobase.Sink`` provides methods for performing a write operation to the
sink described by it. To this end, implementors of an extension of
``iobase.Sink`` must implement three methods:
``initialize_write()``, ``open_writer()``, and ``finalize_write()``.
2. iobase.Writer
``iobase.Writer`` is used to write a single bundle of records. An
``iobase.Writer`` defines two methods: ``write()`` which writes a
single record from the bundle and ``close()`` which is called once
at the end of writing a bundle.
See also ``apache_beam.io.filebasedsink.FileBasedSink`` which provides a
simpler API for writing sinks that produce files.
**Execution of the Write transform**
``initialize_write()``, ``pre_finalize()``, and ``finalize_write()`` are
conceptually called once. However, implementors must
ensure that these methods are *idempotent*, as they may be called multiple
times on different machines in the case of failure/retry. A method may be
called more than once concurrently, in which case it's okay to have a
transient failure (such as due to a race condition). This failure should not
prevent subsequent retries from succeeding.
``initialize_write()`` should perform any initialization that needs to be done
prior to writing to the sink. ``initialize_write()`` may return a result
(let's call this ``init_result``) that contains any parameters it wants to
pass on to its writers about the sink. For example, a sink that writes to a
file system may return an ``init_result`` that contains a dynamically
generated unique directory to which data should be written.
To perform writing of a bundle of elements, Dataflow execution engine will
create an ``iobase.Writer`` using the implementation of
``iobase.Sink.open_writer()``. When invoking ``open_writer()`` execution
engine will provide the ``init_result`` returned by ``initialize_write()``
invocation as well as a *bundle id* (let's call this ``bundle_id``) that is
unique for each invocation of ``open_writer()``.
Execution engine will then invoke ``iobase.Writer.write()`` implementation for
each element that has to be written. Once all elements of a bundle are
written, execution engine will invoke ``iobase.Writer.close()`` implementation
which should return a result (let's call this ``write_result``) that contains
information that encodes the result of the write and, in most cases, some
encoding of the unique bundle id. For example, if each bundle is written to a
unique temporary file, ``close()`` method may return an object that contains
the temporary file name. After writing of all bundles is complete, execution
engine will invoke ``pre_finalize()`` and then ``finalize_write()``
implementation.
The execution of a write transform can be illustrated using following pseudo
code (assume that the outer for loop happens in parallel across many
machines)::
init_result = sink.initialize_write()
write_results = []
for bundle in partition(pcoll):
writer = sink.open_writer(init_result, generate_bundle_id())
for elem in bundle:
writer.write(elem)
write_results.append(writer.close())
pre_finalize_result = sink.pre_finalize(init_result, write_results)
sink.finalize_write(init_result, write_results, pre_finalize_result)
**init_result**
Methods of 'iobase.Sink' should agree on the 'init_result' type that will be
returned when initializing the sink. This type can be a client-defined object
or an existing type. The returned type must be picklable using Dataflow coder
``coders.PickleCoder``. Returning an init_result is optional.
**bundle_id**
In order to ensure fault-tolerance, a bundle may be executed multiple times
(e.g., in the event of failure/retry or for redundancy). However, exactly one
of these executions will have its result passed to the
``iobase.Sink.finalize_write()`` method. Each call to
``iobase.Sink.open_writer()`` is passed a unique bundle id when it is called
by the ``WriteImpl`` transform, so even redundant or retried bundles will have
a unique way of identifying their output.
The bundle id should be used to guarantee that a bundle's output is unique.
This uniqueness guarantee is important; if a bundle is to be output to a file,
for example, the name of the file must be unique to avoid conflicts with other
writers. The bundle id should be encoded in the writer result returned by the
writer and subsequently used by the ``finalize_write()`` method to identify
the results of successful writes.
For example, consider the scenario where a Writer writes files containing
serialized records and the ``finalize_write()`` is to merge or rename these
output files. In this case, a writer may use its unique id to name its output
file (to avoid conflicts) and return the name of the file it wrote as its
writer result. The ``finalize_write()`` will then receive an ``Iterable`` of
output file names that it can then merge or rename using some bundle naming
scheme.
**write_result**
``iobase.Writer.close()`` and ``finalize_write()`` implementations must agree
on type of the ``write_result`` object returned when invoking
``iobase.Writer.close()``. This type can be a client-defined object or
an existing type. The returned type must be picklable using Dataflow coder
``coders.PickleCoder``. Returning a ``write_result`` when
``iobase.Writer.close()`` is invoked is optional but if unique
``write_result`` objects are not returned, sink should, guarantee idempotency
when same bundle is written multiple times due to failure/retry or redundancy.
**More information**
For more information on creating new sinks please refer to the official
documentation at
``https://beam.apache.org/documentation/sdks/python-custom-io#creating-sinks``
"""
# Whether Beam should skip writing any shards if all are empty.
skip_if_empty = False
def initialize_write(self):
"""Initializes the sink before writing begins.
Invoked before any data is written to the sink.
Please see documentation in ``iobase.Sink`` for an example.
Returns:
An object that contains any sink specific state generated by
initialization. This object will be passed to open_writer() and
finalize_write() methods.
"""
raise NotImplementedError
def open_writer(self, init_result, uid):
"""Opens a writer for writing a bundle of elements to the sink.
Args:
init_result: the result of initialize_write() invocation.
uid: a unique identifier generated by the system.
Returns:
an ``iobase.Writer`` that can be used to write a bundle of records to the
current sink.
"""
raise NotImplementedError
def pre_finalize(self, init_result, writer_results):
"""Pre-finalization stage for sink.
Called after all bundle writes are complete and before finalize_write.
Used to setup and verify filesystem and sink states.
Args:
init_result: the result of ``initialize_write()`` invocation.
writer_results: an iterable containing results of ``Writer.close()``
invocations. This will only contain results of successful writes, and
will only contain the result of a single successful write for a given
bundle.
Returns:
An object that contains any sink specific state generated.
This object will be passed to finalize_write().
"""
raise NotImplementedError
def finalize_write(self, init_result, writer_results, pre_finalize_result):
"""Finalizes the sink after all data is written to it.
Given the result of initialization and an iterable of results from bundle
writes, performs finalization after writing and closes the sink. Called
after all bundle writes are complete.
The bundle write results that are passed to finalize are those returned by
bundles that completed successfully. Although bundles may have been run
multiple times (for fault-tolerance), only one writer result will be passed
to finalize for each bundle. An implementation of finalize should perform
clean up of any failed and successfully retried bundles. Note that these
failed bundles will not have their writer result passed to finalize, so
finalize should be capable of locating any temporary/partial output written
by failed bundles.
If all retries of a bundle fails, the whole pipeline will fail *without*
finalize_write() being invoked.
A best practice is to make finalize atomic. If this is impossible given the
semantics of the sink, finalize should be idempotent, as it may be called
multiple times in the case of failure/retry or for redundancy.
Note that the iteration order of the writer results is not guaranteed to be
consistent if finalize is called multiple times.
Args:
init_result: the result of ``initialize_write()`` invocation.
writer_results: an iterable containing results of ``Writer.close()``
invocations. This will only contain results of successful writes, and
will only contain the result of a single successful write for a given
bundle.
pre_finalize_result: the result of ``pre_finalize()`` invocation.
"""
raise NotImplementedError
class Writer(object):
"""This class is deprecated, no backwards-compatibility guarantees.
Writes a bundle of elements from a ``PCollection`` to a sink.
A Writer ``iobase.Writer.write()`` writes and elements to the sink while
``iobase.Writer.close()`` is called after all elements in the bundle have been
written.
See ``iobase.Sink`` for more detailed documentation about the process of
writing to a sink.
"""
def write(self, value):
"""Writes a value to the sink using the current writer.
"""
raise NotImplementedError
def close(self):
"""Closes the current writer.
Please see documentation in ``iobase.Sink`` for an example.
Returns:
An object representing the writes that were performed by the current
writer.
"""
raise NotImplementedError
def at_capacity(self) -> bool:
"""Returns whether this writer should be considered at capacity
and a new one should be created.
"""
return False
class Read(ptransform.PTransform):
"""A transform that reads a PCollection."""
# Import runners here to prevent circular imports
from apache_beam.runners.pipeline_context import PipelineContext
def __init__(self, source):
# type: (SourceBase) -> None
"""Initializes a Read transform.
Args:
source: Data source to read from.
"""
super().__init__()
self.source = source
@staticmethod
def get_desired_chunk_size(total_size):
if total_size:
# 1MB = 1 shard, 1GB = 32 shards, 1TB = 1000 shards, 1PB = 32k shards
chunk_size = max(1 << 20, 1000 * int(math.sqrt(total_size)))
else:
chunk_size = 64 << 20 # 64mb
return chunk_size
def expand(self, pbegin):
if isinstance(self.source, BoundedSource):
coders.registry.register_coder(BoundedSource, _MemoizingPickleCoder)
display_data = self.source.display_data() or {}
display_data['source'] = self.source.__class__
return (
pbegin
| Impulse()
| core.Map(lambda _: self.source).with_output_types(BoundedSource)
| SDFBoundedSourceReader(display_data))
elif isinstance(self.source, ptransform.PTransform):
# The Read transform can also admit a full PTransform as an input
# rather than an anctual source. If the input is a PTransform, then
# just apply it directly.
return pbegin.pipeline | self.source
else:
# Treat Read itself as a primitive.
return pvalue.PCollection(
pbegin.pipeline, is_bounded=self.source.is_bounded())
def get_windowing(self, unused_inputs):
# type: (...) -> core.Windowing
return core.Windowing(window.GlobalWindows())
def _infer_output_coder(self, input_type=None, input_coder=None):
# type: (...) -> Optional[coders.Coder]
from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
if isinstance(self.source, BoundedSource):
return self.source.default_output_coder()
elif isinstance(self.source, dataflow_io.NativeSource):
return self.source.coder
else:
return None
def display_data(self):
return {
'source': DisplayDataItem(self.source.__class__, label='Read Source'),
'source_dd': self.source
}
def to_runner_api_parameter(
self,
context: PipelineContext,
) -> Tuple[str, Any]:
from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
if isinstance(self.source, (BoundedSource, dataflow_io.NativeSource)):
from apache_beam.io.gcp.pubsub import _PubSubSource
if isinstance(self.source, _PubSubSource):
return (
common_urns.composites.PUBSUB_READ.urn,
beam_runner_api_pb2.PubSubReadPayload(
topic=self.source.full_topic,
subscription=self.source.full_subscription,
timestamp_attribute=self.source.timestamp_attribute,
with_attributes=self.source.with_attributes,
id_attribute=self.source.id_label))
return (
common_urns.deprecated_primitives.READ.urn,
beam_runner_api_pb2.ReadPayload(
source=self.source.to_runner_api(context),
is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
if self.source.is_bounded() else
beam_runner_api_pb2.IsBounded.UNBOUNDED))
elif isinstance(self.source, ptransform.PTransform):
return self.source.to_runner_api_parameter(context)
raise NotImplementedError(
"to_runner_api_parameter not "
"implemented for type")
@staticmethod
def from_runner_api_parameter(
transform: beam_runner_api_pb2.PTransform,
payload: Union[beam_runner_api_pb2.ReadPayload,
beam_runner_api_pb2.PubSubReadPayload],
context: PipelineContext,
) -> "Read":
if transform.spec.urn == common_urns.composites.PUBSUB_READ.urn:
assert isinstance(payload, beam_runner_api_pb2.PubSubReadPayload)
# Importing locally to prevent circular dependencies.
from apache_beam.io.gcp.pubsub import _PubSubSource
source = _PubSubSource(
topic=payload.topic or None,
subscription=payload.subscription or None,
id_label=payload.id_attribute or None,
with_attributes=payload.with_attributes,
timestamp_attribute=payload.timestamp_attribute or None)
return Read(source)
else:
assert isinstance(payload, beam_runner_api_pb2.ReadPayload)
return Read(SourceBase.from_runner_api(payload.source, context))
@staticmethod
def _from_runner_api_parameter_read(
transform: beam_runner_api_pb2.PTransform,
payload: beam_runner_api_pb2.ReadPayload,
context: PipelineContext,
) -> "Read":
"""Method for type proxying when calling register_urn due to limitations
in type exprs in Python"""
return Read.from_runner_api_parameter(transform, payload, context)