|
16 | 16 |
|
17 | 17 | package org.springframework.ai.chat.memory.repository.jdbc; |
18 | 18 |
|
| 19 | +import org.springframework.jdbc.support.JdbcUtils; |
| 20 | +import org.springframework.jdbc.support.MetaDataAccessException; |
| 21 | + |
19 | 22 | import java.sql.Connection; |
| 23 | +import java.sql.DatabaseMetaData; |
20 | 24 |
|
21 | 25 | import javax.sql.DataSource; |
22 | 26 |
|
@@ -51,32 +55,24 @@ public interface JdbcChatMemoryRepositoryDialect { |
51 | 55 | */ |
52 | 56 |
|
53 | 57 | /** |
54 | | - * Detects the dialect from the DataSource or JDBC URL. |
| 58 | + * Detects the dialect from the DataSource. |
55 | 59 | */ |
56 | 60 | static JdbcChatMemoryRepositoryDialect from(DataSource dataSource) { |
57 | | - // Simple detection (could be improved) |
58 | | - try (Connection connection = dataSource.getConnection()) { |
59 | | - String url = connection.getMetaData().getURL().toLowerCase(); |
60 | | - if (url.contains("postgresql")) { |
61 | | - return new PostgresChatMemoryRepositoryDialect(); |
62 | | - } |
63 | | - if (url.contains("mysql")) { |
64 | | - return new MysqlChatMemoryRepositoryDialect(); |
65 | | - } |
66 | | - if (url.contains("mariadb")) { |
67 | | - return new MysqlChatMemoryRepositoryDialect(); |
68 | | - } |
69 | | - if (url.contains("sqlserver")) { |
70 | | - return new SqlServerChatMemoryRepositoryDialect(); |
71 | | - } |
72 | | - if (url.contains("hsqldb")) { |
73 | | - return new HsqldbChatMemoryRepositoryDialect(); |
74 | | - } |
75 | | - // Add more as needed |
| 61 | + String productName; |
| 62 | + try { |
| 63 | + productName = JdbcUtils.extractDatabaseMetaData(dataSource, DatabaseMetaData::getDatabaseProductName); |
76 | 64 | } |
77 | | - catch (Exception ignored) { |
| 65 | + catch (Exception e) { |
| 66 | + throw new RuntimeException("Failed to obtain JDBC product name or establish JDBC connection", e); |
78 | 67 | } |
79 | | - return new PostgresChatMemoryRepositoryDialect(); // default |
| 68 | + return switch (productName) { |
| 69 | + case "PostgreSQL" -> new PostgresChatMemoryRepositoryDialect(); |
| 70 | + case "MySQL", "MariaDB" -> new MysqlChatMemoryRepositoryDialect(); |
| 71 | + case "Microsoft SQL Server" -> new SqlServerChatMemoryRepositoryDialect(); |
| 72 | + case "HSQL Database Engine" -> new HsqldbChatMemoryRepositoryDialect(); |
| 73 | + default -> // Add more as needed |
| 74 | + new PostgresChatMemoryRepositoryDialect(); |
| 75 | + }; |
80 | 76 | } |
81 | 77 |
|
82 | 78 | } |
0 commit comments