Skip to content

Commit

Permalink
Merge pull request #2783 from tdadela/refactor-users-dispatcher-user-gen
Browse files Browse the repository at this point in the history
Simplify fixed_count Users generation in UsersDispatcher._user_gen
  • Loading branch information
cyberw authored Aug 2, 2024
2 parents 5209718 + 7f6df03 commit a30782a
Showing 1 changed file with 27 additions and 48 deletions.
75 changes: 27 additions & 48 deletions locust/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import time
from collections import defaultdict
from collections.abc import Generator, Iterator
from collections.abc import Iterator
from heapq import heapify, heapreplace
from math import log2
from operator import attrgetter
Expand All @@ -17,41 +17,31 @@
from locust import User
from locust.runners import WorkerNode

from collections.abc import Generator, Iterable
from typing import TypeVar

# To profile line-by-line, uncomment the code below (i.e. `import line_profiler ...`) and
# place `@profile` on the functions/methods you wish to profile. Then, in the unit test you are
# running, use `from locust.dispatch import profile; profile.print_stats()` at the end of the unit test.
# Placing it in a `finally` block is recommended.
# import line_profiler
#
# profile = line_profiler.LineProfiler()
T = TypeVar("T")


def _kl_generator(users: list[tuple[type[User], float]]) -> Iterator[str | None]:
def _kl_generator(users: Iterable[tuple[T, float]]) -> Iterator[T | None]:
"""Generator based on Kullback-Leibler divergence
For example, given users A, B with weights 5 and 1 respectively,
this algorithm will yield AAABAAAAABAA.
"""
if not users:
heap = [(x * log2(x / (x + 1.0)), x + 1.0, x, name) for name, x in users]
if not heap:
while True:
yield None

names = [u[0].__name__ for u in users]
weights = [u[1] for u in users]
generated = weights.copy()

heap = [(x * log2(x / (x + 1.0)), i) for i, x in enumerate(generated)]
heapify(heap)

while True:
i = heap[0][1] # choose element which choosing minimizes divergence the most
yield names[i]
generated[i] += 1.0
x = generated[i]
kl_diff = weights[i] * log2(x / (x + 1.0))
_, x, weight, name = heap[0]
# (divergence diff, number of generated elements + initial weight, initial weight, name) = heap[0]
yield name
kl_diff = weight * log2(x / (x + 1.0))
# calculate how much choosing element i for (x + 1)th time decreases divergence
heapreplace(heap, (kl_diff, i))
heapreplace(heap, (kl_diff, x + 1.0, weight, name))


class UsersDispatcher(Iterator):
Expand Down Expand Up @@ -378,35 +368,24 @@ def _distribute_users(
return users_on_workers, user_gen, worker_gen, active_users

def _user_gen(self) -> Iterator[str | None]:
fixed_users = {u.__name__: u for u in self._user_classes if u.fixed_count}

fixed_users_gen = _kl_generator([(u, u.fixed_count) for u in fixed_users.values()])
weighted_users_gen = _kl_generator([(u, u.weight) for u in self._user_classes if not u.fixed_count])
weighted_users_gen = _kl_generator((u.__name__, u.weight) for u in self._user_classes if not u.fixed_count)

# Spawn users
while True:
if self._try_dispatch_fixed:
if self._try_dispatch_fixed: # Fixed_count users are spawned before weight users.
# Some peoples treat this implementation detail as a feature.
self._try_dispatch_fixed = False
current_fixed_users_count = {u: self._get_user_current_count(u) for u in fixed_users}
spawned_classes: set[str] = set()
while len(spawned_classes) != len(fixed_users):
user_name: str | None = next(fixed_users_gen)
if not user_name:
break

if current_fixed_users_count[user_name] < fixed_users[user_name].fixed_count:
current_fixed_users_count[user_name] += 1
yield user_name

# 'self._try_dispatch_fixed' was changed outhere, we have to recalculate current count
if self._try_dispatch_fixed:
current_fixed_users_count = {u: self._get_user_current_count(u) for u in fixed_users}
spawned_classes.clear()
self._try_dispatch_fixed = False
else:
spawned_classes.add(user_name)

yield next(weighted_users_gen)
fixed_users_missing = [
(u.__name__, miss)
for u in self._user_classes
if u.fixed_count and (miss := u.fixed_count - self._get_user_current_count(u.__name__)) > 0
]
total_miss = sum(miss for _, miss in fixed_users_missing)
fixed_users_gen = _kl_generator(fixed_users_missing) # type: ignore[arg-type]
# https://mypy.readthedocs.io/en/stable/common_issues.html#variance
for _ in range(total_miss):
yield next(fixed_users_gen)
else:
yield next(weighted_users_gen)

@staticmethod
def _fast_users_on_workers_copy(users_on_workers: dict[str, dict[str, int]]) -> dict[str, dict[str, int]]:
Expand Down

0 comments on commit a30782a

Please sign in to comment.