Skip to content

Commit

Permalink
Refactor breakup functions (#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
abulenok authored Sep 1, 2023
1 parent 42f8485 commit 82af010
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
31 changes: 18 additions & 13 deletions PySDM/backends/impl_numba/methods/collisions_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def coalesce( # pylint: disable=too-many-arguments


@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
def breakup0_compute_mult_transfer(
def compute_transfer_multiplicities(
gamma, j, k, multiplicity, volume, fragment_size_i, max_multiplicity
): # pylint: disable=too-many-arguments
overflow_flag = False
Expand Down Expand Up @@ -91,27 +91,28 @@ def breakup0_compute_mult_transfer(


@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
def breakup1_update_mult_attributes(
def get_new_multiplicities_and_update_attributes(
j, k, attributes, multiplicity, take_from_j, new_mult_k
): # pylint: disable=too-many-arguments
for a in range(len(attributes)):
attributes[a, k] *= multiplicity[k]
attributes[a, k] += take_from_j * attributes[a, j]
attributes[a, k] /= new_mult_k

if multiplicity[j] == take_from_j:
if multiplicity[j] > take_from_j:
nj = multiplicity[j] - take_from_j
nk = new_mult_k

else: # take_from_j == multiplicity[j]
nj = new_mult_k / 2
nk = nj
for a in range(len(attributes)):
attributes[a, j] = attributes[a, k]
else: # take_from_j < multiplicity[j]
nj = multiplicity[j] - take_from_j
nk = new_mult_k
return nj, nk


@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
def breakup2_round_mults_to_ints(
def round_multiplicities_to_ints_and_update_attributes(
j,
k,
nj,
Expand Down Expand Up @@ -144,7 +145,7 @@ def break_up( # pylint: disable=too-many-arguments,c,too-many-locals
warn_overflows,
volume,
): # breakup0 guarantees take_from_j <= multiplicity[j]
take_from_j, new_mult_k, gamma_j_k, overflow_flag = breakup0_compute_mult_transfer(
take_from_j, new_mult_k, gamma_j_k, overflow_flag = compute_transfer_multiplicities(
gamma[i],
j,
k,
Expand All @@ -156,15 +157,17 @@ def break_up( # pylint: disable=too-many-arguments,c,too-many-locals
gamma_deficit = gamma[i] - gamma_j_k

# breakup1 also handles new_n[j] == 0 case via splitting
nj, nk = breakup1_update_mult_attributes(
nj, nk = get_new_multiplicities_and_update_attributes(
j, k, attributes, multiplicity, take_from_j, new_mult_k
)

atomic_add(breakup_rate, cid, gamma_j_k * multiplicity[k])
atomic_add(breakup_rate_deficit, cid, gamma_deficit * multiplicity[k])

# breakup2 also guarantees that no multiplicities are set to 0
breakup2_round_mults_to_ints(j, k, nj, nk, attributes, multiplicity)
round_multiplicities_to_ints_and_update_attributes(
j, k, nj, nk, attributes, multiplicity
)
if overflow_flag and warn_overflows:
warn("overflow", __file__)

Expand Down Expand Up @@ -207,7 +210,7 @@ def break_up_while(
new_mult_k,
gamma_j_k,
overflow_flag,
) = breakup0_compute_mult_transfer(
) = compute_transfer_multiplicities(
gamma_deficit,
j,
k,
Expand All @@ -217,13 +220,15 @@ def break_up_while(
max_multiplicity,
)

nj, nk = breakup1_update_mult_attributes(
nj, nk = get_new_multiplicities_and_update_attributes(
j, k, attributes, multiplicity, take_from_j, new_mult_k
)

atomic_add(breakup_rate, cid, gamma_j_k * multiplicity[k])
gamma_deficit -= gamma_j_k
breakup2_round_mults_to_ints(j, k, nj, nk, attributes, multiplicity)
round_multiplicities_to_ints_and_update_attributes(
j, k, nj, nk, attributes, multiplicity
)

atomic_add(breakup_rate_deficit, cid, gamma_deficit * multiplicity[k])

Expand Down
26 changes: 14 additions & 12 deletions PySDM/backends/impl_thrust_rtc/methods/collisions_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
return false;
}
static __device__ auto breakup_fun0(
static __device__ auto compute_transfer_multiplicities(
real_type gamma,
int64_t j,
int64_t k,
Expand Down Expand Up @@ -107,7 +107,7 @@
return gamma_j_k;
}
static __device__ void breakup_fun1(
static __device__ void get_new_multiplicities_and_update_attributes(
int64_t j,
int64_t k,
VectorView<real_type> attributes,
Expand All @@ -131,10 +131,14 @@
} else {
nj[0] = new_mult_k / 2;
nk[0] = nj[0];
for (auto a = 0; a < n_attr; a += 1) {
attributes[a + j] = attributes[a + k];
}
}
}
static __device__ void breakup_fun2(
static __device__ void round_multiplicities_to_ints_and_update_attributes(
int64_t j,
int64_t k,
real_type nj,
Expand All @@ -144,12 +148,6 @@
real_type take_from_j,
int64_t n_attr
) {
if (multiplicity[j] <= take_from_j) {
for (auto a = 0; a < n_attr; a += 1) {
attributes[a + j] = attributes[a + k];
}
}
multiplicity[j] = max((int64_t)(round(nj)), (int64_t)(1));
multiplicity[k] = max((int64_t)(round(nk)), (int64_t)(1));
auto factor_j = nj / multiplicity[j];
Expand Down Expand Up @@ -178,7 +176,7 @@
) {
real_type take_from_j[1] = {}; // float
real_type new_mult_k[1] = {}; // float
auto gamma_j_k = Commons::breakup_fun0(
auto gamma_j_k = Commons::compute_transfer_multiplicities(
gamma[i],
j,
k,
Expand All @@ -194,7 +192,9 @@
real_type nj[1] = {}; // float
real_type nk[1] = {}; // float
Commons::breakup_fun1(j, k, attributes, multiplicity, take_from_j[0], new_mult_k[0], n_attr, nj, nk);
Commons::get_new_multiplicities_and_update_attributes(
j, k, attributes, multiplicity, take_from_j[0], new_mult_k[0], n_attr, nj, nk
);
atomicAdd(
(unsigned long long int*)&breakup_rate[cid],
Expand All @@ -204,7 +204,9 @@
(unsigned long long int*)&breakup_rate_deficit[cid],
(unsigned long long int)(gamma_deficit * multiplicity[k])
);
Commons::breakup_fun2(j, k, nj[0], nk[0], attributes, multiplicity, take_from_j[0], n_attr);
Commons::round_multiplicities_to_ints_and_update_attributes(
j, k, nj[0], nk[0], attributes, multiplicity, take_from_j[0], n_attr
);
}
};
"""
Expand Down

0 comments on commit 82af010

Please sign in to comment.