diff --git a/eris/paged_list.py b/eris/paged_list.py index e6429a3..2b23f91 100644 --- a/eris/paged_list.py +++ b/eris/paged_list.py @@ -3,11 +3,18 @@ import functools +import itertools import os import pickle import shutil +def batch(iter_, page_size): + for _, batch in itertools.groupby( + enumerate(iter_), lambda tuple_: tuple_[0] // page_size): + yield [value for index, value in batch] + + class PagedList: def __init__(self, list_, pages_dir, page_size, cache_size, exist_ok=False, @@ -16,19 +23,17 @@ class PagedList: self.page_size = page_size self.cache_size = cache_size self.open_func = open_func - self._len = len(list_) + self._len = 0 tmp_dir = pages_dir + ".tmp" if exist_ok: shutil.rmtree(tmp_dir, ignore_errors=True) shutil.rmtree(pages_dir, ignore_errors=True) os.makedirs(tmp_dir) - pages = ([[]] if len(list_) == 0 else - (list_[start:start+self.page_size] - for start in range(0, len(list_), self.page_size))) - for index, page in enumerate(pages): + for index, page in enumerate(batch(list_, page_size)): pickle_path = os.path.join(tmp_dir, str(index)) with self.open_func(pickle_path, "wb") as file_: pickle.dump(page, file_, protocol=pickle.HIGHEST_PROTOCOL) + self._len += len(page) self.page_count = index + 1 os.rename(tmp_dir, self.pages_dir) self._setup_page_cache() diff --git a/tests/paged_list_test.py b/tests/paged_list_test.py index 26f0f6e..bfefc5f 100755 --- a/tests/paged_list_test.py +++ b/tests/paged_list_test.py @@ -13,6 +13,10 @@ import eris.paged_list as paged_list class PagedListTestCase(unittest.TestCase): + def test_batch(self): + self.assertEqual(list(paged_list.batch(iter([3,4,5,6,7]), 2)), + [[3, 4], [5, 6], [7]]) + def test_getitem(self): with tempfile.TemporaryDirectory() as temp_dir: list_ = paged_list.PagedList([3, 4, 5, 6], temp_dir, 4, 2)