@@ -49,78 +49,82 @@ class Neo4jChatMemoryAutoConfigurationIT {
4949
5050 static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName .parse ("neo4j" );
5151
52- @ SuppressWarnings ({"rawtypes" , "resource" })
52+ @ SuppressWarnings ({ "rawtypes" , "resource" })
5353 @ Container
54- static Neo4jContainer neo4jContainer = (Neo4jContainer ) new Neo4jContainer (DEFAULT_IMAGE_NAME .withTag ("5" )).withoutAuthentication ().withExposedPorts (7474 ,7687 );
54+ static Neo4jContainer neo4jContainer = (Neo4jContainer ) new Neo4jContainer (DEFAULT_IMAGE_NAME .withTag ("5" ))
55+ .withoutAuthentication ()
56+ .withExposedPorts (7474 , 7687 );
5557
5658 private final ApplicationContextRunner contextRunner = new ApplicationContextRunner ()
57- .withConfiguration (
58- AutoConfigurations .of (Neo4jChatMemoryAutoConfiguration .class , Neo4jAutoConfiguration .class ));
59-
59+ .withConfiguration (AutoConfigurations .of (Neo4jChatMemoryAutoConfiguration .class , Neo4jAutoConfiguration .class ));
6060
6161 @ Test
6262 void addAndGet () {
63- this .contextRunner .withPropertyValues ("spring.neo4j.uri=" + neo4jContainer .getBoltUrl ())
64- .run (context -> {
65- Neo4jChatMemory memory = context .getBean (Neo4jChatMemory .class );
66-
67- String sessionId = UUIDs .timeBased ().toString ();
68- assertThat (memory .get (sessionId , Integer .MAX_VALUE )).isEmpty ();
69-
70- UserMessage userMessage = new UserMessage ("test question" );
71-
72-
73- memory .add (sessionId , userMessage );
74- List <Message > messages = memory .get (sessionId , Integer .MAX_VALUE );
75- assertThat (messages ).hasSize (1 );
76- assertThat (messages .get (0 )).usingRecursiveAssertion ().isEqualTo (userMessage );
77-
78- memory .clear (sessionId );
79- assertThat (memory .get (sessionId , Integer .MAX_VALUE )).isEmpty ();
80-
81- AssistantMessage assistantMessage = new AssistantMessage ("test answer" , Map .of (),
82- List .of (new AssistantMessage .ToolCall (
83- "id" , "type" , "name" , "arguments" )));
84-
85- memory .add (sessionId , List .of (userMessage , assistantMessage ));
86- messages = memory .get (sessionId , Integer .MAX_VALUE );
87- assertThat (messages ).hasSize (2 );
88- assertThat (messages .get (1 )).isEqualTo (userMessage );
89-
90- assertThat (messages .get (0 )).isEqualTo (assistantMessage );
91- memory .clear (sessionId );
92- MimeType textPlain = MimeType .valueOf ("text/plain" );
93- List <Media > media = List .of (Media .builder ().name ("some media" ).id (UUIDs .random ().toString ())
94- .mimeType (textPlain ).data ("hello" .getBytes (StandardCharsets .UTF_8 )).build (),
95- Media .builder ().data (URI .create ("http://www.google.com" ).toURL ()).mimeType (textPlain ).build ());
96- UserMessage userMessageWithMedia = new UserMessage ("Message with media" , media );
97- memory .add (sessionId , userMessageWithMedia );
98-
99- messages = memory .get (sessionId , Integer .MAX_VALUE );
100- assertThat (messages .size ()).isEqualTo (1 );
101- assertThat (messages .get (0 )).isEqualTo (userMessageWithMedia );
102- assertThat (((UserMessage )messages .get (0 )).getMedia ()).hasSize (2 );
103- assertThat (((UserMessage ) messages .get (0 )).getMedia ()).usingRecursiveFieldByFieldElementComparator ().isEqualTo (media );
104- memory .clear (sessionId );
105- ToolResponseMessage toolResponseMessage = new ToolResponseMessage (List .of (
106- new ToolResponse ("id" , "name" , "responseData" ),
107- new ToolResponse ("id2" , "name2" , "responseData2" )),
108- Map .of ("id" , "id" , "metadataKey" , "metadata" ));
109- memory .add (sessionId , toolResponseMessage );
110- messages = memory .get (sessionId , Integer .MAX_VALUE );
111- assertThat (messages .size ()).isEqualTo (1 );
112- assertThat (messages .get (0 )).isEqualTo (toolResponseMessage );
113-
114- memory .clear (sessionId );
115- SystemMessage sm = new SystemMessage ("this is a System message" );
116- memory .add (sessionId , sm );
117- messages = memory .get (sessionId , Integer .MAX_VALUE );
118- assertThat (messages ).hasSize (1 );
119- assertThat (messages .get (0 )).usingRecursiveAssertion ().isEqualTo (sm );
120- });
63+ this .contextRunner .withPropertyValues ("spring.neo4j.uri=" + neo4jContainer .getBoltUrl ()).run (context -> {
64+ Neo4jChatMemory memory = context .getBean (Neo4jChatMemory .class );
65+
66+ String sessionId = UUIDs .timeBased ().toString ();
67+ assertThat (memory .get (sessionId , Integer .MAX_VALUE )).isEmpty ();
68+
69+ UserMessage userMessage = new UserMessage ("test question" );
70+
71+ memory .add (sessionId , userMessage );
72+ List <Message > messages = memory .get (sessionId , Integer .MAX_VALUE );
73+ assertThat (messages ).hasSize (1 );
74+ assertThat (messages .get (0 )).usingRecursiveAssertion ().isEqualTo (userMessage );
75+
76+ memory .clear (sessionId );
77+ assertThat (memory .get (sessionId , Integer .MAX_VALUE )).isEmpty ();
78+
79+ AssistantMessage assistantMessage = new AssistantMessage ("test answer" , Map .of (),
80+ List .of (new AssistantMessage .ToolCall ("id" , "type" , "name" , "arguments" )));
81+
82+ memory .add (sessionId , List .of (userMessage , assistantMessage ));
83+ messages = memory .get (sessionId , Integer .MAX_VALUE );
84+ assertThat (messages ).hasSize (2 );
85+ assertThat (messages .get (1 )).isEqualTo (userMessage );
86+
87+ assertThat (messages .get (0 )).isEqualTo (assistantMessage );
88+ memory .clear (sessionId );
89+ MimeType textPlain = MimeType .valueOf ("text/plain" );
90+ List <Media > media = List .of (
91+ Media .builder ()
92+ .name ("some media" )
93+ .id (UUIDs .random ().toString ())
94+ .mimeType (textPlain )
95+ .data ("hello" .getBytes (StandardCharsets .UTF_8 ))
96+ .build (),
97+ Media .builder ().data (URI .create ("http://www.google.com" ).toURL ()).mimeType (textPlain ).build ());
98+ UserMessage userMessageWithMedia = new UserMessage ("Message with media" , media );
99+ memory .add (sessionId , userMessageWithMedia );
100+
101+ messages = memory .get (sessionId , Integer .MAX_VALUE );
102+ assertThat (messages .size ()).isEqualTo (1 );
103+ assertThat (messages .get (0 )).isEqualTo (userMessageWithMedia );
104+ assertThat (((UserMessage ) messages .get (0 )).getMedia ()).hasSize (2 );
105+ assertThat (((UserMessage ) messages .get (0 )).getMedia ()).usingRecursiveFieldByFieldElementComparator ()
106+ .isEqualTo (media );
107+ memory .clear (sessionId );
108+ ToolResponseMessage toolResponseMessage = new ToolResponseMessage (
109+ List .of (new ToolResponse ("id" , "name" , "responseData" ),
110+ new ToolResponse ("id2" , "name2" , "responseData2" )),
111+ Map .of ("id" , "id" , "metadataKey" , "metadata" ));
112+ memory .add (sessionId , toolResponseMessage );
113+ messages = memory .get (sessionId , Integer .MAX_VALUE );
114+ assertThat (messages .size ()).isEqualTo (1 );
115+ assertThat (messages .get (0 )).isEqualTo (toolResponseMessage );
116+
117+ memory .clear (sessionId );
118+ SystemMessage sm = new SystemMessage ("this is a System message" );
119+ memory .add (sessionId , sm );
120+ messages = memory .get (sessionId , Integer .MAX_VALUE );
121+ assertThat (messages ).hasSize (1 );
122+ assertThat (messages .get (0 )).usingRecursiveAssertion ().isEqualTo (sm );
123+ });
121124 }
125+
122126 @ Test
123- void setCustomConfiguration (){
127+ void setCustomConfiguration () {
124128 final String sessionLabel = "LabelSession" ;
125129 final String toolCallLabel = "LabelToolCall" ;
126130 final String metadataLabel = "LabelMetadata" ;
@@ -129,25 +133,24 @@ void setCustomConfiguration(){
129133 final String mediaLabel = "LabelMedia" ;
130134
131135 final String propertyBase = "spring.ai.chat.memory.neo4j.%s=%s" ;
132- this .contextRunner .withPropertyValues ("spring.neo4j.uri=" + neo4jContainer .getBoltUrl (),
133- propertyBase .formatted ("sessionlabel" , sessionLabel ),
134- propertyBase .formatted ("toolcallLabel" , toolCallLabel ),
135- propertyBase .formatted ("metadatalabel" , metadataLabel ),
136- propertyBase .formatted ("messagelabel" , messageLabel ),
137- propertyBase .formatted ("toolresponselabel" , toolResponseLabel ),
138- propertyBase .formatted ("medialabel" , mediaLabel ))
139- .run (context -> {
140- Neo4jChatMemory chatMemory = context .getBean (Neo4jChatMemory .class );
141- Neo4jChatMemoryConfig config = chatMemory .getConfig ();
142- assertThat (config .getMessageLabel ()).isEqualTo (messageLabel );
143- assertThat (config .getMediaLabel ()).isEqualTo (mediaLabel );
144- assertThat (config .getMetadataLabel ()).isEqualTo (metadataLabel );
145- assertThat (config .getSessionLabel ()).isEqualTo (sessionLabel );
146- assertThat (config .getToolResponseLabel ()).isEqualTo (toolResponseLabel );
147- assertThat (config .getToolCallLabel ()).isEqualTo (toolCallLabel );
148- });
136+ this .contextRunner
137+ .withPropertyValues ("spring.neo4j.uri=" + neo4jContainer .getBoltUrl (),
138+ propertyBase .formatted ("sessionlabel" , sessionLabel ),
139+ propertyBase .formatted ("toolcallLabel" , toolCallLabel ),
140+ propertyBase .formatted ("metadatalabel" , metadataLabel ),
141+ propertyBase .formatted ("messagelabel" , messageLabel ),
142+ propertyBase .formatted ("toolresponselabel" , toolResponseLabel ),
143+ propertyBase .formatted ("medialabel" , mediaLabel ))
144+ .run (context -> {
145+ Neo4jChatMemory chatMemory = context .getBean (Neo4jChatMemory .class );
146+ Neo4jChatMemoryConfig config = chatMemory .getConfig ();
147+ assertThat (config .getMessageLabel ()).isEqualTo (messageLabel );
148+ assertThat (config .getMediaLabel ()).isEqualTo (mediaLabel );
149+ assertThat (config .getMetadataLabel ()).isEqualTo (metadataLabel );
150+ assertThat (config .getSessionLabel ()).isEqualTo (sessionLabel );
151+ assertThat (config .getToolResponseLabel ()).isEqualTo (toolResponseLabel );
152+ assertThat (config .getToolCallLabel ()).isEqualTo (toolCallLabel );
153+ });
149154 }
150155
151-
152-
153156}
0 commit comments