|
2 | 2 | from typing import Any |
3 | 3 |
|
4 | 4 | import tiktoken |
5 | | -from fastapi import APIRouter, Depends, Header, HTTPException, Query |
| 5 | +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response |
6 | 6 | from mcp.server.fastmcp.prompts import base |
7 | 7 | from mcp.types import TextContent |
8 | 8 |
|
@@ -452,39 +452,28 @@ async def get_working_memory( |
452 | 452 | return WorkingMemoryResponse(**working_mem_data) |
453 | 453 |
|
454 | 454 |
|
455 | | -@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse) |
456 | | -async def put_working_memory( |
| 455 | +async def put_working_memory_core( |
457 | 456 | session_id: str, |
458 | 457 | memory: UpdateWorkingMemory, |
459 | 458 | background_tasks: HybridBackgroundTasks, |
460 | 459 | model_name: ModelNameLiteral | None = None, |
461 | 460 | context_window_max: int | None = None, |
462 | | - current_user: UserInfo = Depends(get_current_user), |
463 | | -): |
| 461 | +) -> WorkingMemoryResponse: |
464 | 462 | """ |
465 | | - Set working memory for a session. Replaces existing working memory. |
466 | | -
|
467 | | - The session_id comes from the URL path, not the request body. |
468 | | - If the token count exceeds the context window threshold, messages will be summarized |
469 | | - immediately and the updated memory state returned to the client. |
| 463 | + Core implementation of put_working_memory. |
470 | 464 |
|
471 | | - NOTE on context_percentage_* fields: |
472 | | - The response includes `context_percentage_total_used` and `context_percentage_until_summarization` |
473 | | - fields that show token usage. These fields will be `null` unless you provide either: |
474 | | - - `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`) |
475 | | - - `context_window_max` query parameter (e.g., `?context_window_max=500`) |
| 465 | + This function contains the business logic for setting working memory and can be |
| 466 | + called from both the REST API endpoint and MCP tools. |
476 | 467 |
|
477 | 468 | Args: |
478 | | - session_id: The session ID (from URL path) |
479 | | - memory: Working memory data to save (session_id not required in body) |
| 469 | + session_id: The session ID |
| 470 | + memory: Working memory data to save |
| 471 | + background_tasks: Background tasks handler |
480 | 472 | model_name: The client's LLM model name for context window determination |
481 | | - context_window_max: Direct specification of context window max tokens (overrides model_name) |
482 | | - background_tasks: DocketBackgroundTasks instance (injected automatically) |
| 473 | + context_window_max: Direct specification of context window max tokens |
483 | 474 |
|
484 | 475 | Returns: |
485 | | - Updated working memory (potentially with summary if tokens were condensed). |
486 | | - Includes context_percentage_total_used and context_percentage_until_summarization |
487 | | - if model information is provided. |
| 476 | + Updated working memory response |
488 | 477 | """ |
489 | 478 | redis = await get_redis_conn() |
490 | 479 |
|
@@ -557,6 +546,61 @@ async def put_working_memory( |
557 | 546 | return WorkingMemoryResponse(**updated_memory_data) |
558 | 547 |
|
559 | 548 |
|
| 549 | +@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse) |
| 550 | +async def put_working_memory( |
| 551 | + session_id: str, |
| 552 | + memory: UpdateWorkingMemory, |
| 553 | + background_tasks: HybridBackgroundTasks, |
| 554 | + response: Response, |
| 555 | + model_name: ModelNameLiteral | None = None, |
| 556 | + context_window_max: int | None = None, |
| 557 | + current_user: UserInfo = Depends(get_current_user), |
| 558 | +): |
| 559 | + """ |
| 560 | + Set working memory for a session. Replaces existing working memory. |
| 561 | +
|
| 562 | + The session_id comes from the URL path, not the request body. |
| 563 | + If the token count exceeds the context window threshold, messages will be summarized |
| 564 | + immediately and the updated memory state returned to the client. |
| 565 | +
|
| 566 | + NOTE on context_percentage_* fields: |
| 567 | + The response includes `context_percentage_total_used` and `context_percentage_until_summarization` |
| 568 | + fields that show token usage. These fields will be `null` unless you provide either: |
| 569 | + - `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`) |
| 570 | + - `context_window_max` query parameter (e.g., `?context_window_max=500`) |
| 571 | +
|
| 572 | + Args: |
| 573 | + session_id: The session ID (from URL path) |
| 574 | + memory: Working memory data to save (session_id not required in body) |
| 575 | + model_name: The client's LLM model name for context window determination |
| 576 | + context_window_max: Direct specification of context window max tokens (overrides model_name) |
| 577 | + background_tasks: DocketBackgroundTasks instance (injected automatically) |
| 578 | + response: FastAPI Response object for setting headers |
| 579 | +
|
| 580 | + Returns: |
| 581 | + Updated working memory (potentially with summary if tokens were condensed). |
| 582 | + Includes context_percentage_total_used and context_percentage_until_summarization |
| 583 | + if model information is provided. |
| 584 | + """ |
| 585 | + # Check if any messages are missing created_at timestamps and add deprecation header |
| 586 | + messages_missing_timestamp = any( |
| 587 | + not getattr(msg, "_created_at_was_provided", True) for msg in memory.messages |
| 588 | + ) |
| 589 | + if messages_missing_timestamp: |
| 590 | + response.headers["X-Deprecation-Warning"] = ( |
| 591 | + "messages[].created_at will become required in the next major version. " |
| 592 | + "Please provide timestamps for all messages." |
| 593 | + ) |
| 594 | + |
| 595 | + return await put_working_memory_core( |
| 596 | + session_id=session_id, |
| 597 | + memory=memory, |
| 598 | + background_tasks=background_tasks, |
| 599 | + model_name=model_name, |
| 600 | + context_window_max=context_window_max, |
| 601 | + ) |
| 602 | + |
| 603 | + |
560 | 604 | @router.delete("/v1/working-memory/{session_id}", response_model=AckResponse) |
561 | 605 | async def delete_working_memory( |
562 | 606 | session_id: str, |
|
0 commit comments