Skip to content

Commit 5382165

Browse files
authored
ENH: Add API to yield notebooks in bulk. (#30)
Added functions under `pgcontents.query`, separate ones to generate current files and remote checkpoints. Also added tests for these new functions.
1 parent 488507f commit 5382165

File tree

2 files changed

+322
-2
lines changed

2 files changed

+322
-2
lines changed

pgcontents/query.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from .api_utils import (
1919
from_api_dirname,
2020
from_api_filename,
21+
reads_base64,
2122
split_api_filepath,
23+
to_api_path,
2224
)
2325
from .constants import UNLIMITED
2426
from .db_utils import (
@@ -547,6 +549,36 @@ def save_file(db, user_id, path, content, encrypt_func, max_size_bytes):
547549
return res
548550

549551

552+
def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
553+
"""
554+
Create a generator of decrypted files.
555+
556+
This function selects all current notebooks (optionally, falling within a
557+
datetime range), decrypts them, and returns a generator yielding dicts,
558+
each containing a decoded notebook and metadata including the user,
559+
filepath, and timestamp.
560+
561+
Parameters
562+
----------
563+
engine : SQLAlchemy.engine
564+
Engine encapsulating database connections.
565+
crypto_factory : function[str -> Any]
566+
A function from user_id to an object providing the interface required
567+
by PostgresContentsManager.crypto. Results of this will be used for
568+
decryption of the selected notebooks.
569+
min_dt : datetime.datetime, optional
570+
Minimum last modified datetime at which a file will be included.
571+
max_dt : datetime.datetime, optional
572+
Last modified datetime at and after which a file will be excluded.
573+
"""
574+
where_conds = []
575+
if min_dt is not None:
576+
where_conds.append(files.c.created_at >= min_dt)
577+
if max_dt is not None:
578+
where_conds.append(files.c.created_at < max_dt)
579+
return _generate_notebooks(files, engine, where_conds, crypto_factory)
580+
581+
550582
# =======================================
551583
# Checkpoints (PostgresCheckpoints)
552584
# =======================================
@@ -700,6 +732,79 @@ def purge_remote_checkpoints(db, user_id):
700732
)
701733

702734

735+
def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
736+
"""
737+
Create a generator of decrypted remote checkpoints.
738+
739+
This function selects all notebook checkpoints (optionally, falling within
740+
a datetime range), decrypts them, and returns a generator yielding dicts,
741+
each containing a decoded notebook and metadata including the user,
742+
filepath, and timestamp.
743+
744+
Parameters
745+
----------
746+
engine : SQLAlchemy.engine
747+
Engine encapsulating database connections.
748+
crypto_factory : function[str -> Any]
749+
A function from user_id to an object providing the interface required
750+
by PostgresContentsManager.crypto. Results of this will be used for
751+
decryption of the selected notebooks.
752+
min_dt : datetime.datetime, optional
753+
Minimum last modified datetime at which a file will be included.
754+
max_dt : datetime.datetime, optional
755+
Last modified datetime at and after which a file will be excluded.
756+
"""
757+
where_conds = []
758+
if min_dt is not None:
759+
where_conds.append(remote_checkpoints.c.last_modified >= min_dt)
760+
if max_dt is not None:
761+
where_conds.append(remote_checkpoints.c.last_modified < max_dt)
762+
return _generate_notebooks(remote_checkpoints,
763+
engine, where_conds, crypto_factory)
764+
765+
766+
# ====================
767+
# Files or Checkpoints
768+
# ====================
769+
def _generate_notebooks(table, engine, where_conds, crypto_factory):
770+
"""
771+
See docstrings for `generate_files` and `generate_checkpoints`.
772+
`where_conds` should be a list of SQLAlchemy expressions, which are used as
773+
the conditions for WHERE clauses on the SELECT queries to the database.
774+
"""
775+
# Query for notebooks satisfying the conditions.
776+
query = select([table]).order_by(table.c.user_id)
777+
for cond in where_conds:
778+
query = query.where(cond)
779+
result = engine.execute(query)
780+
781+
# Decrypt each notebook and yield the result.
782+
last_user_id = None
783+
for nb_row in result:
784+
# The decrypt function depends on the user, so if the user is the same
785+
# then the decrypt function carries over.
786+
if nb_row['user_id'] != last_user_id:
787+
decrypt_func = crypto_factory(nb_row['user_id']).decrypt
788+
last_user_id = nb_row['user_id']
789+
790+
nb_dict = to_dict_with_content(table.c, nb_row, decrypt_func)
791+
if table is files:
792+
# Correct for files schema differing somewhat from checkpoints.
793+
nb_dict['path'] = nb_dict['parent_name'] + nb_dict['name']
794+
nb_dict['last_modified'] = nb_dict['created_at']
795+
796+
# For 'content', we use `reads_base64` directly. If the db content
797+
# format is changed from base64, the decoding should be changed
798+
# here as well.
799+
yield {
800+
'id': nb_dict['id'],
801+
'user_id': nb_dict['user_id'],
802+
'path': to_api_path(nb_dict['path']),
803+
'last_modified': nb_dict['last_modified'],
804+
'content': reads_base64(nb_dict['content']),
805+
}
806+
807+
703808
##########################
704809
# Reencryption Utilities #
705810
##########################
@@ -776,7 +881,6 @@ def reencrypt_user_content(engine,
776881
# file-reencryption process, but we might not see that checkpoint here,
777882
# which means that we would never update the content of that checkpoint
778883
# to the new encryption key.
779-
780884
logger.info("Re-encrypting files for %s", user_id)
781885
for (file_id,) in select_file_ids(db, user_id):
782886
reencrypt_row_content(

pgcontents/tests/test_synchronization.py

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from sqlalchemy import create_engine
1111

1212
from pgcontents import PostgresContentsManager
13-
from pgcontents.crypto import FernetEncryption, NoEncryption
13+
from pgcontents.crypto import (
14+
FernetEncryption,
15+
NoEncryption,
16+
single_password_crypto_factory,
17+
)
18+
from pgcontents.query import generate_files, generate_checkpoints
1419
from pgcontents.utils.ipycompat import new_markdown_cell
1520

1621
from .utils import (
@@ -177,3 +182,214 @@ def check_reencryption(old, new):
177182
# crypto manager.
178183
unencrypt_all_users(engine, crypto2_factory, logger)
179184
check_reencryption(manager2, no_crypto_manager)
185+
186+
187+
class TestGenerateNotebooks(TestCase):
188+
189+
def setUp(self):
190+
remigrate_test_schema()
191+
self.db_url = TEST_DB_URL
192+
self.engine = create_engine(self.db_url)
193+
encryption_pw = u'foobar'
194+
self.crypto_factory = single_password_crypto_factory(encryption_pw)
195+
196+
def tearDown(self):
197+
clear_test_db()
198+
199+
def populate_users(self, user_ids):
200+
"""
201+
Create a `PostgresContentsManager` and notebooks for each user.
202+
"""
203+
def encrypted_pgmanager(user_id):
204+
return PostgresContentsManager(
205+
user_id=user_id,
206+
db_url=self.db_url,
207+
crypto=self.crypto_factory(user_id),
208+
create_user_on_startup=True,
209+
)
210+
managers = {user_id: encrypted_pgmanager(user_id)
211+
for user_id in user_ids}
212+
paths = {user_id: populate(managers[user_id]) for user_id in user_ids}
213+
return (managers, paths)
214+
215+
def test_generate_files(self):
216+
"""
217+
Create files for three users; try fetching them using `generate_files`.
218+
"""
219+
user_ids = ['test_generate_files0',
220+
'test_generate_files1',
221+
'test_generate_files2']
222+
(managers, paths) = self.populate_users(user_ids)
223+
224+
def get_file_dt(user_id, idx):
225+
path = paths[user_id][idx]
226+
return managers[user_id].get(path, content=False)['last_modified']
227+
228+
# Find a split datetime midway through each user's list of files
229+
split_idx = len(paths[user_ids[0]]) // 2
230+
split_dts = [get_file_dt(user_id, split_idx) for user_id in user_ids]
231+
232+
def check_call(kwargs, expect_files_by_user):
233+
"""
234+
Call `generate_files`; check that all expected files are found,
235+
with the correct content.
236+
"""
237+
file_record = {user_id: [] for user_id in expect_files_by_user}
238+
for result in generate_files(self.engine, self.crypto_factory,
239+
**kwargs):
240+
manager = managers[result['user_id']]
241+
242+
# This recreates functionality from
243+
# `manager._notebook_model_from_db` to match with the model
244+
# returned by `manager.get`.
245+
nb = result['content']
246+
manager.mark_trusted_cells(nb, result['path'])
247+
248+
# Check that the content returned by the pgcontents manager
249+
# matches that returned by `generate_files`
250+
self.assertEqual(nb, manager.get(result['path'])['content'])
251+
252+
file_record[result['user_id']].append(result['path'])
253+
254+
# Make sure all files were found
255+
for user_id in expect_files_by_user:
256+
self.assertEqual(sorted(file_record[user_id]),
257+
sorted(expect_files_by_user[user_id]))
258+
259+
# Expect all files given no `min_dt`/`max_dt`
260+
check_call({}, paths)
261+
262+
# `min_dt` is in the middle of 1's files; we get the latter half of 1's
263+
# and all of 2's
264+
check_call({'min_dt': split_dts[1]},
265+
{
266+
user_ids[0]: [],
267+
user_ids[1]: paths[user_ids[1]][split_idx:],
268+
user_ids[2]: paths[user_ids[2]],
269+
})
270+
271+
# `max_dt` is in the middle of 1's files; we get all of 0's and the
272+
# beginning half of 1's
273+
check_call({'max_dt': split_dts[1]},
274+
{
275+
user_ids[0]: paths[user_ids[0]],
276+
user_ids[1]: paths[user_ids[1]][:split_idx],
277+
user_ids[2]: [],
278+
})
279+
280+
# `min_dt` is in the middle of 0's files cutting off 0's beginning half
281+
# `max_dt` is in the middle of 2's files cutting off 2's latter half
282+
check_call({'min_dt': split_dts[0], 'max_dt': split_dts[2]},
283+
{
284+
user_ids[0]: paths[user_ids[0]][split_idx:],
285+
user_ids[1]: paths[user_ids[1]],
286+
user_ids[2]: paths[user_ids[2]][:split_idx],
287+
})
288+
289+
def test_generate_checkpoints(self):
290+
"""
291+
Create checkpoints in three stages; try fetching them with
292+
`generate_checkpoints`.
293+
"""
294+
user_ids = ['test_generate_checkpoints0',
295+
'test_generate_checkpoints1',
296+
'test_generate_checkpoints2']
297+
(managers, paths) = self.populate_users(user_ids)
298+
299+
def update_content(user_id, path, text):
300+
"""
301+
Add a Markdown cell and save the notebook.
302+
303+
Returns the new notebook content.
304+
"""
305+
manager = managers[user_id]
306+
model = manager.get(path)
307+
model['content'].cells.append(
308+
new_markdown_cell(text + ' on path: ' + path)
309+
)
310+
manager.save(model, path)
311+
return manager.get(path)['content']
312+
313+
# Each of the next three steps creates a checkpoint for each notebook
314+
# and stores the notebook content in a dict, keyed by the user id,
315+
# the path, and the datetime of the new checkpoint.
316+
317+
# Begin by making a checkpoint for the original notebook content.
318+
beginning_content = {}
319+
for user_id in user_ids:
320+
for path in paths[user_id]:
321+
content = managers[user_id].get(path)['content']
322+
dt = managers[user_id].create_checkpoint(path)['last_modified']
323+
beginning_content[user_id, path, dt] = content
324+
325+
# Update each notebook and make a new checkpoint.
326+
middle_content = {}
327+
middle_min_dt = None
328+
for user_id in user_ids:
329+
for path in paths[user_id]:
330+
content = update_content(user_id, path, '1st addition')
331+
dt = managers[user_id].create_checkpoint(path)['last_modified']
332+
middle_content[user_id, path, dt] = content
333+
if middle_min_dt is None:
334+
middle_min_dt = dt
335+
336+
# Update each notebook again and make another checkpoint.
337+
end_content = {}
338+
end_min_dt = None
339+
for user_id in user_ids:
340+
for path in paths[user_id]:
341+
content = update_content(user_id, path, '2nd addition')
342+
dt = managers[user_id].create_checkpoint(path)['last_modified']
343+
end_content[user_id, path, dt] = content
344+
if end_min_dt is None:
345+
end_min_dt = dt
346+
347+
def merge_dicts(*args):
348+
result = {}
349+
for d in args:
350+
result.update(d)
351+
return result
352+
353+
def check_call(kwargs, expect_checkpoints_content):
354+
"""
355+
Call `generate_checkpoints`; check that all expected checkpoints
356+
are found, with the correct content.
357+
"""
358+
expect_checkpoints = expect_checkpoints_content.keys()
359+
checkpoint_record = []
360+
for result in generate_checkpoints(self.engine,
361+
self.crypto_factory, **kwargs):
362+
manager = managers[result['user_id']]
363+
364+
# This recreates functionality from
365+
# `manager._notebook_model_from_db` to match with the model
366+
# returned by `manager.get`.
367+
nb = result['content']
368+
manager.mark_trusted_cells(nb, result['path'])
369+
370+
# Check that the checkpoint content matches what's expected
371+
key = (result['user_id'], result['path'],
372+
result['last_modified'])
373+
self.assertEqual(nb, expect_checkpoints_content[key])
374+
375+
checkpoint_record.append(key)
376+
377+
# Make sure all checkpoints were found
378+
self.assertEqual(sorted(checkpoint_record),
379+
sorted(expect_checkpoints))
380+
381+
# No `min_dt`/`max_dt`
382+
check_call({}, merge_dicts(beginning_content,
383+
middle_content, end_content))
384+
385+
# `min_dt` cuts off `beginning_content` checkpoints
386+
check_call({'min_dt': middle_min_dt},
387+
merge_dicts(middle_content, end_content))
388+
389+
# `max_dt` cuts off `end_content` checkpoints
390+
check_call({'max_dt': end_min_dt},
391+
merge_dicts(beginning_content, middle_content))
392+
393+
# `min_dt` and `max_dt` together isolate `middle_content`
394+
check_call({'min_dt': middle_min_dt, 'max_dt': end_min_dt},
395+
middle_content)

0 commit comments

Comments
 (0)