|
2 | 2 | from typing import Any |
3 | 3 |
|
4 | 4 | import tiktoken |
5 | | -from fastapi import APIRouter, Depends, Header, HTTPException, Query |
| 5 | +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response |
6 | 6 | from mcp.server.fastmcp.prompts import base |
7 | 7 | from mcp.types import TextContent |
8 | 8 | from ulid import ULID |
@@ -470,39 +470,28 @@ async def get_working_memory( |
470 | 470 | return WorkingMemoryResponse(**working_mem_data) |
471 | 471 |
|
472 | 472 |
|
473 | | -@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse) |
474 | | -async def put_working_memory( |
| 473 | +async def put_working_memory_core( |
475 | 474 | session_id: str, |
476 | 475 | memory: UpdateWorkingMemory, |
477 | 476 | background_tasks: HybridBackgroundTasks, |
478 | 477 | model_name: ModelNameLiteral | None = None, |
479 | 478 | context_window_max: int | None = None, |
480 | | - current_user: UserInfo = Depends(get_current_user), |
481 | | -): |
| 479 | +) -> WorkingMemoryResponse: |
482 | 480 | """ |
483 | | - Set working memory for a session. Replaces existing working memory. |
484 | | -
|
485 | | - The session_id comes from the URL path, not the request body. |
486 | | - If the token count exceeds the context window threshold, messages will be summarized |
487 | | - immediately and the updated memory state returned to the client. |
| 481 | + Core implementation of put_working_memory. |
488 | 482 |
|
489 | | - NOTE on context_percentage_* fields: |
490 | | - The response includes `context_percentage_total_used` and `context_percentage_until_summarization` |
491 | | - fields that show token usage. These fields will be `null` unless you provide either: |
492 | | - - `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`) |
493 | | - - `context_window_max` query parameter (e.g., `?context_window_max=500`) |
| 483 | + This function contains the business logic for setting working memory and can be |
| 484 | + called from both the REST API endpoint and MCP tools. |
494 | 485 |
|
495 | 486 | Args: |
496 | | - session_id: The session ID (from URL path) |
497 | | - memory: Working memory data to save (session_id not required in body) |
| 487 | + session_id: The session ID |
| 488 | + memory: Working memory data to save |
| 489 | + background_tasks: Background tasks handler |
498 | 490 | model_name: The client's LLM model name for context window determination |
499 | | - context_window_max: Direct specification of context window max tokens (overrides model_name) |
500 | | - background_tasks: DocketBackgroundTasks instance (injected automatically) |
| 491 | + context_window_max: Direct specification of context window max tokens |
501 | 492 |
|
502 | 493 | Returns: |
503 | | - Updated working memory (potentially with summary if tokens were condensed). |
504 | | - Includes context_percentage_total_used and context_percentage_until_summarization |
505 | | - if model information is provided. |
| 494 | + Updated working memory response |
506 | 495 | """ |
507 | 496 | redis = await get_redis_conn() |
508 | 497 |
|
@@ -575,6 +564,61 @@ async def put_working_memory( |
575 | 564 | return WorkingMemoryResponse(**updated_memory_data) |
576 | 565 |
|
577 | 566 |
|
| 567 | +@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse) |
| 568 | +async def put_working_memory( |
| 569 | + session_id: str, |
| 570 | + memory: UpdateWorkingMemory, |
| 571 | + background_tasks: HybridBackgroundTasks, |
| 572 | + response: Response, |
| 573 | + model_name: ModelNameLiteral | None = None, |
| 574 | + context_window_max: int | None = None, |
| 575 | + current_user: UserInfo = Depends(get_current_user), |
| 576 | +): |
| 577 | + """ |
| 578 | + Set working memory for a session. Replaces existing working memory. |
| 579 | +
|
| 580 | + The session_id comes from the URL path, not the request body. |
| 581 | + If the token count exceeds the context window threshold, messages will be summarized |
| 582 | + immediately and the updated memory state returned to the client. |
| 583 | +
|
| 584 | + NOTE on context_percentage_* fields: |
| 585 | + The response includes `context_percentage_total_used` and `context_percentage_until_summarization` |
| 586 | + fields that show token usage. These fields will be `null` unless you provide either: |
| 587 | + - `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`) |
| 588 | + - `context_window_max` query parameter (e.g., `?context_window_max=500`) |
| 589 | +
|
| 590 | + Args: |
| 591 | + session_id: The session ID (from URL path) |
| 592 | + memory: Working memory data to save (session_id not required in body) |
| 593 | + model_name: The client's LLM model name for context window determination |
| 594 | + context_window_max: Direct specification of context window max tokens (overrides model_name) |
| 595 | + background_tasks: DocketBackgroundTasks instance (injected automatically) |
| 596 | + response: FastAPI Response object for setting headers |
| 597 | +
|
| 598 | + Returns: |
| 599 | + Updated working memory (potentially with summary if tokens were condensed). |
| 600 | + Includes context_percentage_total_used and context_percentage_until_summarization |
| 601 | + if model information is provided. |
| 602 | + """ |
| 603 | + # Check if any messages are missing created_at timestamps and add deprecation header |
| 604 | + messages_missing_timestamp = any( |
| 605 | + not getattr(msg, "_created_at_was_provided", True) for msg in memory.messages |
| 606 | + ) |
| 607 | + if messages_missing_timestamp: |
| 608 | + response.headers["X-Deprecation-Warning"] = ( |
| 609 | + "messages[].created_at will become required in the next major version. " |
| 610 | + "Please provide timestamps for all messages." |
| 611 | + ) |
| 612 | + |
| 613 | + return await put_working_memory_core( |
| 614 | + session_id=session_id, |
| 615 | + memory=memory, |
| 616 | + background_tasks=background_tasks, |
| 617 | + model_name=model_name, |
| 618 | + context_window_max=context_window_max, |
| 619 | + ) |
| 620 | + |
| 621 | + |
578 | 622 | @router.delete("/v1/working-memory/{session_id}", response_model=AckResponse) |
579 | 623 | async def delete_working_memory( |
580 | 624 | session_id: str, |
|
0 commit comments