diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e9550a3 --- /dev/null +++ b/Makefile @@ -0,0 +1,69 @@ +.PHONY: help setup sync lint format test test-api pre-commit clean server mcp worker + +help: ## Show this help message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +# Setup and dependencies +setup: ## Initial setup: create virtualenv and install dependencies + pip install uv + uv venv + uv sync --all-extras + uv run pre-commit install + +sync: ## Sync dependencies from lock file + uv sync --all-extras + +# Code quality +lint: ## Run linting checks (ruff) + uv run ruff check . + +format: ## Format code (ruff) + uv run ruff format . + uv run ruff check --fix . + +pre-commit: ## Run all pre-commit hooks + uv run pre-commit run --all-files + +# Testing +test: ## Run tests (excludes API tests requiring keys) + uv run pytest + +test-api: ## Run all tests including API tests (requires OPENAI_API_KEY) + uv run pytest --run-api-tests + +test-unit: ## Run only unit tests + uv run pytest tests/unit/ + +test-integration: ## Run only integration tests + uv run pytest tests/integration/ + +test-cov: ## Run tests with coverage report + uv run pytest --cov + +# Running services +server: ## Start the REST API server + uv run agent-memory api + +mcp: ## Start the MCP server (stdio mode) + uv run agent-memory mcp + +mcp-sse: ## Start the MCP server (SSE mode on port 9000) + uv run agent-memory mcp --mode sse --port 9000 + +worker: ## Start the background task worker + uv run agent-memory task-worker + +# Database operations +rebuild-index: ## Rebuild Redis search index + uv run agent-memory rebuild-index + +migrate: ## Run memory migrations + uv run agent-memory migrate-memories + +# Cleanup +clean: ## Clean up generated files and caches + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .pytest_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .ruff_cache -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + rm -rf .coverage htmlcov/ 2>/dev/null || true diff --git a/README.md b/README.md index 3f2feab..926a53c 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A memory layer for AI agents. **[Documentation](https://redis.github.io/agent-memory-server/)** • **[GitHub](https://github.com/redis/agent-memory-server)** • **[Docker](https://hub.docker.com/r/redislabs/agent-memory-server)** - + ## Features - **Dual Interface**: REST API and Model Context Protocol (MCP) server - **Two-Tier Memory**: Working memory (session-scoped) and long-term memory (persistent) diff --git a/agent-memory-client/agent-memory-client-java/build.gradle.kts b/agent-memory-client/agent-memory-client-java/build.gradle.kts index 029d178..3e0c5cd 100644 --- a/agent-memory-client/agent-memory-client-java/build.gradle.kts +++ b/agent-memory-client/agent-memory-client-java/build.gradle.kts @@ -132,4 +132,4 @@ publishing { url = uri(layout.buildDirectory.dir("staging-deploy")) } } -} \ No newline at end of file +} diff --git a/agent-memory-client/agent-memory-client-java/gradle.properties b/agent-memory-client/agent-memory-client-java/gradle.properties index cade5ff..de55ab6 100644 --- a/agent-memory-client/agent-memory-client-java/gradle.properties +++ b/agent-memory-client/agent-memory-client-java/gradle.properties @@ -1,2 +1 @@ version=0.1.0 - diff --git a/agent-memory-client/agent-memory-client-java/settings.gradle.kts b/agent-memory-client/agent-memory-client-java/settings.gradle.kts index a959bdb..35d993b 100644 --- a/agent-memory-client/agent-memory-client-java/settings.gradle.kts +++ b/agent-memory-client/agent-memory-client-java/settings.gradle.kts @@ -1 +1 @@ -rootProject.name = "agent-memory-client-java" \ No newline at end of file +rootProject.name = "agent-memory-client-java" diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/MemoryAPIClient.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/MemoryAPIClient.java index 58a4ed6..e98ef56 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/MemoryAPIClient.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/MemoryAPIClient.java @@ -142,7 +142,7 @@ public void close() { httpClient.dispatcher().executorService().shutdown(); httpClient.connectionPool().evictAll(); } - + /** * Creates a new builder for MemoryAPIClient. * @param baseUrl the base URL of the memory server diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryClientException.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryClientException.java index b7e0cc9..f93e555 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryClientException.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryClientException.java @@ -4,13 +4,12 @@ * Base exception for all memory client errors. */ public class MemoryClientException extends Exception { - + public MemoryClientException(String message) { super(message); } - + public MemoryClientException(String message, Throwable cause) { super(message, cause); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryNotFoundException.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryNotFoundException.java index 0a1cfdb..a24fc91 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryNotFoundException.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryNotFoundException.java @@ -4,13 +4,12 @@ * Raised when a requested memory or session is not found. */ public class MemoryNotFoundException extends MemoryClientException { - + public MemoryNotFoundException(String message) { super(message); } - + public MemoryNotFoundException(String message, Throwable cause) { super(message, cause); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryServerException.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryServerException.java index 56d418c..4aada3b 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryServerException.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryServerException.java @@ -6,27 +6,26 @@ * Raised when the memory server returns an error. */ public class MemoryServerException extends MemoryClientException { - + @Nullable private final Integer statusCode; - + public MemoryServerException(String message) { this(message, null); } - + public MemoryServerException(String message, @Nullable Integer statusCode) { super(message); this.statusCode = statusCode; } - + public MemoryServerException(String message, @Nullable Integer statusCode, Throwable cause) { super(message, cause); this.statusCode = statusCode; } - + @Nullable public Integer getStatusCode() { return statusCode; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryValidationException.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryValidationException.java index 260d3ba..a94da98 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryValidationException.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/exceptions/MemoryValidationException.java @@ -7,13 +7,12 @@ * requests to the server, allowing for early error detection. */ public class MemoryValidationException extends MemoryClientException { - + public MemoryValidationException(String message) { super(message); } - + public MemoryValidationException(String message, Throwable cause) { super(message, cause); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/AckResponse.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/AckResponse.java index acd6384..ee7829a 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/AckResponse.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/AckResponse.java @@ -6,22 +6,22 @@ * Generic acknowledgement response. */ public class AckResponse { - + @NotNull private String status; - + public AckResponse() { } - + public AckResponse(@NotNull String status) { this.status = status; } - + @NotNull public String getStatus() { return status; } - + public void setStatus(@NotNull String status) { this.status = status; } @@ -33,4 +33,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/health/HealthCheckResponse.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/health/HealthCheckResponse.java index c506790..28f8344 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/health/HealthCheckResponse.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/health/HealthCheckResponse.java @@ -4,20 +4,20 @@ * Health check response from the server. */ public class HealthCheckResponse { - + private double now; - + public HealthCheckResponse() { } - + public HealthCheckResponse(double now) { this.now = now; } - + public double getNow() { return now; } - + public void setNow(double now) { this.now = now; } @@ -29,4 +29,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/ForgetResponse.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/ForgetResponse.java index 2a45073..1d929e0 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/ForgetResponse.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/ForgetResponse.java @@ -9,61 +9,61 @@ * Response from the "forget" endpoint. */ public class ForgetResponse { - + private int scanned; - + private int deleted; - + @NotNull @JsonProperty("deleted_ids") private List deletedIds; - + @JsonProperty("dry_run") private boolean dryRun; - + public ForgetResponse() { } - + public ForgetResponse(int scanned, int deleted, @NotNull List deletedIds, boolean dryRun) { this.scanned = scanned; this.deleted = deleted; this.deletedIds = deletedIds; this.dryRun = dryRun; } - + public int getScanned() { return scanned; } - + public void setScanned(int scanned) { this.scanned = scanned; } - + public int getDeleted() { return deleted; } - + public void setDeleted(int deleted) { this.deleted = deleted; } - + @NotNull public List getDeletedIds() { return deletedIds; } - + public void setDeletedIds(@NotNull List deletedIds) { this.deletedIds = deletedIds; } - + public boolean isDryRun() { return dryRun; } - + public void setDryRun(boolean dryRun) { this.dryRun = dryRun; } - + @Override public String toString() { return "ForgetResponse{" + @@ -74,4 +74,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java index b9e3ea4..c7e6447 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java @@ -12,66 +12,66 @@ * A memory record in the system. */ public class MemoryRecord { - + @NotNull private String id; - + @NotNull private String text; - + @Nullable @JsonProperty("session_id") private String sessionId; - + @Nullable @JsonProperty("user_id") private String userId; - + @Nullable private String namespace; - + @NotNull @JsonProperty("last_accessed") private Instant lastAccessed; - + @NotNull @JsonProperty("created_at") private Instant createdAt; - + @NotNull @JsonProperty("updated_at") private Instant updatedAt; - + @Nullable private List topics; - + @Nullable private List entities; - + @Nullable @JsonProperty("memory_hash") private String memoryHash; - + @NotNull @JsonProperty("discrete_memory_extracted") private String discreteMemoryExtracted; - + @NotNull @JsonProperty("memory_type") private MemoryType memoryType; - + @Nullable @JsonProperty("persisted_at") private Instant persistedAt; - + @Nullable @JsonProperty("extracted_from") private List extractedFrom; - + @Nullable @JsonProperty("event_date") private Instant eventDate; - + public MemoryRecord() { this.id = UlidCreator.getUlid().toString(); Instant now = Instant.now(); @@ -81,68 +81,68 @@ public MemoryRecord() { this.discreteMemoryExtracted = "f"; this.memoryType = MemoryType.MESSAGE; } - + public MemoryRecord(@NotNull String text) { this(); this.text = text; } - + // Getters and setters - + @NotNull public String getId() { return id; } - + public void setId(@NotNull String id) { this.id = id; } - + @NotNull public String getText() { return text; } - + public void setText(@NotNull String text) { this.text = text; } - + @Nullable public String getSessionId() { return sessionId; } - + public void setSessionId(@Nullable String sessionId) { this.sessionId = sessionId; } - + @Nullable public String getUserId() { return userId; } - + public void setUserId(@Nullable String userId) { this.userId = userId; } - + @Nullable public String getNamespace() { return namespace; } - + public void setNamespace(@Nullable String namespace) { this.namespace = namespace; } - + @NotNull public Instant getLastAccessed() { return lastAccessed; } - + public void setLastAccessed(@NotNull Instant lastAccessed) { this.lastAccessed = lastAccessed; } - + @NotNull public Instant getCreatedAt() { return createdAt; @@ -510,4 +510,3 @@ public MemoryRecord build() { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResult.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResult.java index db023ec..8fb2c50 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResult.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResult.java @@ -4,17 +4,17 @@ * Result from a memory search operation. */ public class MemoryRecordResult extends MemoryRecord { - + private double dist; - + public MemoryRecordResult() { super(); } - + public double getDist() { return dist; } - + public void setDist(double dist) { this.dist = dist; } @@ -27,4 +27,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResults.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResults.java index 8c5cf35..6ac22bf 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResults.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResults.java @@ -10,46 +10,46 @@ * Results from memory search operations. */ public class MemoryRecordResults { - + @NotNull private List memories; - + private int total; - + @Nullable @JsonProperty("next_offset") private Integer nextOffset; - + public MemoryRecordResults() { } - + public MemoryRecordResults(@NotNull List memories, int total) { this.memories = memories; this.total = total; } - + @NotNull public List getMemories() { return memories; } - + public void setMemories(@NotNull List memories) { this.memories = memories; } - + public int getTotal() { return total; } - + public void setTotal(int total) { this.total = total; } - + @Nullable public Integer getNextOffset() { return nextOffset; } - + public void setNextOffset(@Nullable Integer nextOffset) { this.nextOffset = nextOffset; } @@ -63,4 +63,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryType.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryType.java index 59ea9f0..682a2a9 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryType.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryType.java @@ -9,18 +9,18 @@ public enum MemoryType { EPISODIC("episodic"), SEMANTIC("semantic"), MESSAGE("message"); - + private final String value; - + MemoryType(String value) { this.value = value; } - + @JsonValue public String getValue() { return value; } - + public static MemoryType fromValue(String value) { for (MemoryType type : values()) { if (type.value.equals(value)) { @@ -30,4 +30,3 @@ public static MemoryType fromValue(String value) { throw new IllegalArgumentException("Unknown memory type: " + value); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java index 12ff6db..85efd44 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java @@ -11,113 +11,113 @@ */ @JsonInclude(JsonInclude.Include.NON_NULL) public class SearchRequest { - + @Nullable private String text; - + @Nullable @JsonProperty("session_id") private String sessionId; - + @Nullable private String namespace; - + @Nullable private List topics; - + @Nullable private List entities; - + @Nullable @JsonProperty("user_id") private String userId; - + @Nullable @JsonProperty("distance_threshold") private Double distanceThreshold; - + private int limit = 10; - + private int offset = 0; - + public SearchRequest() { } - + @Nullable public String getText() { return text; } - + public void setText(@Nullable String text) { this.text = text; } - + @Nullable public String getSessionId() { return sessionId; } - + public void setSessionId(@Nullable String sessionId) { this.sessionId = sessionId; } - + @Nullable public String getNamespace() { return namespace; } - + public void setNamespace(@Nullable String namespace) { this.namespace = namespace; } - + @Nullable public List getTopics() { return topics; } - + public void setTopics(@Nullable List topics) { this.topics = topics; } - + @Nullable public List getEntities() { return entities; } - + public void setEntities(@Nullable List entities) { this.entities = entities; } - + @Nullable public String getUserId() { return userId; } - + public void setUserId(@Nullable String userId) { this.userId = userId; } - + @Nullable public Double getDistanceThreshold() { return distanceThreshold; } - + public void setDistanceThreshold(@Nullable Double distanceThreshold) { this.distanceThreshold = distanceThreshold; } - + public int getLimit() { return limit; } - + public void setLimit(int limit) { this.limit = limit; } - + public int getOffset() { return offset; } - + public void setOffset(int offset) { this.offset = offset; } @@ -201,4 +201,3 @@ public SearchRequest build() { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryMessage.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryMessage.java index 0f6825a..1a6218e 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryMessage.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryMessage.java @@ -11,87 +11,87 @@ * A message in the memory system. */ public class MemoryMessage { - + @NotNull private String role; - + @NotNull private String content; - + @NotNull private String id; - + @NotNull @JsonProperty("created_at") private Instant createdAt; - + @Nullable @JsonProperty("persisted_at") private Instant persistedAt; - + @NotNull @JsonProperty("discrete_memory_extracted") private String discreteMemoryExtracted; - + public MemoryMessage() { this.id = UlidCreator.getUlid().toString(); this.createdAt = Instant.now(); this.discreteMemoryExtracted = "f"; } - + public MemoryMessage(@NotNull String role, @NotNull String content) { this(); this.role = role; this.content = content; } - + // Getters and setters - + @NotNull public String getRole() { return role; } - + public void setRole(@NotNull String role) { this.role = role; } - + @NotNull public String getContent() { return content; } - + public void setContent(@NotNull String content) { this.content = content; } - + @NotNull public String getId() { return id; } - + public void setId(@NotNull String id) { this.id = id; } - + @NotNull public Instant getCreatedAt() { return createdAt; } - + public void setCreatedAt(@NotNull Instant createdAt) { this.createdAt = createdAt; } - + @Nullable public Instant getPersistedAt() { return persistedAt; } - + public void setPersistedAt(@Nullable Instant persistedAt) { this.persistedAt = persistedAt; } - + @NotNull public String getDiscreteMemoryExtracted() { return discreteMemoryExtracted; @@ -223,4 +223,3 @@ public MemoryMessage build() { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryStrategyConfig.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryStrategyConfig.java index 8915292..3e61d53 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryStrategyConfig.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MemoryStrategyConfig.java @@ -10,37 +10,37 @@ * Configuration for memory extraction strategy. */ public class MemoryStrategyConfig { - + @NotNull private String strategy; - + @NotNull private Map config; - + public MemoryStrategyConfig() { this.strategy = "discrete"; this.config = new HashMap<>(); } - + public MemoryStrategyConfig(@NotNull String strategy) { this.strategy = strategy; this.config = new HashMap<>(); } - + public MemoryStrategyConfig(@NotNull String strategy, @NotNull Map config) { this.strategy = strategy; this.config = config; } - + @NotNull public String getStrategy() { return strategy; } - + public void setStrategy(@NotNull String strategy) { this.strategy = strategy; } - + @NotNull public Map getConfig() { return config; @@ -122,4 +122,3 @@ public MemoryStrategyConfig build() { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MergeStrategy.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MergeStrategy.java index d63e2b9..483b53e 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MergeStrategy.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/MergeStrategy.java @@ -8,15 +8,14 @@ public enum MergeStrategy { * Replace existing data entirely with new data. */ REPLACE, - + /** * Shallow merge - top-level keys from new data override existing keys. */ MERGE, - + /** * Deep merge - recursively merge nested maps. */ DEEP_MERGE } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/SessionListResponse.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/SessionListResponse.java index 3411ac6..1ba161b 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/SessionListResponse.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/SessionListResponse.java @@ -8,33 +8,33 @@ * Response containing a list of sessions. */ public class SessionListResponse { - + @NotNull private List sessions; - + private int total; - + public SessionListResponse() { } - + public SessionListResponse(@NotNull List sessions, int total) { this.sessions = sessions; this.total = total; } - + @NotNull public List getSessions() { return sessions; } - + public void setSessions(@NotNull List sessions) { this.sessions = sessions; } - + public int getTotal() { return total; } - + public void setTotal(int total) { this.total = total; } @@ -47,4 +47,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemory.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemory.java index 1b2fb78..4669e14 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemory.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemory.java @@ -15,44 +15,44 @@ * Working memory for a session - contains both messages and structured memory records. */ public class WorkingMemory { - + @NotNull private List messages; - + @NotNull private List memories; - + @Nullable private Map data; - + @Nullable private String context; - + @Nullable @JsonProperty("user_id") private String userId; - + private int tokens; - + @NotNull @JsonProperty("session_id") private String sessionId; - + @Nullable private String namespace; - + @NotNull @JsonProperty("long_term_memory_strategy") private MemoryStrategyConfig longTermMemoryStrategy; - + @Nullable @JsonProperty("ttl_seconds") private Integer ttlSeconds; - + @NotNull @JsonProperty("last_accessed") private Instant lastAccessed; - + public WorkingMemory() { this.messages = new ArrayList<>(); this.memories = new ArrayList<>(); @@ -61,85 +61,85 @@ public WorkingMemory() { this.longTermMemoryStrategy = new MemoryStrategyConfig(); this.lastAccessed = Instant.now(); } - + public WorkingMemory(@NotNull String sessionId) { this(); this.sessionId = sessionId; } - + // Getters and setters - + @NotNull public List getMessages() { return messages; } - + public void setMessages(@NotNull List messages) { this.messages = messages; } - + @NotNull public List getMemories() { return memories; } - + public void setMemories(@NotNull List memories) { this.memories = memories; } - + @Nullable public Map getData() { return data; } - + public void setData(@Nullable Map data) { this.data = data; } - + @Nullable public String getContext() { return context; } - + public void setContext(@Nullable String context) { this.context = context; } - + @Nullable public String getUserId() { return userId; } - + public void setUserId(@Nullable String userId) { this.userId = userId; } - + public int getTokens() { return tokens; } - + public void setTokens(int tokens) { this.tokens = tokens; } - + @NotNull public String getSessionId() { return sessionId; } - + public void setSessionId(@NotNull String sessionId) { this.sessionId = sessionId; } - + @Nullable public String getNamespace() { return namespace; } - + public void setNamespace(@Nullable String namespace) { this.namespace = namespace; } - + @NotNull public MemoryStrategyConfig getLongTermMemoryStrategy() { return longTermMemoryStrategy; @@ -388,4 +388,3 @@ public WorkingMemory build() { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResponse.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResponse.java index bf660c5..8d7d7a5 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResponse.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResponse.java @@ -7,58 +7,58 @@ * Response from working memory operations. */ public class WorkingMemoryResponse extends WorkingMemory { - + @Nullable @JsonProperty("context_percentage_total_used") private Double contextPercentageTotalUsed; - + @Nullable @JsonProperty("context_percentage_until_summarization") private Double contextPercentageUntilSummarization; - + @Nullable @JsonProperty("new_session") private Boolean newSession; - + @Nullable private Boolean unsaved; - + public WorkingMemoryResponse() { super(); } - + @Nullable public Double getContextPercentageTotalUsed() { return contextPercentageTotalUsed; } - + public void setContextPercentageTotalUsed(@Nullable Double contextPercentageTotalUsed) { this.contextPercentageTotalUsed = contextPercentageTotalUsed; } - + @Nullable public Double getContextPercentageUntilSummarization() { return contextPercentageUntilSummarization; } - + public void setContextPercentageUntilSummarization(@Nullable Double contextPercentageUntilSummarization) { this.contextPercentageUntilSummarization = contextPercentageUntilSummarization; } - + @Nullable public Boolean getNewSession() { return newSession; } - + public void setNewSession(@Nullable Boolean newSession) { this.newSession = newSession; } - + @Nullable public Boolean getUnsaved() { return unsaved; } - + public void setUnsaved(@Nullable Boolean unsaved) { this.unsaved = unsaved; } @@ -74,4 +74,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResult.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResult.java index df53658..e402c3c 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResult.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryResult.java @@ -7,26 +7,26 @@ public class WorkingMemoryResult { private final boolean created; private final WorkingMemoryResponse memory; - + public WorkingMemoryResult(boolean created, WorkingMemoryResponse memory) { this.created = created; this.memory = memory; } - + /** * @return true if the memory was created, false if it already existed */ public boolean isCreated() { return created; } - + /** * @return the working memory (either newly created or existing) */ public WorkingMemoryResponse getMemory() { return memory; } - + @Override public String toString() { return "WorkingMemoryResult{" + @@ -35,4 +35,3 @@ public String toString() { '}'; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/BaseService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/BaseService.java index 0947858..c088101 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/BaseService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/BaseService.java @@ -17,16 +17,16 @@ * Base service class providing common functionality for all service classes. */ public abstract class BaseService { - + protected static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); - + protected final String baseUrl; protected final OkHttpClient httpClient; protected final ObjectMapper objectMapper; protected final String defaultNamespace; protected final String defaultModelName; protected final Integer defaultContextWindowMax; - + protected BaseService( @NotNull String baseUrl, @NotNull OkHttpClient httpClient, @@ -41,7 +41,7 @@ protected BaseService( this.defaultModelName = defaultModelName; this.defaultContextWindowMax = defaultContextWindowMax; } - + /** * Handle HTTP errors and throw appropriate exceptions. */ @@ -82,4 +82,3 @@ protected void handleHttpError(@NotNull Response response) throws MemoryClientEx } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/HealthService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/HealthService.java index 939e2dd..ea56a60 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/HealthService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/HealthService.java @@ -17,7 +17,7 @@ * Service for health check operations. */ public class HealthService extends BaseService { - + public HealthService( @NotNull String baseUrl, @NotNull OkHttpClient httpClient, @@ -27,10 +27,10 @@ public HealthService( @Nullable Integer defaultContextWindowMax) { super(baseUrl, httpClient, objectMapper, defaultNamespace, defaultModelName, defaultContextWindowMax); } - + /** * Check the health of the memory server. - * + * * @return HealthCheckResponse with current server timestamp * @throws MemoryClientException if the request fails */ @@ -39,21 +39,20 @@ public HealthCheckResponse healthCheck() throws MemoryClientException { .url(baseUrl + "/v1/health") .get() .build(); - + try (Response response = httpClient.newCall(request).execute()) { if (!response.isSuccessful()) { handleHttpError(response); } - + ResponseBody body = response.body(); if (body == null) { throw new MemoryServerException("Empty response body"); } - + return objectMapper.readValue(body.string(), HealthCheckResponse.class); } catch (IOException e) { throw new MemoryClientException("Failed to execute health check", e); } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java index e8a15af..61e7abd 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java @@ -17,7 +17,7 @@ * Service for long-term memory operations. */ public class LongTermMemoryService extends BaseService { - + public LongTermMemoryService( @NotNull String baseUrl, @NotNull OkHttpClient httpClient, @@ -27,7 +27,7 @@ public LongTermMemoryService( @Nullable Integer defaultContextWindowMax) { super(baseUrl, httpClient, objectMapper, defaultNamespace, defaultModelName, defaultContextWindowMax); } - + /** * Create long-term memories. * @@ -38,31 +38,31 @@ public LongTermMemoryService( public AckResponse createLongTermMemories(@NotNull List memories) throws MemoryClientException { Map payload = new HashMap<>(); payload.put("memories", memories); - + try { String json = objectMapper.writeValueAsString(payload); RequestBody body = RequestBody.create(json, JSON); - + Request request = new Request.Builder() .url(baseUrl + "/v1/long-term-memory/") .post(body) .build(); - + try (Response response = httpClient.newCall(request).execute()) { handleHttpError(response); - + ResponseBody responseBody = response.body(); if (responseBody == null) { throw new MemoryClientException("Empty response body"); } - + return objectMapper.readValue(responseBody.string(), AckResponse.class); } } catch (IOException e) { throw new MemoryClientException("Failed to create long-term memories", e); } } - + /** * Search long-term memories. * @@ -76,7 +76,7 @@ public MemoryRecordResults searchLongTermMemories(@NotNull SearchRequest request payload.put("text", request.getText()); payload.put("limit", request.getLimit()); payload.put("offset", request.getOffset()); - + // Add filters if present if (request.getSessionId() != null) { payload.put("session_id", Map.of("eq", request.getSessionId())); @@ -89,38 +89,38 @@ public MemoryRecordResults searchLongTermMemories(@NotNull SearchRequest request } else if (defaultNamespace != null) { payload.put("namespace", Map.of("eq", defaultNamespace)); } - + if (request.getTopics() != null && !request.getTopics().isEmpty()) { payload.put("topics", Map.of("any", request.getTopics())); } if (request.getEntities() != null && !request.getEntities().isEmpty()) { payload.put("entities", Map.of("any", request.getEntities())); } - + try { String json = objectMapper.writeValueAsString(payload); RequestBody body = RequestBody.create(json, JSON); - + Request httpRequest = new Request.Builder() .url(baseUrl + "/v1/long-term-memory/search") .post(body) .build(); - + try (Response response = httpClient.newCall(httpRequest).execute()) { handleHttpError(response); - + ResponseBody responseBody = response.body(); if (responseBody == null) { throw new MemoryClientException("Empty response body"); } - + return objectMapper.readValue(responseBody.string(), MemoryRecordResults.class); } } catch (IOException e) { throw new MemoryClientException("Failed to search long-term memories", e); } } - + /** * Search long-term memories with simple text query. */ @@ -130,7 +130,7 @@ public MemoryRecordResults searchLongTermMemories(@NotNull String text) throws M .build(); return searchLongTermMemories(request); } - + /** * Get a single long-term memory by ID. * @@ -143,15 +143,15 @@ public MemoryRecord getLongTermMemory(@NotNull String memoryId) throws MemoryCli .url(baseUrl + "/v1/long-term-memory/" + memoryId) .get() .build(); - + try (Response response = httpClient.newCall(request).execute()) { handleHttpError(response); - + ResponseBody body = response.body(); if (body == null) { throw new MemoryClientException("Empty response body"); } - + return objectMapper.readValue(body.string(), MemoryRecord.class); } catch (IOException e) { throw new MemoryClientException("Failed to get long-term memory", e); @@ -462,4 +462,3 @@ public Stream searchAllLongTermMemoriesStream( ); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/MemoryHydrationService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/MemoryHydrationService.java index dbdc790..7c21d85 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/MemoryHydrationService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/MemoryHydrationService.java @@ -14,7 +14,7 @@ * Service for memory hydration operations (memory prompt). */ public class MemoryHydrationService extends BaseService { - + public MemoryHydrationService( @NotNull String baseUrl, @NotNull OkHttpClient httpClient, @@ -24,7 +24,7 @@ public MemoryHydrationService( @Nullable Integer defaultContextWindowMax) { super(baseUrl, httpClient, objectMapper, defaultNamespace, defaultModelName, defaultContextWindowMax); } - + /** * Hydrate a user query with memory context and return a prompt ready to send to an LLM. * @@ -48,44 +48,44 @@ public Map memoryPrompt( @Nullable Map longTermSearch, @Nullable String userId, boolean optimizeQuery) throws MemoryClientException { - + Map payload = new HashMap<>(); payload.put("query", query); - + // Add session parameters if provided if (sessionId != null) { Map sessionParams = new HashMap<>(); sessionParams.put("session_id", sessionId); - + if (namespace != null) { sessionParams.put("namespace", namespace); } else if (defaultNamespace != null) { sessionParams.put("namespace", defaultNamespace); } - + String effectiveModelName = modelName != null ? modelName : defaultModelName; if (effectiveModelName != null) { sessionParams.put("model_name", effectiveModelName); } - + Integer effectiveContextWindowMax = contextWindowMax != null ? contextWindowMax : defaultContextWindowMax; if (effectiveContextWindowMax != null) { sessionParams.put("context_window_max", effectiveContextWindowMax); } - + if (userId != null) { sessionParams.put("user_id", userId); } - + payload.put("session", sessionParams); } - + // Add long-term search parameters if provided if (longTermSearch != null) { Map searchParams = new HashMap<>(longTermSearch); - + // Add namespace to long-term search if not present if (!searchParams.containsKey("namespace")) { if (namespace != null) { @@ -98,30 +98,30 @@ public Map memoryPrompt( searchParams.put("namespace", namespaceFilter); } } - + payload.put("long_term_search", searchParams); } - + HttpUrl.Builder urlBuilder = HttpUrl.parse(baseUrl + "/v1/memory/prompt").newBuilder(); urlBuilder.addQueryParameter("optimize_query", String.valueOf(optimizeQuery)); - + try { String json = objectMapper.writeValueAsString(payload); RequestBody body = RequestBody.create(json, JSON); - + Request request = new Request.Builder() .url(urlBuilder.build()) .post(body) .build(); - + try (Response response = httpClient.newCall(request).execute()) { handleHttpError(response); - + ResponseBody responseBody = response.body(); if (responseBody == null) { throw new MemoryClientException("Empty response body"); } - + @SuppressWarnings("unchecked") Map result = objectMapper.readValue(responseBody.string(), Map.class); return result; @@ -130,7 +130,7 @@ public Map memoryPrompt( throw new MemoryClientException("Failed to hydrate memory prompt: " + e.getMessage(), e); } } - + /** * Hydrate a query with minimal parameters. */ @@ -138,4 +138,3 @@ public Map memoryPrompt(@NotNull String query) throws MemoryClie return memoryPrompt(query, null, null, null, null, null, null, false); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/WorkingMemoryService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/WorkingMemoryService.java index e1fff4f..3a79c50 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/WorkingMemoryService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/WorkingMemoryService.java @@ -17,7 +17,7 @@ * Service for working memory operations. */ public class WorkingMemoryService extends BaseService { - + public WorkingMemoryService( @NotNull String baseUrl, @NotNull OkHttpClient httpClient, @@ -27,10 +27,10 @@ public WorkingMemoryService( @Nullable Integer defaultContextWindowMax) { super(baseUrl, httpClient, objectMapper, defaultNamespace, defaultModelName, defaultContextWindowMax); } - + /** * List available sessions with optional pagination and filtering. - * + * * @param limit Maximum number of sessions to return * @param offset Offset for pagination * @param namespace Optional namespace filter @@ -48,43 +48,43 @@ public SessionListResponse listSessions( .newBuilder() .addQueryParameter("limit", String.valueOf(limit)) .addQueryParameter("offset", String.valueOf(offset)); - + if (namespace != null) { urlBuilder.addQueryParameter("namespace", namespace); } else if (defaultNamespace != null) { urlBuilder.addQueryParameter("namespace", defaultNamespace); } - + if (userId != null) { urlBuilder.addQueryParameter("user_id", userId); } - + Request request = new Request.Builder() .url(urlBuilder.build()) .get() .build(); - + try (Response response = httpClient.newCall(request).execute()) { handleHttpError(response); - + ResponseBody body = response.body(); if (body == null) { throw new MemoryClientException("Empty response body"); } - + return objectMapper.readValue(body.string(), SessionListResponse.class); } catch (IOException e) { throw new MemoryClientException("Failed to list sessions", e); } } - + /** * List sessions with default pagination. */ public SessionListResponse listSessions() throws MemoryClientException { return listSessions(100, 0, null, null); } - + /** * Get working memory for a session. * @@ -106,42 +106,42 @@ public WorkingMemoryResponse getWorkingMemory( HttpUrl.Builder urlBuilder = HttpUrl.parse( baseUrl + "/v1/working-memory/" + sessionId ).newBuilder(); - + if (userId != null) { urlBuilder.addQueryParameter("user_id", userId); } - + if (namespace != null) { urlBuilder.addQueryParameter("namespace", namespace); } else if (defaultNamespace != null) { urlBuilder.addQueryParameter("namespace", defaultNamespace); } - + String effectiveModelName = modelName != null ? modelName : defaultModelName; if (effectiveModelName != null) { urlBuilder.addQueryParameter("model_name", effectiveModelName); } - + Integer effectiveContextWindowMax = contextWindowMax != null ? contextWindowMax : defaultContextWindowMax; if (effectiveContextWindowMax != null) { urlBuilder.addQueryParameter("context_window_max", String.valueOf(effectiveContextWindowMax)); } - + Request request = new Request.Builder() .url(urlBuilder.build()) .get() .build(); - + try (Response response = httpClient.newCall(request).execute()) { handleHttpError(response); - + ResponseBody body = response.body(); if (body == null) { throw new MemoryClientException("Empty response body"); } - + return objectMapper.readValue(body.string(), WorkingMemoryResponse.class); } catch (IOException e) { throw new MemoryClientException("Failed to get working memory", e); diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryAPIClientTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryAPIClientTest.java index 5bf08e4..220d060 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryAPIClientTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryAPIClientTest.java @@ -19,11 +19,11 @@ import static org.junit.jupiter.api.Assertions.*; class MemoryAPIClientTest { - + private MockWebServer mockServer; private MemoryAPIClient client; private ObjectMapper objectMapper; - + @BeforeEach void setUp() throws IOException { mockServer = new MockWebServer(); @@ -38,13 +38,13 @@ void setUp() throws IOException { objectMapper.registerModule(new JavaTimeModule()); objectMapper.disable(com.fasterxml.jackson.databind.SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); } - + @AfterEach void tearDown() throws Exception { client.close(); mockServer.shutdown(); } - + @Test void testNotFoundError() { diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryRecordBuilderTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryRecordBuilderTest.java index a91429c..b6ff632 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryRecordBuilderTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/MemoryRecordBuilderTest.java @@ -56,7 +56,7 @@ void testMemoryRecordBuilderRequiresText() { void testMemoryRecordDefaultConstructor() { // Test that default constructor still uses old defaults (for deserialization) MemoryRecord memory = new MemoryRecord(); - + assertNotNull(memory.getId()); assertEquals("f", memory.getDiscreteMemoryExtracted()); // Should be "f" for default constructor assertEquals(MemoryType.MESSAGE, memory.getMemoryType()); // Should be MESSAGE for default constructor @@ -66,10 +66,9 @@ void testMemoryRecordDefaultConstructor() { void testMemoryRecordConstructorWithText() { // Test that text constructor uses old defaults (for deserialization) MemoryRecord memory = new MemoryRecord("Test text"); - + assertEquals("Test text", memory.getText()); assertEquals("f", memory.getDiscreteMemoryExtracted()); assertEquals(MemoryType.MESSAGE, memory.getMemoryType()); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/BaseIntegrationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/BaseIntegrationTest.java index db6ae57..7e8638f 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/BaseIntegrationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/BaseIntegrationTest.java @@ -141,4 +141,3 @@ protected ObjectMapper getObjectMapper() { return objectMapper; } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/EndToEndIntegrationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/EndToEndIntegrationTest.java index 1bf3341..e507652 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/EndToEndIntegrationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/EndToEndIntegrationTest.java @@ -194,4 +194,3 @@ void testMultipleSessionsIsolation() throws Exception { assertEquals("2", retrieved2.getData().get("session")); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/LongTermMemoryIntegrationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/LongTermMemoryIntegrationTest.java index 03e1e33..c912275 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/LongTermMemoryIntegrationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/LongTermMemoryIntegrationTest.java @@ -437,4 +437,3 @@ void testSearchAllLongTermMemoriesStream() throws Exception { assertTrue(count >= 0); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/MemoryHydrationIntegrationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/MemoryHydrationIntegrationTest.java index 8c1a5b4..db90172 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/MemoryHydrationIntegrationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/MemoryHydrationIntegrationTest.java @@ -73,7 +73,7 @@ void testMemoryPromptWithWorkingMemory() throws Exception { assertNotNull(result); assertTrue(result.containsKey("messages")); - + // The result should include the conversation history Object resultMessages = result.get("messages"); assertNotNull(resultMessages); @@ -205,4 +205,3 @@ void testMemoryPromptWithQueryOptimization() throws Exception { } } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/WorkingMemoryIntegrationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/WorkingMemoryIntegrationTest.java index c04a7a1..2aafc0b 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/WorkingMemoryIntegrationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/integration/WorkingMemoryIntegrationTest.java @@ -299,4 +299,3 @@ void testUpdateWorkingMemoryData() throws Exception { assertEquals("active", response.getData().get("status")); // Should still exist } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java index ca7c4f3..f3bf830 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java @@ -20,9 +20,9 @@ import static org.junit.jupiter.api.Assertions.*; class JsonSerializationTest { - + private ObjectMapper objectMapper; - + @BeforeEach void setUp() { objectMapper = new ObjectMapper(); @@ -30,21 +30,21 @@ void setUp() { objectMapper.disable(com.fasterxml.jackson.databind.SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); objectMapper.disable(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); } - + @Test void testMemoryMessageSerialization() throws Exception { MemoryMessage message = new MemoryMessage("user", "Hello, world!"); - + String json = objectMapper.writeValueAsString(message); assertNotNull(json); assertTrue(json.contains("\"role\":\"user\"")); assertTrue(json.contains("\"content\":\"Hello, world!\"")); - + MemoryMessage deserialized = objectMapper.readValue(json, MemoryMessage.class); assertEquals("user", deserialized.getRole()); assertEquals("Hello, world!", deserialized.getContent()); } - + @Test void testMemoryRecordSerialization() throws Exception { MemoryRecord record = new MemoryRecord("Test memory"); @@ -54,13 +54,13 @@ void testMemoryRecordSerialization() throws Exception { record.setMemoryType(MemoryType.SEMANTIC); record.setTopics(Arrays.asList("topic1", "topic2")); record.setEntities(Arrays.asList("entity1", "entity2")); - + String json = objectMapper.writeValueAsString(record); assertNotNull(json); assertTrue(json.contains("\"text\":\"Test memory\"")); assertTrue(json.contains("\"user_id\":\"user-123\"")); assertTrue(json.contains("\"memory_type\":\"semantic\"")); - + MemoryRecord deserialized = objectMapper.readValue(json, MemoryRecord.class); assertEquals("Test memory", deserialized.getText()); assertEquals("user-123", deserialized.getUserId()); @@ -68,20 +68,20 @@ void testMemoryRecordSerialization() throws Exception { assertNotNull(deserialized.getTopics()); assertEquals(2, deserialized.getTopics().size()); } - + @Test void testMemoryTypeSerialization() throws Exception { // Test enum serialization assertEquals("\"message\"", objectMapper.writeValueAsString(MemoryType.MESSAGE)); assertEquals("\"semantic\"", objectMapper.writeValueAsString(MemoryType.SEMANTIC)); assertEquals("\"episodic\"", objectMapper.writeValueAsString(MemoryType.EPISODIC)); - + // Test enum deserialization assertEquals(MemoryType.MESSAGE, objectMapper.readValue("\"message\"", MemoryType.class)); assertEquals(MemoryType.SEMANTIC, objectMapper.readValue("\"semantic\"", MemoryType.class)); assertEquals(MemoryType.EPISODIC, objectMapper.readValue("\"episodic\"", MemoryType.class)); } - + @Test void testWorkingMemorySerialization() throws Exception { WorkingMemory memory = new WorkingMemory("session-123"); @@ -89,23 +89,23 @@ void testWorkingMemorySerialization() throws Exception { memory.setNamespace("test-namespace"); memory.setContext("Previous conversation"); memory.setTokens(1000); - + MemoryMessage message = new MemoryMessage("user", "Hello"); memory.getMessages().add(message); - + MemoryRecord record = new MemoryRecord("User said hello"); memory.getMemories().add(record); - + Map data = new HashMap<>(); data.put("key1", "value1"); data.put("key2", 42); memory.setData(data); - + String json = objectMapper.writeValueAsString(memory); assertNotNull(json); assertTrue(json.contains("\"session_id\":\"session-123\"")); assertTrue(json.contains("\"user_id\":\"user-456\"")); - + WorkingMemory deserialized = objectMapper.readValue(json, WorkingMemory.class); assertEquals("session-123", deserialized.getSessionId()); assertEquals("user-456", deserialized.getUserId()); @@ -114,24 +114,24 @@ void testWorkingMemorySerialization() throws Exception { assertNotNull(deserialized.getData()); assertEquals("value1", deserialized.getData().get("key1")); } - + @Test void testInstantSerialization() throws Exception { MemoryMessage message = new MemoryMessage("user", "Test"); Instant now = Instant.now(); message.setCreatedAt(now); - + String json = objectMapper.writeValueAsString(message); assertNotNull(json); - + // Should be in ISO-8601 format, not timestamp assertFalse(json.contains("\"created_at\":" + now.toEpochMilli())); assertTrue(json.contains("\"created_at\":\"")); - + MemoryMessage deserialized = objectMapper.readValue(json, MemoryMessage.class); assertNotNull(deserialized.getCreatedAt()); } - + @Test void testHealthCheckResponseDeserialization() throws Exception { String json = "{\"now\":1705318200.0}"; @@ -140,25 +140,24 @@ void testHealthCheckResponseDeserialization() throws Exception { assertNotNull(response); assertTrue(response.getNow() > 0); } - + @Test void testSessionListResponseDeserialization() throws Exception { String json = "{\"sessions\":[\"session-1\",\"session-2\"],\"total\":2}"; - + SessionListResponse response = objectMapper.readValue(json, SessionListResponse.class); assertNotNull(response); assertEquals(2, response.getTotal()); assertEquals(2, response.getSessions().size()); assertTrue(response.getSessions().contains("session-1")); } - + @Test void testAckResponseDeserialization() throws Exception { String json = "{\"status\":\"ok\"}"; - + AckResponse response = objectMapper.readValue(json, AckResponse.class); assertNotNull(response); assertEquals("ok", response.getStatus()); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java index d766795..84de099 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java @@ -8,11 +8,11 @@ import static org.junit.jupiter.api.Assertions.*; class MemoryRecordTest { - + @Test void testDefaultConstructor() { MemoryRecord record = new MemoryRecord(); - + assertNotNull(record.getId()); assertNotNull(record.getCreatedAt()); assertNotNull(record.getLastAccessed()); @@ -20,32 +20,32 @@ void testDefaultConstructor() { assertEquals("f", record.getDiscreteMemoryExtracted()); assertEquals(MemoryType.MESSAGE, record.getMemoryType()); } - + @Test void testConstructorWithText() { MemoryRecord record = new MemoryRecord("Test memory"); - + assertEquals("Test memory", record.getText()); assertNotNull(record.getId()); } - + @Test void testSettersAndGetters() { MemoryRecord record = new MemoryRecord(); - + record.setText("Test memory"); record.setSessionId("session-123"); record.setUserId("user-456"); record.setNamespace("test-namespace"); - + List topics = Arrays.asList("topic1", "topic2"); record.setTopics(topics); - + List entities = Arrays.asList("entity1", "entity2"); record.setEntities(entities); - + record.setMemoryType(MemoryType.SEMANTIC); - + assertEquals("Test memory", record.getText()); assertEquals("session-123", record.getSessionId()); assertEquals("user-456", record.getUserId()); @@ -55,4 +55,3 @@ void testSettersAndGetters() { assertEquals(MemoryType.SEMANTIC, record.getMemoryType()); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/MemoryMessageTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/MemoryMessageTest.java index d3c4287..0860ea6 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/MemoryMessageTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/MemoryMessageTest.java @@ -5,37 +5,36 @@ import static org.junit.jupiter.api.Assertions.*; class MemoryMessageTest { - + @Test void testDefaultConstructor() { MemoryMessage message = new MemoryMessage(); - + assertNotNull(message.getId()); assertNotNull(message.getCreatedAt()); assertEquals("f", message.getDiscreteMemoryExtracted()); assertNull(message.getPersistedAt()); } - + @Test void testConstructorWithRoleAndContent() { MemoryMessage message = new MemoryMessage("user", "Hello, world!"); - + assertEquals("user", message.getRole()); assertEquals("Hello, world!", message.getContent()); assertNotNull(message.getId()); assertNotNull(message.getCreatedAt()); assertEquals("f", message.getDiscreteMemoryExtracted()); } - + @Test void testSettersAndGetters() { MemoryMessage message = new MemoryMessage(); - + message.setRole("assistant"); message.setContent("Test content"); - + assertEquals("assistant", message.getRole()); assertEquals("Test content", message.getContent()); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryTest.java index 23280be..ee04df9 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/workingmemory/WorkingMemoryTest.java @@ -6,11 +6,11 @@ import static org.junit.jupiter.api.Assertions.*; class WorkingMemoryTest { - + @Test void testDefaultConstructor() { WorkingMemory memory = new WorkingMemory(); - + assertNotNull(memory.getMessages()); assertTrue(memory.getMessages().isEmpty()); assertNotNull(memory.getMemories()); @@ -20,57 +20,57 @@ void testDefaultConstructor() { assertNotNull(memory.getLongTermMemoryStrategy()); assertNotNull(memory.getLastAccessed()); } - + @Test void testConstructorWithSessionId() { WorkingMemory memory = new WorkingMemory("session-123"); - + assertEquals("session-123", memory.getSessionId()); assertNotNull(memory.getMessages()); assertNotNull(memory.getMemories()); } - + @Test void testAddingMessages() { WorkingMemory memory = new WorkingMemory("session-123"); - + MemoryMessage message1 = new MemoryMessage("user", "Hello"); MemoryMessage message2 = new MemoryMessage("assistant", "Hi there!"); - + memory.getMessages().add(message1); memory.getMessages().add(message2); - + assertEquals(2, memory.getMessages().size()); assertEquals("Hello", memory.getMessages().get(0).getContent()); assertEquals("Hi there!", memory.getMessages().get(1).getContent()); } - + @Test void testAddingMemories() { WorkingMemory memory = new WorkingMemory("session-123"); - + MemoryRecord record1 = new MemoryRecord("Memory 1"); MemoryRecord record2 = new MemoryRecord("Memory 2"); - + memory.getMemories().add(record1); memory.getMemories().add(record2); - + assertEquals(2, memory.getMemories().size()); assertEquals("Memory 1", memory.getMemories().get(0).getText()); assertEquals("Memory 2", memory.getMemories().get(1).getText()); } - + @Test void testSettersAndGetters() { WorkingMemory memory = new WorkingMemory(); - + memory.setSessionId("session-456"); memory.setUserId("user-789"); memory.setNamespace("test-namespace"); memory.setContext("Previous conversation summary"); memory.setTokens(1000); memory.setTtlSeconds(3600); - + assertEquals("session-456", memory.getSessionId()); assertEquals("user-789", memory.getUserId()); assertEquals("test-namespace", memory.getNamespace()); @@ -79,4 +79,3 @@ void testSettersAndGetters() { assertEquals(3600, memory.getTtlSeconds()); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/HealthServiceTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/HealthServiceTest.java index 07d7b21..2c33059 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/HealthServiceTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/HealthServiceTest.java @@ -19,11 +19,11 @@ * Tests for HealthService functionality. */ class HealthServiceTest { - + private MockWebServer mockServer; private MemoryAPIClient client; private ObjectMapper objectMapper; - + @BeforeEach void setUp() throws IOException { mockServer = new MockWebServer(); @@ -38,13 +38,13 @@ void setUp() throws IOException { objectMapper.registerModule(new JavaTimeModule()); objectMapper.disable(com.fasterxml.jackson.databind.SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); } - + @AfterEach void tearDown() throws Exception { client.close(); mockServer.shutdown(); } - + @Test void testHealthCheck() throws Exception { // Mock response @@ -68,4 +68,3 @@ void testHealthCheck() throws Exception { assertTrue(request.getPath().contains("/v1/health")); } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java index 921d6e3..4f39540 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java @@ -291,4 +291,3 @@ void testForgetLongTermMemories_MinimalParams() throws Exception { } } - diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/MemoryHydrationServiceTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/MemoryHydrationServiceTest.java index 30861b6..fa0016f 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/MemoryHydrationServiceTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/MemoryHydrationServiceTest.java @@ -19,11 +19,11 @@ * Tests for MemoryHydrationService functionality. */ class MemoryHydrationServiceTest { - + private MockWebServer mockServer; private MemoryAPIClient client; private ObjectMapper objectMapper; - + @BeforeEach void setUp() throws IOException { mockServer = new MockWebServer(); @@ -38,13 +38,13 @@ void setUp() throws IOException { objectMapper.registerModule(new JavaTimeModule()); objectMapper.disable(com.fasterxml.jackson.databind.SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); } - + @AfterEach void tearDown() throws Exception { client.close(); mockServer.shutdown(); } - + @Test void testMemoryPrompt() throws Exception { // Mock response @@ -83,4 +83,3 @@ void testMemoryPrompt() throws Exception { assertTrue(request.getPath().contains("optimize_query=true")); } } - diff --git a/agent-memory-client/agent_memory_client/models.py b/agent-memory-client/agent_memory_client/models.py index 4575ba3..9710e25 100644 --- a/agent-memory-client/agent_memory_client/models.py +++ b/agent-memory-client/agent_memory_client/models.py @@ -27,6 +27,10 @@ "o1", "o1-mini", "o3-mini", + "gpt-5-mini", + "gpt-5-nano", + "gpt-5.1-chat-latest", + "gpt-5.2-chat-latest", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index a7e5934..1f7eb8c 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response from mcp.server.fastmcp.prompts import base from mcp.types import TextContent +from ulid import ULID from agent_memory_server import long_term_memory, working_memory from agent_memory_server.auth import UserInfo, get_current_user @@ -16,6 +17,7 @@ from agent_memory_server.models import ( AckResponse, CreateMemoryRecordRequest, + CreateSummaryViewRequest, EditMemoryRecordRequest, GetSessionsQuery, MemoryMessage, @@ -24,14 +26,30 @@ MemoryRecord, MemoryRecordResultsResponse, ModelNameLiteral, + RunSummaryViewPartitionRequest, + RunSummaryViewRequest, SearchRequest, SessionListResponse, + SummaryView, + SummaryViewPartitionResult, SystemMessage, + Task, + TaskStatusEnum, + TaskTypeEnum, UpdateWorkingMemory, WorkingMemory, WorkingMemoryResponse, ) from agent_memory_server.summarization import _incremental_summary +from agent_memory_server.summary_views import ( + get_summary_view as get_summary_view_config, + list_partition_results, + list_summary_views, + save_partition_result, + save_summary_view, + summarize_partition_for_view, +) +from agent_memory_server.tasks import create_task, get_task from agent_memory_server.utils.redis import get_redis_conn @@ -1074,3 +1092,232 @@ async def memory_prompt( ) return MemoryPromptResponse(messages=_messages) + + +def _validate_summary_view_keys(payload: CreateSummaryViewRequest) -> None: + """Validate group_by and filter keys for a SummaryView. + + For v1 we explicitly restrict these keys to a small, known set so we can + implement execution safely. We also currently only support long-term + memory as the source for SummaryViews. + """ + + if payload.source != "long_term": + raise HTTPException( + status_code=400, + detail=( + "SummaryView.source must be 'long_term' for now; " + "'working_memory' is not yet supported." + ), + ) + + allowed_group_by = {"user_id", "namespace", "session_id", "memory_type"} + allowed_filters = { + "user_id", + "namespace", + "session_id", + "memory_type", + } + + invalid_group = [k for k in payload.group_by if k not in allowed_group_by] + if invalid_group: + raise HTTPException( + status_code=400, + detail=("Unsupported group_by fields: " + ", ".join(sorted(invalid_group))), + ) + + invalid_filters = [k for k in payload.filters if k not in allowed_filters] + if invalid_filters: + raise HTTPException( + status_code=400, + detail=("Unsupported filter fields: " + ", ".join(sorted(invalid_filters))), + ) + + +@router.post("/v1/summary-views", response_model=SummaryView) +async def create_summary_view( + payload: CreateSummaryViewRequest, + current_user: UserInfo = Depends(get_current_user), +): + """Create a new SummaryView configuration. + + The server assigns an ID; the configuration can then be run on-demand or + by background workers. + """ + + _validate_summary_view_keys(payload) + + view = SummaryView( + id=str(ULID()), + name=payload.name, + source=payload.source, + group_by=payload.group_by, + filters=payload.filters, + time_window_days=payload.time_window_days, + continuous=payload.continuous, + prompt=payload.prompt, + model_name=payload.model_name, + ) + + await save_summary_view(view) + return view + + +@router.get("/v1/summary-views", response_model=list[SummaryView]) +async def list_summary_views_endpoint( + current_user: UserInfo = Depends(get_current_user), +): + """List all registered SummaryViews. + + Filtering by source/continuous can be added later if needed. + """ + + return await list_summary_views() + + +@router.get("/v1/summary-views/{view_id}", response_model=SummaryView) +async def get_summary_view( + view_id: str, + current_user: UserInfo = Depends(get_current_user), +): + """Get a SummaryView configuration by ID.""" + + view = await get_summary_view_config(view_id) + if view is None: + raise HTTPException(status_code=404, detail=f"SummaryView {view_id} not found") + return view + + +@router.delete("/v1/summary-views/{view_id}", response_model=AckResponse) +async def delete_summary_view_endpoint( + view_id: str, + current_user: UserInfo = Depends(get_current_user), +): + """Delete a SummaryView configuration. + + Stored partition summaries are left as-is for now. + """ + + from agent_memory_server.summary_views import delete_summary_view + + await delete_summary_view(view_id) + return AckResponse(status="ok") + + +@router.post( + "/v1/summary-views/{view_id}/partitions/run", + response_model=SummaryViewPartitionResult, +) +async def run_summary_view_partition( + view_id: str, + payload: RunSummaryViewPartitionRequest, + current_user: UserInfo = Depends(get_current_user), +): + """Synchronously compute a summary for a single partition of a view. + + For long-term memory views this will query the underlying memories + and run a real summarization. For other sources it currently returns + a placeholder summary. + """ + + view = await get_summary_view_config(view_id) + if view is None: + raise HTTPException(status_code=404, detail=f"SummaryView {view_id} not found") + + # Ensure the provided group keys match the view's group_by definition. + group_keys = set(payload.group.keys()) + expected_keys = set(view.group_by) + if group_keys != expected_keys: + raise HTTPException( + status_code=400, + detail=( + f"group keys {sorted(group_keys)} must exactly match " + f"view.group_by {sorted(expected_keys)}" + ), + ) + + result = await summarize_partition_for_view(view, payload.group) + # Persist the result so it appears in materialized listings. + await save_partition_result(result) + return result + + +@router.get( + "/v1/summary-views/{view_id}/partitions", + response_model=list[SummaryViewPartitionResult], +) +async def list_summary_view_partitions( + view_id: str, + user_id: str | None = None, + namespace: str | None = None, + session_id: str | None = None, + memory_type: str | None = None, + current_user: UserInfo = Depends(get_current_user), +): + """List materialized partition summaries for a SummaryView. + + This does not trigger recomputation; it simply reads stored + SummaryViewPartitionResult entries from Redis. Optional query + parameters filter by group fields when present. + """ + + view = await get_summary_view_config(view_id) + if view is None: + raise HTTPException(status_code=404, detail=f"SummaryView {view_id} not found") + + group_filter: dict[str, str] = {} + if user_id is not None: + group_filter["user_id"] = user_id + if namespace is not None: + group_filter["namespace"] = namespace + if session_id is not None: + group_filter["session_id"] = session_id + if memory_type is not None: + group_filter["memory_type"] = memory_type + + return await list_partition_results(view_id, group_filter or None) + + +@router.post("/v1/summary-views/{view_id}/run", response_model=Task) +async def run_summary_view_full( + view_id: str, + payload: RunSummaryViewRequest, + background_tasks: HybridBackgroundTasks, + current_user: UserInfo = Depends(get_current_user), +): + """Trigger an asynchronous full recompute of all partitions for a view. + + Returns a Task that can be polled for status. The actual work is + performed by a Docket worker running refresh_summary_view. + """ + + view = await get_summary_view_config(view_id) + if view is None: + raise HTTPException(status_code=404, detail=f"SummaryView {view_id} not found") + + task_id = payload.task_id or str(ULID()) + task = Task( + id=task_id, + type=TaskTypeEnum.SUMMARY_VIEW_FULL_RUN, + status=TaskStatusEnum.PENDING, + view_id=view_id, + ) + await create_task(task) + + from agent_memory_server.summary_views import refresh_summary_view + + background_tasks.add_task(refresh_summary_view, view_id=view_id, task_id=task_id) + return task + + +@router.get("/v1/tasks/{task_id}", response_model=Task) +async def get_task_status( + task_id: str, + current_user: UserInfo = Depends(get_current_user), +): + """Get the status of a background Task by ID.""" + + task = await get_task(task_id) + if task is None: + raise HTTPException(status_code=404, detail=f"Task {task_id} not found") + return task diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 83ffcba..933ce4f 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -35,6 +35,18 @@ class ModelConfig(BaseModel): # Model configuration mapping MODEL_CONFIGS = { # OpenAI Models + "gpt-4.1": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4.1", + max_tokens=128000, + embedding_dimensions=1536, + ), + "gpt-4.1-mini": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-4.1-mini", + max_tokens=128000, + embedding_dimensions=1536, + ), "gpt-3.5-turbo": ModelConfig( provider=ModelProvider.OPENAI, name="gpt-3.5-turbo", @@ -90,6 +102,31 @@ class ModelConfig(BaseModel): max_tokens=200000, embedding_dimensions=1536, ), + # GPT-5 family + "gpt-5-mini": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-5-mini", + max_tokens=400000, + embedding_dimensions=1536, + ), + "gpt-5-nano": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-5-nano", + max_tokens=400000, + embedding_dimensions=1536, + ), + "gpt-5.1-chat-latest": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-5.1-chat-latest", + max_tokens=128000, + embedding_dimensions=1536, + ), + "gpt-5.2-chat-latest": ModelConfig( + provider=ModelProvider.OPENAI, + name="gpt-5.2-chat-latest", + max_tokens=128000, + embedding_dimensions=1536, + ), # Embedding models "text-embedding-ada-002": ModelConfig( provider=ModelProvider.OPENAI, @@ -134,7 +171,26 @@ class ModelConfig(BaseModel): max_tokens=200000, embedding_dimensions=1536, ), - # Latest Anthropic Models + # Claude 4.5 family (direct Anthropic API IDs) + "claude-sonnet-4-5-20250929": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-sonnet-4-5-20250929", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-haiku-4-5-20251001": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-haiku-4-5-20251001", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-opus-4-5-20251101": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-opus-4-5-20251101", + max_tokens=200000, + embedding_dimensions=1536, + ), + # Latest Anthropic Models (Claude 3.x family) "claude-3-7-sonnet-20250219": ModelConfig( provider=ModelProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219", @@ -178,6 +234,25 @@ class ModelConfig(BaseModel): max_tokens=200000, embedding_dimensions=1536, ), + # Aliases for Claude 4.5 family + "claude-sonnet-4-5": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-sonnet-4-5-20250929", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-haiku-4-5": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-haiku-4-5-20251001", + max_tokens=200000, + embedding_dimensions=1536, + ), + "claude-opus-4-5": ModelConfig( + provider=ModelProvider.ANTHROPIC, + name="claude-opus-4-5-20251101", + max_tokens=200000, + embedding_dimensions=1536, + ), # AWS Bedrock Embedding Models "amazon.titan-embed-text-v2:0": ModelConfig( provider=ModelProvider.AWS_BEDROCK, diff --git a/agent_memory_server/docket_tasks.py b/agent_memory_server/docket_tasks.py index e3b8cb0..6a14531 100644 --- a/agent_memory_server/docket_tasks.py +++ b/agent_memory_server/docket_tasks.py @@ -21,6 +21,10 @@ update_last_accessed, ) from agent_memory_server.summarization import summarize_session +from agent_memory_server.summary_views import ( + periodic_refresh_summary_views, + refresh_summary_view, +) logger = logging.getLogger(__name__) @@ -38,6 +42,8 @@ forget_long_term_memories, periodic_forget_long_term_memories, update_last_accessed, + refresh_summary_view, + periodic_refresh_summary_views, ] diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index ce6730e..1f6564e 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -8,7 +8,7 @@ from agent_memory_client.models import ClientMemoryRecord from mcp.server.fastmcp.prompts import base -from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent +from mcp.types import TextContent from pydantic import BaseModel, Field, PrivateAttr, model_validator from ulid import ULID @@ -40,6 +40,9 @@ class MemoryTypeEnum(str, Enum): # These should match the keys in MODEL_CONFIGS ModelNameLiteral = Literal[ + # OpenAI chat and reasoning models + "gpt-4.1", + "gpt-4.1-mini", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", @@ -49,9 +52,15 @@ class MemoryTypeEnum(str, Enum): "o1", "o1-mini", "o3-mini", + "gpt-5-mini", + "gpt-5-nano", + "gpt-5.1-chat-latest", + "gpt-5.2-chat-latest", + # OpenAI embedding models "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", + # Anthropic Claude 3.x family "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", @@ -63,6 +72,13 @@ class MemoryTypeEnum(str, Enum): "claude-3-5-sonnet-latest", "claude-3-5-haiku-latest", "claude-3-opus-latest", + # Anthropic Claude 4.5 family (direct API IDs and aliases) + "claude-sonnet-4-5-20250929", + "claude-haiku-4-5-20251001", + "claude-opus-4-5-20251101", + "claude-sonnet-4-5", + "claude-haiku-4-5", + "claude-opus-4-5", ] @@ -781,7 +797,7 @@ class SystemMessage(BaseModel): """A system message""" role: Literal["system"] = "system" - content: str | TextContent | ImageContent | AudioContent | EmbeddedResource + content: str | TextContent class UserMessage(base.Message): @@ -838,3 +854,222 @@ class EditMemoryRecordRequest(BaseModel): event_date: datetime | None = Field( default=None, description="Updated event date for episodic memories" ) + + +class TaskStatusEnum(str, Enum): + """Status values for background tasks exposed to clients.""" + + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + +class TaskTypeEnum(str, Enum): + """Type of background task. + + We start with summary view refreshes but keep this extensible. + """ + + SUMMARY_VIEW_FULL_RUN = "summary_view_full_run" + + +class Task(BaseModel): + """Client-visible background task tracked in Redis as JSON. + + These tasks represent long-running operations such as a full recompute + of all partitions for a SummaryView. + """ + + id: str = Field(description="Unique task identifier (client or server generated)") + type: TaskTypeEnum = Field( + description="Type of task, e.g. summary_view_full_run", + ) + status: TaskStatusEnum = Field( + default=TaskStatusEnum.PENDING, + description="Current task status", + ) + view_id: str | None = Field( + default=None, + description="Associated SummaryView ID, if applicable", + ) + created_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), + description="When the task record was created", + ) + started_at: datetime | None = Field( + default=None, + description="When execution of the task actually started", + ) + completed_at: datetime | None = Field( + default=None, + description="When execution of the task finished (success or failure)", + ) + error_message: str | None = Field( + default=None, + description="Error message if the task failed", + ) + + +class SummaryView(BaseModel): + """Configuration for a summary view over memories. + + A SummaryView fully specifies what pool of memories to summarize and how + to partition and filter them, so it can be run on-demand or by a + background worker without additional runtime parameters. + """ + + id: str = Field(description="Unique identifier for the summary view") + name: str | None = Field( + default=None, + description="Optional human-readable name for the view", + ) + source: Literal["long_term", "working_memory"] = Field( + description=( + "Memory source to summarize. Currently only 'long_term' is " + "supported; 'working_memory' is reserved for future use." + ), + ) + group_by: list[str] = Field( + default_factory=list, + description=( + "Fields used to partition summaries (e.g. ['user_id'], " + "['user_id', 'namespace'])." + ), + ) + filters: dict[str, Any] = Field( + default_factory=dict, + description=( + "Static filters applied to every run (e.g. memory_type, namespace). " + "Only a small, known set of keys is supported in v1." + ), + ) + time_window_days: int | None = Field( + default=None, + ge=1, + description=( + "If set, each run uses now() - time_window_days as a cutoff " + "for eligible memories." + ), + ) + continuous: bool = Field( + default=False, + description=( + "If true, background workers periodically refresh all partitions " + "for this view." + ), + ) + prompt: str | None = Field( + default=None, + description=( + "Optional custom summarization instructions. If omitted, a " + "server-defined default prompt is used." + ), + ) + model_name: str | None = Field( + default=None, + description=( + "Optional model override for summarization. Defaults to a fast " + "model from settings when not provided." + ), + ) + + +class SummaryViewPartitionResult(BaseModel): + """Result of summarizing one partition of a SummaryView. + + A partition is defined by a concrete combination of the view's + group_by fields, e.g. {"user_id": "alice"} or + {"user_id": "alice", "namespace": "chat"}. + """ + + view_id: str = Field(description="ID of the SummaryView that produced this result") + group: dict[str, str] = Field( + description="Concrete values for the view's group_by fields", + ) + summary: str = Field(description="Summarized text for this partition") + memory_count: int = Field( + ge=0, + description="Number of memories that contributed to this summary", + ) + computed_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), + description="When this summary was computed", + ) + + +class CreateSummaryViewRequest(BaseModel): + """Payload for creating a new SummaryView. + + Same fields as SummaryView except for the server-assigned id. + """ + + name: str | None = Field( + default=None, + description="Optional human-readable name for the view", + ) + source: Literal["long_term", "working_memory"] = Field( + description="Memory source to summarize: long-term or working memory", + ) + group_by: list[str] = Field( + default_factory=list, + description=( + "Fields used to partition summaries (e.g. ['user_id'], " + "['user_id', 'namespace'])." + ), + ) + filters: dict[str, Any] = Field( + default_factory=dict, + description=( + "Static filters applied to every run (e.g. memory_type, namespace). " + "Only a small, known set of keys is supported in v1." + ), + ) + time_window_days: int | None = Field( + default=None, + ge=1, + description=( + "If set, each run uses now() - time_window_days as a cutoff " + "for eligible memories." + ), + ) + continuous: bool = Field( + default=False, + description=( + "If true, background workers periodically refresh all partitions " + "for this view." + ), + ) + prompt: str | None = Field( + default=None, + description=( + "Optional custom summarization instructions. If omitted, a " + "server-defined default prompt is used." + ), + ) + model_name: str | None = Field( + default=None, + description=( + "Optional model override for summarization. Defaults to a fast " + "model from settings when not provided." + ), + ) + + +class RunSummaryViewPartitionRequest(BaseModel): + """Request body for running a single partition of a SummaryView.""" + + group: dict[str, str] = Field( + description="Concrete values for this view's group_by fields", + ) + + +class RunSummaryViewRequest(BaseModel): + """Request body for triggering a full SummaryView run as a Task.""" + + task_id: str | None = Field( + default=None, + description=( + "Optional client-provided task ID. If omitted, the server generates one." + ), + ) diff --git a/agent_memory_server/summary_views.py b/agent_memory_server/summary_views.py new file mode 100644 index 0000000..be1bde0 --- /dev/null +++ b/agent_memory_server/summary_views.py @@ -0,0 +1,646 @@ +"""Helpers for SummaryView configs, stored results, and summarization logic. + +This module implements the execution logic for summarizing long-term memory +sources using LLMs, including Redis JSON storage, key conventions, and +partitioned summary management so the API surface is wired end-to-end. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterable +from datetime import UTC, datetime, timedelta +from typing import Any +from urllib.parse import quote, unquote + +import tiktoken +from docket import Perpetual + +from agent_memory_server import long_term_memory +from agent_memory_server.config import settings +from agent_memory_server.filters import ( + CreatedAt, + MemoryType, + Namespace, + SessionId, + UserId, +) +from agent_memory_server.models import ( + MemoryRecord, + SummaryView, + SummaryViewPartitionResult, + TaskStatusEnum, +) +from agent_memory_server.tasks import update_task_status +from agent_memory_server.utils.redis import get_redis_conn + + +logger = logging.getLogger(__name__) + + +_SUMMARY_VIEW_INDEX_KEY = "summary_view:index" + +# Conservative cap on how many memories we inline into a single LLM prompt. +# We still report the full memory_count separately. +_MAX_MEMORIES_FOR_LLM_PROMPT = 200 + + +def _config_key(view_id: str) -> str: + return f"summary_view:{view_id}:config" + + +def _summary_key(view_id: str, partition_key: str) -> str: + return f"summary_view:{view_id}:summary:{partition_key}" + + +def encode_partition_key(group: dict[str, str]) -> str: + """Create a stable key representation from group_by values. + + Keys are sorted alphabetically so the same group always produces the + same identifier. Both keys and values are URL-encoded to handle any + special characters (including '|' and '=') that could otherwise cause + ambiguous or colliding partition keys. + + The resulting string is treated as an opaque key; use + decode_partition_key() if you need to reverse the encoding. + """ + + parts: list[str] = [] + for key in sorted(group.keys()): + encoded_key = quote(key, safe="") + encoded_value = quote(group[key], safe="") + parts.append(f"{encoded_key}={encoded_value}") + return "|".join(parts) + + +def decode_partition_key(partition_key: str) -> dict[str, str]: + """Decode a partition key back into a group dictionary. + + This reverses the URL-encoding applied by encode_partition_key(). + """ + + if not partition_key: + return {} + + result: dict[str, str] = {} + for part in partition_key.split("|"): + if "=" not in part: + continue + encoded_key, encoded_value = part.split("=", 1) + result[unquote(encoded_key)] = unquote(encoded_value) + return result + + +def _matches_group_filter(group: dict[str, str], group_filter: dict[str, str]) -> bool: + return all(group.get(key) == value for key, value in group_filter.items()) + + +async def save_summary_view(view: SummaryView) -> None: + """Persist a SummaryView definition in Redis as JSON and index it.""" + + redis = await get_redis_conn() + await redis.set(_config_key(view.id), view.model_dump_json()) + await redis.sadd(_SUMMARY_VIEW_INDEX_KEY, view.id) + + +async def get_summary_view(view_id: str) -> SummaryView | None: + """Load a SummaryView by ID from Redis JSON storage.""" + + redis = await get_redis_conn() + raw = await redis.get(_config_key(view_id)) + if raw is None: + return None + + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + + try: + return SummaryView.model_validate_json(raw) + except Exception: + logger.exception("Failed to decode SummaryView JSON for %s", view_id) + return None + + +async def list_summary_views() -> list[SummaryView]: + """Return all SummaryViews registered in the index. + + This performs one GET per view ID; acceptable for the current scale. + """ + + redis = await get_redis_conn() + ids: Iterable[bytes] = await redis.smembers(_SUMMARY_VIEW_INDEX_KEY) + views: list[SummaryView] = [] + + for raw_id in ids: + view_id = raw_id.decode("utf-8") if isinstance(raw_id, bytes) else str(raw_id) + view = await get_summary_view(view_id) + if view is not None: + views.append(view) + + return views + + +async def delete_summary_view(view_id: str) -> None: + """Delete a SummaryView config and remove it from the index. + + Stored partition summaries are left as-is for now; they can be cleaned + up in a later pass if needed. + """ + + redis = await get_redis_conn() + await redis.delete(_config_key(view_id)) + await redis.srem(_SUMMARY_VIEW_INDEX_KEY, view_id) + + +async def save_partition_result(result: SummaryViewPartitionResult) -> None: + """Persist a single partition result for a SummaryView.""" + + redis = await get_redis_conn() + partition_key = encode_partition_key(result.group) + await redis.set( + _summary_key(result.view_id, partition_key), result.model_dump_json() + ) + + +async def list_partition_results( + view_id: str, group_filter: dict[str, str] | None = None +) -> list[SummaryViewPartitionResult]: + """List stored partition results for a view, optionally filtered by group. + + This reads whatever has been materialized so far; it does not trigger + recomputation. + """ + + redis = await get_redis_conn() + pattern = _summary_key(view_id, "*") + results: list[SummaryViewPartitionResult] = [] + + async for key in redis.scan_iter(match=pattern): + raw = await redis.get(key) + if raw is None: + continue + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + try: + result = SummaryViewPartitionResult.model_validate_json(raw) + except Exception: + logger.exception( + "Failed to decode SummaryViewPartitionResult for key %s", key + ) + continue + + if group_filter and not _matches_group_filter(result.group, group_filter): + continue + + results.append(result) + + return results + + +def _build_long_term_filters_for_view( + view: SummaryView, + extra_group: dict[str, str] | None = None, +) -> dict[str, Any]: + """Build keyword arguments for search_long_term_memories. + + Maps SummaryView.filters and optionally a concrete group dict into + typed filter objects used by long_term_memory.search_long_term_memories. + """ + + filters: dict[str, Any] = {} + + def _apply_filter(key: str, value: str | Any) -> None: + """Apply a single filter mapping from a raw key/value pair. + + Both static view.filters and extra_group values are coerced to str + for consistency. + """ + + if key == "user_id": + filters["user_id"] = UserId(eq=str(value)) + elif key == "namespace": + filters["namespace"] = Namespace(eq=str(value)) + elif key == "session_id": + filters["session_id"] = SessionId(eq=str(value)) + elif key == "memory_type": + filters["memory_type"] = MemoryType(eq=str(value)) + + # Static filters from the view config + for key, value in view.filters.items(): + _apply_filter(key, value) + + # Group-specific filters + if extra_group: + for key, value in extra_group.items(): + _apply_filter(key, value) + + # Time window: apply to created_at for now + if view.time_window_days is not None and view.time_window_days > 0: + cutoff = datetime.now(UTC) - timedelta(days=view.time_window_days) + filters["created_at"] = CreatedAt(gte=cutoff) + + return filters + + +async def _fetch_long_term_memories_for_view( + view: SummaryView, + extra_group: dict[str, str] | None = None, + page_size: int = 1000, + overall_limit: int | None = None, +) -> list[MemoryRecord]: + """Fetch long-term memories matching a SummaryView and optional group. + + Uses the filter-only listing path of search_long_term_memories by + providing an empty text query and paginating through results. + + If overall_limit is provided, it serves as an upper bound on the total + number of memories returned; otherwise, all available pages are fetched. + """ + + if page_size <= 0: + raise ValueError("page_size must be positive") + + filters = _build_long_term_filters_for_view(view, extra_group) + + memories: list[MemoryRecord] = [] + offset = 0 + + while True: + # Respect an overall cap if provided + if overall_limit is not None: + remaining = overall_limit - len(memories) + if remaining <= 0: + break + current_limit = min(page_size, remaining) + else: + current_limit = page_size + + results = await long_term_memory.search_long_term_memories( + text="", + limit=current_limit, + offset=offset, + **filters, + ) + batch = list(results.memories) + if not batch: + break + + memories.extend(batch) + + # If fewer results than requested were returned, we've reached the + # end of the result set. + if len(batch) < current_limit: + break + + offset += len(batch) + + # If we applied an overall limit, enforce it defensively here too. + if overall_limit is not None and len(memories) > overall_limit: + memories = memories[:overall_limit] + + return memories + + +def _partition_memories_by_group( + view: SummaryView, memories: list[MemoryRecord] +) -> dict[tuple[tuple[str, str], ...], list[MemoryRecord]]: + """Group memories into partitions based on view.group_by fields. + + Returns a mapping from a stable tuple key to a list of MemoryRecord. + The key is a sorted tuple of (field, value) pairs. + """ + + partitions: dict[tuple[tuple[str, str], ...], list[MemoryRecord]] = {} + for mem in memories: + group_dict: dict[str, str] = {} + for field in view.group_by: + value = getattr(mem, field, None) + if value is None: + # If a grouping field is missing for this memory, skip it + break + group_dict[field] = str(value) + else: + # Only executed if the inner loop did not break + key = tuple(sorted(group_dict.items())) + partitions.setdefault(key, []).append(mem) + + return partitions + + +def _build_long_term_summary_prompt( + view: SummaryView, + group: dict[str, str], + memories: list[MemoryRecord], + model_name: str, + instructions: str, +) -> str: + """Build a token-aware prompt for long-term memory summarization. + + Uses tiktoken and the model's configured context window to truncate the + inlined memories so the prompt stays within a safe fraction of the + context limit while still leaving room for the model's response. + """ + + # Import here to avoid circular imports at module load time. + from agent_memory_server.llms import get_model_config + + encoding = tiktoken.get_encoding("cl100k_base") + + model_config = get_model_config(model_name) + full_context_tokens = max(model_config.max_tokens, 1) + + # Use the same summarization_threshold knob as working-memory + # summarization to control how much of the context window we devote + # to the prompt itself. + prompt_budget = int(full_context_tokens * settings.summarization_threshold) + + # Reserve some space for the model's response and any overhead. + reserved_completion_tokens = min(4096, full_context_tokens // 10) + max_prompt_tokens = max(prompt_budget - reserved_completion_tokens, 1024) + + base_prefix = ( + f"{instructions}\n\n" + f"GROUP: {json.dumps(group, sort_keys=True)}\n\n" + "MEMORIES:\n" + ) + base_tokens = len(encoding.encode(base_prefix)) + + remaining_tokens = max_prompt_tokens - base_tokens + if remaining_tokens <= 0: + return ( + base_prefix + + "[Memories omitted due to token budget constraints.]\n\nSUMMARY:" + ) + + # Cap the size of each individual memory text we inline so that a + # single extremely long memory cannot dominate the prompt. + max_bullet_tokens = min(1024, full_context_tokens // 20) + + bullet_lines: list[str] = [] + for mem in memories[:_MAX_MEMORIES_FOR_LLM_PROMPT]: + text = mem.text or "" + bullet = f"- {text}" + bullet_tokens = len(encoding.encode(bullet)) + + if bullet_tokens > max_bullet_tokens: + # Roughly truncate very long memories by characters, then + # recompute tokens. This mirrors the approach used in + # agent_memory_server.summarization. + approx_chars = max_bullet_tokens * 4 + text = text[:approx_chars] + bullet = f"- {text}" + bullet_tokens = len(encoding.encode(bullet)) + + if bullet_tokens > remaining_tokens: + break + + bullet_lines.append(bullet) + remaining_tokens -= bullet_tokens + + memories_text = "\n".join(bullet_lines) + total_memories = len(memories) + used_memories = len(bullet_lines) + if total_memories > used_memories: + memories_text += ( + f"\n\n[Memories truncated to fit token budget: used {used_memories} " + f"of {total_memories} entries]" + ) + + return f"{base_prefix}{memories_text}\n\nSUMMARY:" + + +async def summarize_partition_long_term( + view: SummaryView, + group: dict[str, str], + memories: list[MemoryRecord], +) -> SummaryViewPartitionResult: + """Summarize a partition of long-term memories. + + For now we keep the prompt simple and use a single chat completion + call with a textual join of memory texts. + """ + + if not memories: + summary_text = f"No memories found for group {group!r}." + return SummaryViewPartitionResult( + view_id=view.id, + group=group, + summary=summary_text, + memory_count=0, + computed_at=datetime.now(UTC), + ) + + # If no LLM credentials are configured, fall back to a simple + # deterministic summary that just concatenates memory texts. + if not ( + settings.openai_api_key + or settings.anthropic_api_key + or settings.aws_access_key_id + ): + joined = "\n".join(f"- {m.text}" for m in memories[:50]) + summary_text = ( + "LLM summarization disabled (no API keys configured). " + "Concatenated up to 50 memories:\n" + joined + ) + else: + from agent_memory_server.llms import get_model_client + + model_name = view.model_name or settings.fast_model + client = await get_model_client(model_name) + + # Build a prompt using either the view's prompt or a default, then + # construct a token-aware memories section based on the model's + # configured context window. + default_instructions = ( + "You are a summarization assistant. Given a set of long-term " + "memories, produce a concise summary that highlights key facts, " + "stable preferences, and important events relevant to the group." + ) + instructions = view.prompt or default_instructions + + prompt = _build_long_term_summary_prompt( + view=view, + group=group, + memories=memories, + model_name=model_name, + instructions=instructions, + ) + + # We use the same interface pattern as other summarization helpers, + # but add minimal defensive checks around the response structure. + try: + response = await client.create_chat_completion(model_name, prompt) + choices = getattr(response, "choices", None) or [] + content = None + if choices: + first_message = getattr(choices[0], "message", None) + content = getattr(first_message, "content", None) + except Exception: + logger.exception( + "Error calling summarization model %s for SummaryView %s group %r", + model_name, + view.id, + group, + ) + content = None + + if not content: + logger.warning( + "Summarization model %s returned empty response for SummaryView %s " + "group %r; using fallback text.", + model_name, + view.id, + group, + ) + summary_text = "No summary could be generated for this partition." + else: + summary_text = content + + return SummaryViewPartitionResult( + view_id=view.id, + group=group, + summary=summary_text, + memory_count=len(memories), + computed_at=datetime.now(UTC), + ) + + +async def summarize_partition_placeholder( + view: SummaryView, group: dict[str, str] +) -> SummaryViewPartitionResult: + """Fallback placeholder summary for unsupported sources. + + Used currently for SummaryViews whose source is not yet implemented. + """ + + summary_text = f"Placeholder summary for view {view.id} with group {group!r}." + return SummaryViewPartitionResult( + view_id=view.id, + group=group, + summary=summary_text, + memory_count=0, + computed_at=datetime.now(UTC), + ) + + +async def summarize_partition_for_view( + view: SummaryView, group: dict[str, str] +) -> SummaryViewPartitionResult: + """High-level entry point to summarize a single partition for a view. + + Dispatches to the appropriate implementation based on view.source. + """ + + if view.source == "long_term": + memories = await _fetch_long_term_memories_for_view(view, extra_group=group) + return await summarize_partition_long_term(view, group, memories) + + # Fallback for sources we haven't implemented yet. + return await summarize_partition_placeholder(view, group) + + +async def refresh_summary_view(view_id: str, task_id: str | None = None) -> None: + """Docket task to recompute all partitions for a SummaryView. + + For long-term memory sources, this will fetch memories matching the + view filters, partition them by group_by, summarize each partition, + and store SummaryViewPartitionResult entries. + """ + + view = await get_summary_view(view_id) + now = datetime.now(UTC) + + if task_id is not None: + if view is None: + await update_task_status( + task_id, + status=TaskStatusEnum.FAILED, + completed_at=now, + error_message=f"SummaryView {view_id} not found", + ) + return + + await update_task_status( + task_id, + status=TaskStatusEnum.RUNNING, + started_at=now, + ) + + if view is None: + # Nothing to do; already handled task status above if needed. + return + + # Threshold above which we log a warning about large memory sets. + # This helps operators identify views that may benefit from tighter + # filters or time windows. + large_memory_threshold = 5000 + + try: + if view.source == "long_term": + # Fetch all relevant memories and partition them. + memories = await _fetch_long_term_memories_for_view(view) + + if len(memories) >= large_memory_threshold: + logger.warning( + "refresh_summary_view: fetched %d memories for view %s; " + "consider adding filters or a time_window_days to reduce volume.", + len(memories), + view.id, + ) + + partitions = _partition_memories_by_group(view, memories) + + for key, mems in partitions.items(): + group = dict(key) + result = await summarize_partition_long_term(view, group, mems) + await save_partition_result(result) + else: + # For unsupported sources, we currently do nothing. + logger.info( + "refresh_summary_view: source %s not yet implemented for view %s", + view.source, + view.id, + ) + + if task_id is not None: + await update_task_status( + task_id, + status=TaskStatusEnum.SUCCESS, + completed_at=datetime.now(UTC), + ) + except Exception as exc: # noqa: BLE001 + # We deliberately catch all exceptions here so that background workers + # never crash silently and any failure is reflected in the Task record + # as FAILED. The original error is logged with traceback above. + logger.exception("Error refreshing SummaryView %s", view_id) + if task_id is not None: + await update_task_status( + task_id, + status=TaskStatusEnum.FAILED, + completed_at=datetime.now(UTC), + error_message=str(exc), + ) + + +async def periodic_refresh_summary_views( + perpetual: Perpetual = Perpetual( + every=timedelta(minutes=60), + automatic=True, + ), +) -> None: + """Periodic Docket task to refresh all continuous SummaryViews. + + Uses the same refresh_summary_view helper but without task tracking. + """ + + if not settings.long_term_memory: + # If long-term memory is entirely disabled, there may still be + # working-memory backed views later, but for now we bail out. + return + + views = await list_summary_views() + for view in views: + if not view.continuous: + continue + await refresh_summary_view(view.id, task_id=None) diff --git a/agent_memory_server/tasks.py b/agent_memory_server/tasks.py new file mode 100644 index 0000000..eb5b29d --- /dev/null +++ b/agent_memory_server/tasks.py @@ -0,0 +1,99 @@ +import logging +from datetime import UTC, datetime + +from agent_memory_server.models import Task, TaskStatusEnum +from agent_memory_server.utils.redis import get_redis_conn + + +logger = logging.getLogger(__name__) + + +# Tasks are operational metadata; we don't need to retain them forever. +# Use a conservative TTL so Redis state cannot grow without bound. +_TASK_TTL_SECONDS = 7 * 24 * 60 * 60 # 7 days + + +def _task_key(task_id: str) -> str: + """Return the Redis key for a task JSON payload.""" + + return f"task:{task_id}" + + +async def create_task(task: Task) -> None: + """Persist a new Task as JSON in Redis. + + This overwrites any existing task with the same ID. + """ + + redis = await get_redis_conn() + await redis.set( + _task_key(task.id), + task.model_dump_json(), + ex=_TASK_TTL_SECONDS, + ) + + +async def get_task(task_id: str) -> Task | None: + """Load a Task from Redis JSON storage. + + Returns None if the task does not exist. + """ + + redis = await get_redis_conn() + raw = await redis.get(_task_key(task_id)) + if raw is None: + return None + + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + + try: + return Task.model_validate_json(raw) + except Exception: + logger.exception("Failed to decode task JSON for %s", task_id) + return None + + +async def update_task_status( + task_id: str, + *, + status: TaskStatusEnum | None = None, + started_at: datetime | None = None, + completed_at: datetime | None = None, + error_message: str | None = None, +) -> None: + """Update status and timestamps for an existing Task. + + If the task does not exist, this is a no-op. + """ + + redis = await get_redis_conn() + key = _task_key(task_id) + raw = await redis.get(key) + if raw is None: + logger.warning("Attempted to update missing task %s", task_id) + return + + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + + try: + task = Task.model_validate_json(raw) + except Exception: + logger.exception("Failed to decode task JSON for %s during update", task_id) + return + + if status is not None: + task.status = status + if started_at is not None: + task.started_at = started_at + if completed_at is not None: + task.completed_at = completed_at + if error_message is not None: + task.error_message = error_message + + # Ensure created_at is always set + if task.created_at is None: + task.created_at = datetime.now(UTC) + + await redis.set(key, task.model_dump_json(), ex=_TASK_TTL_SECONDS) diff --git a/docs/configuration.md b/docs/configuration.md index 4b2524f..352fe6c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -207,8 +207,12 @@ uv run agent-memory task-worker --concurrency 5 --redelivery-timeout 60 ## Supported Models ### Generation Models (OpenAI) -- `gpt-4o` - Latest GPT-4 Optimized (recommended) -- `gpt-4o-mini` - Faster, smaller GPT-4 (good for fast_model) +- `gpt-5.2-chat-latest` - Latest GPT-5.2 Chat snapshot used in ChatGPT (recommended when available) +- `gpt-5.1-chat-latest` - GPT-5.1 Chat snapshot (fast, chat-optimized) +- `gpt-5-mini` - Smaller GPT-5 model (good candidate for `FAST_MODEL`) +- `gpt-5-nano` - Smallest GPT-5 model (ultra fast, cost efficient) +- `gpt-4o` - GPT-4 Optimized (default in this project) +- `gpt-4o-mini` - Faster, smaller GPT-4 (good for `FAST_MODEL`) - `gpt-4` - Previous GPT-4 version - `gpt-3.5-turbo` - Older, faster model - `o1` - OpenAI o1 reasoning model diff --git a/docs/getting-started.md b/docs/getting-started.md index 0ad7be1..ba20f5d 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -69,7 +69,7 @@ When configuring MCP-enabled apps (e.g., Claude Desktop), prefer `uvx` so the ap ``` Notes: -- API keys: Default models use OpenAI. Set `OPENAI_API_KEY`. To use Anthropic instead, set `ANTHROPIC_API_KEY` and also `GENERATION_MODEL` to an Anthropic model (e.g., `claude-3-5-haiku-20241022`). +- API keys: Default models use OpenAI. Set `OPENAI_API_KEY`. To use Anthropic instead, set `ANTHROPIC_API_KEY` and also `GENERATION_MODEL` to an Anthropic model (e.g., `claude-3-5-haiku-20241022`). If you have access to GPT-5 models, you can instead set `GENERATION_MODEL` to `gpt-5.2-chat-latest`, `gpt-5.1-chat-latest`, `gpt-5-mini`, or `gpt-5-nano`. - Make sure your MCP host can find `uvx` (on its PATH or by using an absolute command path). macOS: `brew install uv`. If not on PATH, set `"command"` to an absolute path (e.g., `/opt/homebrew/bin/uvx` on Apple Silicon, `/usr/local/bin/uvx` on Intel macOS). - For production, remove `DISABLE_AUTH` and configure auth. diff --git a/docs/query-optimization.md b/docs/query-optimization.md index 3bac790..1514463 100644 --- a/docs/query-optimization.md +++ b/docs/query-optimization.md @@ -64,10 +64,11 @@ QUERY_OPTIMIZATION_MODEL=gpt-4o-mini # Use a more powerful model for memory extraction and other tasks GENERATION_MODEL=gpt-4o -# Supported models include: -# - gpt-4o, gpt-4o-mini -# - claude-3-5-sonnet-20241022, claude-3-haiku-20240307 -# - Any model supported by your LLM provider + # Supported models include: + # - gpt-5.2-chat-latest, gpt-5.1-chat-latest, gpt-5-mini, gpt-5-nano + # - gpt-4o, gpt-4o-mini + # - claude-3-5-sonnet-20241022, claude-3-haiku-20240307 + # - Any model supported by your LLM provider ``` ## Usage Examples diff --git a/tests/conftest.py b/tests/conftest.py index b4cf706..b97cde0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -332,13 +332,19 @@ def patched_docket_init(self, name, url=None, *args, **kwargs): import agent_memory_server.vectorstore_factory with ( + # Core Redis helper patch("agent_memory_server.utils.redis.get_redis_conn", mock_get_redis_conn), - patch("docket.docket.Docket.__init__", patched_docket_init), - patch("agent_memory_server.working_memory.get_redis_conn", mock_get_redis_conn), + # Modules that imported get_redis_conn directly must also be patched patch("agent_memory_server.api.get_redis_conn", mock_get_redis_conn), + patch("agent_memory_server.working_memory.get_redis_conn", mock_get_redis_conn), patch( "agent_memory_server.long_term_memory.get_redis_conn", mock_get_redis_conn ), + patch("agent_memory_server.summary_views.get_redis_conn", mock_get_redis_conn), + patch("agent_memory_server.tasks.get_redis_conn", mock_get_redis_conn), + # Ensure Docket uses the test Redis URL + patch("docket.docket.Docket.__init__", patched_docket_init), + # Point settings.redis_url at the testcontainer Redis patch.object(settings, "redis_url", redis_url), ): # Reset global state to force recreation with test Redis diff --git a/tests/test_summary_views.py b/tests/test_summary_views.py new file mode 100644 index 0000000..4f2e7f8 --- /dev/null +++ b/tests/test_summary_views.py @@ -0,0 +1,278 @@ +import pytest + +from agent_memory_server.models import MemoryRecord, SummaryView, TaskStatusEnum + + +@pytest.mark.asyncio +async def test_create_and_get_summary_view(client): + # Create a summary view + payload = { + "name": "ltm_by_user_30d", + "source": "long_term", + "group_by": ["user_id"], + "filters": {"memory_type": "semantic"}, + "time_window_days": 30, + "continuous": False, + "prompt": None, + "model_name": None, + } + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 200, resp.text + view = resp.json() + view_id = view["id"] + + # Fetch it back + resp_get = await client.get(f"/v1/summary-views/{view_id}") + assert resp_get.status_code == 200 + fetched = resp_get.json() + assert fetched["id"] == view_id + assert fetched["group_by"] == ["user_id"] + + +@pytest.mark.asyncio +async def test_create_summary_view_rejects_invalid_keys(client): + """SummaryView creation should reject unsupported group_by / filter keys.""" + + payload = { + "name": "invalid_keys_view", + "source": "long_term", + # "invalid" is not in the allowed group_by set + "group_by": ["user_id", "invalid"], + "filters": {"memory_type": "semantic"}, + "time_window_days": 30, + "continuous": False, + "prompt": None, + "model_name": None, + } + + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 400 + data = resp.json() + assert "Unsupported group_by fields" in data["detail"] + + +@pytest.mark.asyncio +async def test_run_single_partition_and_list_partitions(client): + # Create a simple view grouped by user_id + payload = { + "name": "ltm_by_user", + "source": "long_term", + "group_by": ["user_id"], + "filters": {}, + "time_window_days": None, + "continuous": False, + "prompt": None, + "model_name": None, + } + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 200, resp.text + view_id = resp.json()["id"] + + # Run a single partition synchronously + run_payload = {"group": {"user_id": "alice"}} + resp_run = await client.post( + f"/v1/summary-views/{view_id}/partitions/run", json=run_payload + ) + assert resp_run.status_code == 200, resp_run.text + result = resp_run.json() + assert result["group"] == {"user_id": "alice"} + assert "summary" in result + + # List materialized partitions + resp_list = await client.get( + f"/v1/summary-views/{view_id}/partitions", params={"user_id": "alice"} + ) + assert resp_list.status_code == 200 + partitions = resp_list.json() + assert len(partitions) == 1 + assert partitions[0]["group"]["user_id"] == "alice" + + +@pytest.mark.asyncio +async def test_delete_summary_view_removes_it_from_get_and_list(client): + """Deleting a SummaryView should remove it from retrieval and listings.""" + + # Create a view we can delete + payload = { + "name": "ltm_to_delete", + "source": "long_term", + "group_by": ["user_id"], + "filters": {}, + "time_window_days": None, + "continuous": False, + "prompt": None, + "model_name": None, + } + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 200, resp.text + view_id = resp.json()["id"] + + # Ensure it appears in the list + list_before = await client.get("/v1/summary-views") + assert list_before.status_code == 200 + ids_before = {v["id"] for v in list_before.json()} + assert view_id in ids_before + + # Delete the view + resp_delete = await client.delete(f"/v1/summary-views/{view_id}") + assert resp_delete.status_code == 200, resp_delete.text + + # GET should now return 404 + resp_get = await client.get(f"/v1/summary-views/{view_id}") + assert resp_get.status_code == 404 + + # And it should no longer appear in the list + list_after = await client.get("/v1/summary-views") + assert list_after.status_code == 200 + ids_after = {v["id"] for v in list_after.json()} + assert view_id not in ids_after + + +@pytest.mark.asyncio +async def test_run_full_view_creates_task_and_updates_status(client): + # Create a summary view + payload = { + "name": "ltm_full_run", + "source": "long_term", + "group_by": ["user_id"], + "filters": {}, + "time_window_days": None, + "continuous": False, + "prompt": None, + "model_name": None, + } + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 200, resp.text + view_id = resp.json()["id"] + + # Trigger a full run + resp_run = await client.post(f"/v1/summary-views/{view_id}/run", json={}) + assert resp_run.status_code == 200, resp_run.text + task = resp_run.json() + task_id = task["id"] + + # Poll the task status via the API. We intentionally do not wait for the + # background Docket worker here; the goal is to verify that the Task is + # created and visible through the status endpoint, not that the worker + # has actually completed the refresh. + resp_task = await client.get(f"/v1/tasks/{task_id}") + assert resp_task.status_code == 200 + polled = resp_task.json() + assert polled["status"] in { + TaskStatusEnum.PENDING, + TaskStatusEnum.RUNNING, + TaskStatusEnum.SUCCESS, + } + + +@pytest.mark.asyncio +async def test_fetch_long_term_memories_for_view_paginates(monkeypatch): + """_fetch_long_term_memories_for_view should paginate through results. + + We monkeypatch long_term_memory.search_long_term_memories to return + deterministic pages and verify that multiple calls are made when the + number of results exceeds the configured page_size. + """ + + from agent_memory_server import summary_views + + calls: list[tuple[int, int]] = [] + + class FakeResults: + def __init__(self, memories: list[MemoryRecord]): + self.memories = memories + + async def fake_search_long_term_memories( + *, text: str, limit: int, offset: int, **_: object + ): # type: ignore[override] + # Record the (limit, offset) pair for assertions. + calls.append((limit, offset)) + + # Pretend we have 2500 total memories; each page returns `limit` + # until we reach that total. + total = 2500 + remaining = max(total - offset, 0) + batch_size = min(limit, remaining) + + memories = [ + MemoryRecord( + id=f"mem-{offset + i}", + text=f"memory {offset + i}", + session_id=None, + user_id=None, + namespace=None, + ) + for i in range(batch_size) + ] + return FakeResults(memories) + + monkeypatch.setattr( + summary_views.long_term_memory, + "search_long_term_memories", + fake_search_long_term_memories, + ) + + view = SummaryView( + id="view-1", + name="test", + source="long_term", + group_by=["user_id"], + filters={}, + time_window_days=None, + continuous=False, + prompt=None, + model_name=None, + ) + + # Use a small page_size so multiple pages are required; also set + # an overall_limit below the total so we exercise that branch. + memories = await summary_views._fetch_long_term_memories_for_view( + view, + extra_group=None, + page_size=1000, + overall_limit=2100, + ) + + # We should have respected the overall_limit. + assert len(memories) == 2100 + + # And we should have made at least two paginated calls with advancing + # offsets. + assert calls[0] == (1000, 0) + assert calls[1] == (1000, 1000) + # The final page only needs 100 records to reach 2100. + assert calls[2] == (100, 2000) + + +def test_encode_partition_key_handles_special_characters(): + """encode_partition_key should URL-encode special characters in values.""" + + from agent_memory_server.summary_views import ( + decode_partition_key, + encode_partition_key, + ) + + # Values containing the delimiter characters '|' and '=' + group = {"user_id": "alice|bob", "namespace": "test=value"} + + encoded = encode_partition_key(group) + + # The encoded key should not have raw '|' or '=' from values + # (keys are sorted alphabetically, so namespace comes first) + assert "alice%7Cbob" in encoded # %7C is URL-encoded '|' + assert "test%3Dvalue" in encoded # %3D is URL-encoded '=' + + # Decoding should restore the original values + decoded = decode_partition_key(encoded) + assert decoded == group + + +def test_encode_partition_key_is_stable(): + """encode_partition_key should produce the same key regardless of dict order.""" + + from agent_memory_server.summary_views import encode_partition_key + + group1 = {"user_id": "alice", "namespace": "chat"} + group2 = {"namespace": "chat", "user_id": "alice"} + + assert encode_partition_key(group1) == encode_partition_key(group2) diff --git a/tests/test_tasks.py b/tests/test_tasks.py new file mode 100644 index 0000000..560c6d9 --- /dev/null +++ b/tests/test_tasks.py @@ -0,0 +1,52 @@ +import pytest + +from agent_memory_server.models import TaskStatusEnum + + +@pytest.mark.asyncio +async def test_task_lifecycle_via_api(client): + """Basic sanity check for Task creation and retrieval via the API. + + This verifies that: + - POST /v1/summary-views/{id}/run creates a Task + - GET /v1/tasks/{task_id} returns that Task with the expected ID and type + """ + + # Create a minimal summary view we can run + payload = { + "name": "task_lifecycle_test_view", + "source": "long_term", + "group_by": ["user_id"], + "filters": {}, + "time_window_days": None, + "continuous": False, + "prompt": None, + "model_name": None, + } + resp = await client.post("/v1/summary-views", json=payload) + assert resp.status_code == 200, resp.text + view_id = resp.json()["id"] + + # Trigger a full run to create a Task + resp_run = await client.post(f"/v1/summary-views/{view_id}/run", json={}) + assert resp_run.status_code == 200, resp_run.text + task = resp_run.json() + + assert task["id"] + assert task["view_id"] == view_id + assert task["status"] == TaskStatusEnum.PENDING + + task_id = task["id"] + + # Fetch the task via the task status endpoint + resp_task = await client.get(f"/v1/tasks/{task_id}") + assert resp_task.status_code == 200, resp_task.text + polled = resp_task.json() + + assert polled["id"] == task_id + assert polled["view_id"] == view_id + assert polled["status"] in { + TaskStatusEnum.PENDING, + TaskStatusEnum.RUNNING, + TaskStatusEnum.SUCCESS, + }