Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iteration within iteration #1427

Merged
merged 2 commits into from
Dec 5, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def get(self, *q_objs, **query):
except StopIteration:
return result

# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind()
message = u'%d items returned, instead of 1' % queryset.count()
raise queryset._document.MultipleObjectsReturned(message)
Expand Down
52 changes: 40 additions & 12 deletions mongoengine/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def __iter__(self):
in batches of ``ITER_CHUNK_SIZE``.

If ``self._has_more`` the cursor hasn't been exhausted so cache then
batch. Otherwise iterate the result_cache.
batch. Otherwise iterate the result_cache.
"""
self._iter = True

if self._has_more:
return self._iter_results()

Expand All @@ -42,10 +43,12 @@ def __len__(self):
"""
if self._len is not None:
return self._len

# Populate the result cache with *all* of the docs in the cursor
if self._has_more:
# populate the cache
list(self._iter_results())

# Cache the length of the complete result cache and return it
self._len = len(self._result_cache)
return self._len

Expand All @@ -64,18 +67,33 @@ def __repr__(self):
def _iter_results(self):
"""A generator for iterating over the result cache.

Also populates the cache if there are more possible results to yield.
Raises StopIteration when there are no more results"""
Also populates the cache if there are more possible results to
yield. Raises StopIteration when there are no more results.
"""
if self._result_cache is None:
self._result_cache = []

pos = 0
while True:
upper = len(self._result_cache)
while pos < upper:

# For all positions lower than the length of the current result
# cache, serve the docs straight from the cache w/o hitting the
# database.
# XXX it's VERY important to compute the len within the `while`
# condition because the result cache might expand mid-iteration
# (e.g. if we call len(qs) inside a loop that iterates over the
# queryset). Fortunately len(list) is O(1) in Python, so this
# doesn't cause performance issues.
while pos < len(self._result_cache):
yield self._result_cache[pos]
pos += 1

# Raise StopIteration if we already established there were no more
# docs in the db cursor.
if not self._has_more:
raise StopIteration

# Otherwise, populate more of the cache and repeat.
if len(self._result_cache) <= pos:
self._populate_cache()

Expand All @@ -86,12 +104,22 @@ def _populate_cache(self):
"""
if self._result_cache is None:
self._result_cache = []
if self._has_more:
try:
for i in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
except StopIteration:
self._has_more = False

# Skip populating the cache if we already established there are no
# more docs to pull from the database.
if not self._has_more:
return

# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for i in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
# information in other places.
self._has_more = False

def count(self, with_limit_and_skip=False):
"""Count the selected elements in the query.
Expand Down
50 changes: 50 additions & 0 deletions tests/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4890,6 +4890,56 @@ class Doc(Document):

self.assertEqual(1, Doc.objects(item__type__="axe").count())

def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't
stop paging.
"""
class Data(Document):
pass

for i in xrange(300):
Data().save()

records = Data.objects.limit(250)

# This should pull all 250 docs from mongo and populate the result
# cache
len(records)

# Assert that iterating over documents in the qs touches every
# document even if we call len(qs) midway through the iteration.
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)

# Assert the same behavior is true even if we didn't pre-populate the
# result cache.
records = Data.objects.limit(250)
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)

def test_iteration_within_iteration(self):
"""You should be able to reliably iterate over all the documents
in a given queryset even if there are multiple iterations of it
happening at the same time.
"""
class Data(Document):
pass

for i in xrange(300):
Data().save()

qs = Data.objects.limit(250)
for i, doc in enumerate(qs):
for j, doc2 in enumerate(qs):
pass

self.assertEqual(i, 249)
self.assertEqual(j, 249)


if __name__ == '__main__':
unittest.main()