|
10 | 10 | from sqlalchemy import create_engine |
11 | 11 |
|
12 | 12 | from pgcontents import PostgresContentsManager |
13 | | -from pgcontents.crypto import FernetEncryption, NoEncryption |
| 13 | +from pgcontents.crypto import ( |
| 14 | + FernetEncryption, |
| 15 | + NoEncryption, |
| 16 | + single_password_crypto_factory, |
| 17 | +) |
| 18 | +from pgcontents.query import generate_files, generate_checkpoints |
14 | 19 | from pgcontents.utils.ipycompat import new_markdown_cell |
15 | 20 |
|
16 | 21 | from .utils import ( |
@@ -177,3 +182,214 @@ def check_reencryption(old, new): |
177 | 182 | # crypto manager. |
178 | 183 | unencrypt_all_users(engine, crypto2_factory, logger) |
179 | 184 | check_reencryption(manager2, no_crypto_manager) |
| 185 | + |
| 186 | + |
| 187 | +class TestGenerateNotebooks(TestCase): |
| 188 | + |
| 189 | + def setUp(self): |
| 190 | + remigrate_test_schema() |
| 191 | + self.db_url = TEST_DB_URL |
| 192 | + self.engine = create_engine(self.db_url) |
| 193 | + encryption_pw = u'foobar' |
| 194 | + self.crypto_factory = single_password_crypto_factory(encryption_pw) |
| 195 | + |
| 196 | + def tearDown(self): |
| 197 | + clear_test_db() |
| 198 | + |
| 199 | + def populate_users(self, user_ids): |
| 200 | + """ |
| 201 | + Create a `PostgresContentsManager` and notebooks for each user. |
| 202 | + """ |
| 203 | + def encrypted_pgmanager(user_id): |
| 204 | + return PostgresContentsManager( |
| 205 | + user_id=user_id, |
| 206 | + db_url=self.db_url, |
| 207 | + crypto=self.crypto_factory(user_id), |
| 208 | + create_user_on_startup=True, |
| 209 | + ) |
| 210 | + managers = {user_id: encrypted_pgmanager(user_id) |
| 211 | + for user_id in user_ids} |
| 212 | + paths = {user_id: populate(managers[user_id]) for user_id in user_ids} |
| 213 | + return (managers, paths) |
| 214 | + |
| 215 | + def test_generate_files(self): |
| 216 | + """ |
| 217 | + Create files for three users; try fetching them using `generate_files`. |
| 218 | + """ |
| 219 | + user_ids = ['test_generate_files0', |
| 220 | + 'test_generate_files1', |
| 221 | + 'test_generate_files2'] |
| 222 | + (managers, paths) = self.populate_users(user_ids) |
| 223 | + |
| 224 | + def get_file_dt(user_id, idx): |
| 225 | + path = paths[user_id][idx] |
| 226 | + return managers[user_id].get(path, content=False)['last_modified'] |
| 227 | + |
| 228 | + # Find a split datetime midway through each user's list of files |
| 229 | + split_idx = len(paths[user_ids[0]]) // 2 |
| 230 | + split_dts = [get_file_dt(user_id, split_idx) for user_id in user_ids] |
| 231 | + |
| 232 | + def check_call(kwargs, expect_files_by_user): |
| 233 | + """ |
| 234 | + Call `generate_files`; check that all expected files are found, |
| 235 | + with the correct content. |
| 236 | + """ |
| 237 | + file_record = {user_id: [] for user_id in expect_files_by_user} |
| 238 | + for result in generate_files(self.engine, self.crypto_factory, |
| 239 | + **kwargs): |
| 240 | + manager = managers[result['user_id']] |
| 241 | + |
| 242 | + # This recreates functionality from |
| 243 | + # `manager._notebook_model_from_db` to match with the model |
| 244 | + # returned by `manager.get`. |
| 245 | + nb = result['content'] |
| 246 | + manager.mark_trusted_cells(nb, result['path']) |
| 247 | + |
| 248 | + # Check that the content returned by the pgcontents manager |
| 249 | + # matches that returned by `generate_files` |
| 250 | + self.assertEqual(nb, manager.get(result['path'])['content']) |
| 251 | + |
| 252 | + file_record[result['user_id']].append(result['path']) |
| 253 | + |
| 254 | + # Make sure all files were found |
| 255 | + for user_id in expect_files_by_user: |
| 256 | + self.assertEqual(sorted(file_record[user_id]), |
| 257 | + sorted(expect_files_by_user[user_id])) |
| 258 | + |
| 259 | + # Expect all files given no `min_dt`/`max_dt` |
| 260 | + check_call({}, paths) |
| 261 | + |
| 262 | + # `min_dt` is in the middle of 1's files; we get the latter half of 1's |
| 263 | + # and all of 2's |
| 264 | + check_call({'min_dt': split_dts[1]}, |
| 265 | + { |
| 266 | + user_ids[0]: [], |
| 267 | + user_ids[1]: paths[user_ids[1]][split_idx:], |
| 268 | + user_ids[2]: paths[user_ids[2]], |
| 269 | + }) |
| 270 | + |
| 271 | + # `max_dt` is in the middle of 1's files; we get all of 0's and the |
| 272 | + # beginning half of 1's |
| 273 | + check_call({'max_dt': split_dts[1]}, |
| 274 | + { |
| 275 | + user_ids[0]: paths[user_ids[0]], |
| 276 | + user_ids[1]: paths[user_ids[1]][:split_idx], |
| 277 | + user_ids[2]: [], |
| 278 | + }) |
| 279 | + |
| 280 | + # `min_dt` is in the middle of 0's files cutting off 0's beginning half |
| 281 | + # `max_dt` is in the middle of 2's files cutting off 2's latter half |
| 282 | + check_call({'min_dt': split_dts[0], 'max_dt': split_dts[2]}, |
| 283 | + { |
| 284 | + user_ids[0]: paths[user_ids[0]][split_idx:], |
| 285 | + user_ids[1]: paths[user_ids[1]], |
| 286 | + user_ids[2]: paths[user_ids[2]][:split_idx], |
| 287 | + }) |
| 288 | + |
| 289 | + def test_generate_checkpoints(self): |
| 290 | + """ |
| 291 | + Create checkpoints in three stages; try fetching them with |
| 292 | + `generate_checkpoints`. |
| 293 | + """ |
| 294 | + user_ids = ['test_generate_checkpoints0', |
| 295 | + 'test_generate_checkpoints1', |
| 296 | + 'test_generate_checkpoints2'] |
| 297 | + (managers, paths) = self.populate_users(user_ids) |
| 298 | + |
| 299 | + def update_content(user_id, path, text): |
| 300 | + """ |
| 301 | + Add a Markdown cell and save the notebook. |
| 302 | +
|
| 303 | + Returns the new notebook content. |
| 304 | + """ |
| 305 | + manager = managers[user_id] |
| 306 | + model = manager.get(path) |
| 307 | + model['content'].cells.append( |
| 308 | + new_markdown_cell(text + ' on path: ' + path) |
| 309 | + ) |
| 310 | + manager.save(model, path) |
| 311 | + return manager.get(path)['content'] |
| 312 | + |
| 313 | + # Each of the next three steps creates a checkpoint for each notebook |
| 314 | + # and stores the notebook content in a dict, keyed by the user id, |
| 315 | + # the path, and the datetime of the new checkpoint. |
| 316 | + |
| 317 | + # Begin by making a checkpoint for the original notebook content. |
| 318 | + beginning_content = {} |
| 319 | + for user_id in user_ids: |
| 320 | + for path in paths[user_id]: |
| 321 | + content = managers[user_id].get(path)['content'] |
| 322 | + dt = managers[user_id].create_checkpoint(path)['last_modified'] |
| 323 | + beginning_content[user_id, path, dt] = content |
| 324 | + |
| 325 | + # Update each notebook and make a new checkpoint. |
| 326 | + middle_content = {} |
| 327 | + middle_min_dt = None |
| 328 | + for user_id in user_ids: |
| 329 | + for path in paths[user_id]: |
| 330 | + content = update_content(user_id, path, '1st addition') |
| 331 | + dt = managers[user_id].create_checkpoint(path)['last_modified'] |
| 332 | + middle_content[user_id, path, dt] = content |
| 333 | + if middle_min_dt is None: |
| 334 | + middle_min_dt = dt |
| 335 | + |
| 336 | + # Update each notebook again and make another checkpoint. |
| 337 | + end_content = {} |
| 338 | + end_min_dt = None |
| 339 | + for user_id in user_ids: |
| 340 | + for path in paths[user_id]: |
| 341 | + content = update_content(user_id, path, '2nd addition') |
| 342 | + dt = managers[user_id].create_checkpoint(path)['last_modified'] |
| 343 | + end_content[user_id, path, dt] = content |
| 344 | + if end_min_dt is None: |
| 345 | + end_min_dt = dt |
| 346 | + |
| 347 | + def merge_dicts(*args): |
| 348 | + result = {} |
| 349 | + for d in args: |
| 350 | + result.update(d) |
| 351 | + return result |
| 352 | + |
| 353 | + def check_call(kwargs, expect_checkpoints_content): |
| 354 | + """ |
| 355 | + Call `generate_checkpoints`; check that all expected checkpoints |
| 356 | + are found, with the correct content. |
| 357 | + """ |
| 358 | + expect_checkpoints = expect_checkpoints_content.keys() |
| 359 | + checkpoint_record = [] |
| 360 | + for result in generate_checkpoints(self.engine, |
| 361 | + self.crypto_factory, **kwargs): |
| 362 | + manager = managers[result['user_id']] |
| 363 | + |
| 364 | + # This recreates functionality from |
| 365 | + # `manager._notebook_model_from_db` to match with the model |
| 366 | + # returned by `manager.get`. |
| 367 | + nb = result['content'] |
| 368 | + manager.mark_trusted_cells(nb, result['path']) |
| 369 | + |
| 370 | + # Check that the checkpoint content matches what's expected |
| 371 | + key = (result['user_id'], result['path'], |
| 372 | + result['last_modified']) |
| 373 | + self.assertEqual(nb, expect_checkpoints_content[key]) |
| 374 | + |
| 375 | + checkpoint_record.append(key) |
| 376 | + |
| 377 | + # Make sure all checkpoints were found |
| 378 | + self.assertEqual(sorted(checkpoint_record), |
| 379 | + sorted(expect_checkpoints)) |
| 380 | + |
| 381 | + # No `min_dt`/`max_dt` |
| 382 | + check_call({}, merge_dicts(beginning_content, |
| 383 | + middle_content, end_content)) |
| 384 | + |
| 385 | + # `min_dt` cuts off `beginning_content` checkpoints |
| 386 | + check_call({'min_dt': middle_min_dt}, |
| 387 | + merge_dicts(middle_content, end_content)) |
| 388 | + |
| 389 | + # `max_dt` cuts off `end_content` checkpoints |
| 390 | + check_call({'max_dt': end_min_dt}, |
| 391 | + merge_dicts(beginning_content, middle_content)) |
| 392 | + |
| 393 | + # `min_dt` and `max_dt` together isolate `middle_content` |
| 394 | + check_call({'min_dt': middle_min_dt, 'max_dt': end_min_dt}, |
| 395 | + middle_content) |
0 commit comments