Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Swing implementation (right now only when comm_size is a power of 2). #19

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/Schedgen2/mpi_colls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from goal import GoalComm
from patterns import binomialtree, recdoub, ring, linear
from patterns import binomialtree, recdoub, ring, linear, swing


def mpi_communication_pattern_selection(
Expand Down Expand Up @@ -226,6 +226,25 @@ def allreduce(
**kwargs,
)
)
elif ptrn == "swing":
comms.append(
swing(
comm_size=comm_size,
datasize=datasize,
tag=tag,
algorithm="reduce-scatter",
**kwargs,
)
)
comms.append(
swing(
comm_size=comm_size,
datasize=datasize,
tag=tag + comm_size,
algorithm="allgather",
**kwargs,
)
)
elif ptrn == "ring":
comms.append(
ring(
Expand Down
72 changes: 72 additions & 0 deletions src/Schedgen2/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,78 @@ def recdoub(
dependencies[rank] = calc
return comm

def swing(
comm_size: int,
datasize: int,
tag: int,
algorithm: str = "reduce-scatter",
compute_time_dependency: int = 0,
**kwargs,
) -> GoalComm:
"""
Create a Swing communication pattern.

:param comm_size: number of ranks in the communicator
:param datasize: size of data to send or receive
:param tag: tag that is used for all send and receive operations
:param algorithm: communication algorithm that uses this pattern; default is reduce-scatter
:param compute_time_dependency: compute time dependency for each send operation; if 0 (default), no compute time is added
:param kwargs: additional arguments that are ignored
:return: GoalComm object that represents the communication pattern
"""

assert algorithm in [
"reduce-scatter",
"allgather",
], f"the pattern does not currently support the {algorithm} algorithm"

if not log2(comm_size).is_integer():
raise ValueError("At the moment, Swing only support a number of ranks which is a power of 2")

# Add other values if you plan to run it on more than 2**20 nodes
rhos = [1, -1, 3, -5, 11, -21, 43, -85, 171, -341, 683, -1365, 2731, -5461, 10923, -21845, 43691, -87381, 174763, -349525]
comm = GoalComm(comm_size)
num_steps = int(log2(comm_size))
dependencies = [None] * comm_size

if num_steps > len(rhos):
raise ValueError("Please increase the values of rhos in the code.")

for r in range(num_steps):
for rank in range(comm_size):
if algorithm in ["reduce-scatter"]:
distance = rhos[r]
message_size = datasize // (2 ** (r + 1))
elif algorithm in ["allgather"]:
distance = rhos[num_steps - r - 1]
message_size = datasize // (2 ** (num_steps - r))
else:
raise ValueError(
f"the pattern does not currently support the {algorithm} algorithm"
)

# Flip the direction for odd ranks
if rank % 2:
distance = -distance

if (rank + distance) < 0:
dest = (rank + distance) + comm_size
else:
dest = (rank + distance) % comm_size

if dest < comm_size:
send = comm.Send(size=message_size, src=rank, dst=dest, tag=tag + r)
if dependencies[rank] is not None:
send.requires(dependencies[rank])
dependencies[rank] = comm.Recv(
size=message_size, src=dest, dst=rank, tag=tag + r
)
if compute_time_dependency > 0:
calc = comm.Calc(host=rank, size=compute_time_dependency)
calc.requires(dependencies[rank])
dependencies[rank] = calc
return comm


def ring(
comm_size: int,
Expand Down
5 changes: 4 additions & 1 deletion src/Schedgen2/schedgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
p.add_argument(
"--ptrn",
dest="ptrn",
choices=["datasize_based", "binomialtree", "recdoub", "ring", "linear"],
choices=["datasize_based", "binomialtree", "recdoub", "ring", "linear", "swing"],
default="datasize_based",
help="Pattern to use for communication, note that not all patterns are available for all communication types",
)
Expand Down Expand Up @@ -136,6 +136,9 @@ def verify_params(args):
assert (
args.ptrn != "recdoub" or args.comm_size & (args.comm_size - 1) == 0
), "Currently recdoub pattern requires a power of 2 communicator size."
assert (
args.ptrn != "swing" or args.comm_size & (args.comm_size - 1) == 0
), "Currently swing pattern requires a power of 2 communicator size."


def comm_to_func(comm: str) -> callable:
Expand Down