Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,54 @@
*/
public class JdbcChatMemoryRepository implements ChatMemoryRepository {

private static final String QUERY_GET_IDS = """
SELECT DISTINCT conversation_id FROM ai_chat_memory
""";
private final String queryGetIds;

private static final String QUERY_ADD = """
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)
""";
private final String queryAdd;

private static final String QUERY_GET = """
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp"
""";
private final String queryGet;

private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
private final String queryClear;

private final JdbcTemplate jdbcTemplate;

private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) {
public final static String DEFAULT_TABLE_NAME = "ai_chat_memory";

public final static String DEFAULT_CONVERSION_ID_FIELD_NAME = "conversation_id";

public final static String DEFAULT_CONTENT_FIELD_NAME = "content";

public final static String DEFAULT_TYPE_FIELD_NAME = "type";

public final static String DEFAULT_TIMESTAMP_FIELD_NAME = "\"timestamp\"";

public final static String DEFAULT_GET_IDS_QUERY = "SELECT DISTINCT %s FROM %s";

public final static String DEFAULT_ADD_QUERY = "INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)";

public final static String DEFAULT_GET_QUERY = "SELECT %s, %s FROM %s WHERE %s = ? ORDER BY %s";

public final static String DEFAULT_CLEAR_QUERY = "DELETE FROM %s WHERE %s = ?";

private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, String tableName, String conversionIdFiledName,
String contentFiledName, String typeFiledName, String timestampFiledName) {
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
Assert.notNull(tableName, "tableName cannot be null");
Assert.notNull(conversionIdFiledName, "conversionIdFiledName cannot be null");
Assert.notNull(contentFiledName, "contentFiledName cannot be null");
Assert.notNull(typeFiledName, "typeFiledName cannot be null");
Assert.notNull(timestampFiledName, "timestampFiledName cannot be null");
this.jdbcTemplate = jdbcTemplate;
this.queryGetIds = DEFAULT_GET_IDS_QUERY.formatted(conversionIdFiledName, tableName);
this.queryAdd = DEFAULT_ADD_QUERY.formatted(tableName, conversionIdFiledName, contentFiledName, typeFiledName,
timestampFiledName);
this.queryGet = DEFAULT_GET_QUERY.formatted(contentFiledName, typeFiledName, tableName, conversionIdFiledName,
timestampFiledName);
this.queryClear = DEFAULT_CLEAR_QUERY.formatted(tableName, conversionIdFiledName);
}

@Override
public List<String> findConversationIds() {
List<String> conversationIds = this.jdbcTemplate.query(QUERY_GET_IDS, rs -> {
List<String> conversationIds = this.jdbcTemplate.query(queryGetIds, rs -> {
var ids = new ArrayList<String>();
while (rs.next()) {
ids.add(rs.getString(1));
Expand All @@ -84,7 +108,7 @@ public List<String> findConversationIds() {
@Override
public List<Message> findByConversationId(String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId);
return this.jdbcTemplate.query(queryGet, new MessageRowMapper(), conversationId);
}

@Override
Expand All @@ -93,13 +117,13 @@ public void saveAll(String conversationId, List<Message> messages) {
Assert.notNull(messages, "messages cannot be null");
Assert.noNullElements(messages, "messages cannot contain null elements");
this.deleteByConversationId(conversationId);
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
this.jdbcTemplate.batchUpdate(queryAdd, new AddBatchPreparedStatement(conversationId, messages));
}

@Override
public void deleteByConversationId(String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
this.jdbcTemplate.update(queryClear, conversationId);
}

private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
Expand Down Expand Up @@ -154,6 +178,16 @@ public static class Builder {

private JdbcTemplate jdbcTemplate;

private String tableName = DEFAULT_TABLE_NAME;

private String conversionIdFiledName = DEFAULT_CONVERSION_ID_FIELD_NAME;

private String contentFiledName = DEFAULT_CONTENT_FIELD_NAME;

private String typeFiledName = DEFAULT_TYPE_FIELD_NAME;

private String timestampFiledName = DEFAULT_TIMESTAMP_FIELD_NAME;

private Builder() {
}

Expand All @@ -162,8 +196,34 @@ public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
return this;
}

public Builder tableName(String tableName) {
this.tableName = tableName;
return this;
}

public Builder conversionIdFiledName(String conversionIdFiledName) {
this.conversionIdFiledName = conversionIdFiledName;
return this;
}

public Builder contentFiledName(String contentFiledName) {
this.contentFiledName = contentFiledName;
return this;
}

public Builder typeFiledName(String typeFiledName) {
this.typeFiledName = typeFiledName;
return this;
}

public Builder timestampFiledName(String timestampFiledName) {
this.timestampFiledName = timestampFiledName;
return this;
}

public JdbcChatMemoryRepository build() {
return new JdbcChatMemoryRepository(this.jdbcTemplate);
return new JdbcChatMemoryRepository(this.jdbcTemplate, this.tableName, this.conversionIdFiledName,
this.contentFiledName, this.typeFiledName, this.timestampFiledName);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.UUID;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository.*;

/**
* Integration tests for {@link JdbcChatMemoryRepository}.
Expand Down Expand Up @@ -160,7 +161,14 @@ static class TestConfiguration {

@Bean
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) {
return JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build();
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.tableName(DEFAULT_TABLE_NAME)
.conversionIdFiledName(DEFAULT_CONVERSION_ID_FIELD_NAME)
.contentFiledName(DEFAULT_CONTENT_FIELD_NAME)
.typeFiledName(DEFAULT_TYPE_FIELD_NAME)
.timestampFiledName(DEFAULT_TIMESTAMP_FIELD_NAME)
.build();
}

}
Expand Down