Skip to content

Commit 8d8aabe

Browse files
author
chao.wang
committed
Add JdbcAssertingPartyMetadataRepository
Closes gh-16012
1 parent 517ce22 commit 8d8aabe

File tree

6 files changed

+654
-0
lines changed

6 files changed

+654
-0
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/AssertingPartyMetadata.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,19 @@
1616

1717
package org.springframework.security.saml2.provider.service.registration;
1818

19+
import java.io.IOException;
20+
import java.io.InputStream;
1921
import java.io.Serializable;
22+
import java.util.ArrayList;
2023
import java.util.Collection;
2124
import java.util.List;
2225
import java.util.function.Consumer;
2326

27+
import org.opensaml.saml.common.xml.SAMLConstants;
28+
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
29+
import org.springframework.core.io.DefaultResourceLoader;
30+
import org.springframework.core.io.ResourceLoader;
31+
import org.springframework.security.saml2.Saml2Exception;
2432
import org.springframework.security.saml2.core.Saml2X509Credential;
2533

2634
/**
@@ -274,4 +282,33 @@ interface Builder<B extends Builder<B>> {
274282

275283
}
276284

285+
static final ResourceLoader resourceLoader = new DefaultResourceLoader();
286+
287+
static Collection<Builder<?>> collectionFromMetadataLocation(String location) {
288+
try (InputStream source = resourceLoader.getResource(location).getInputStream()) {
289+
return collectionFromMetadata(source);
290+
}
291+
catch (IOException ex) {
292+
if (ex.getCause() instanceof Saml2Exception) {
293+
throw (Saml2Exception) ex.getCause();
294+
}
295+
throw new Saml2Exception(ex);
296+
}
297+
}
298+
299+
static Collection<Builder<?>> collectionFromMetadata(InputStream source) {
300+
Collection<Builder<?>> builders = new ArrayList<>();
301+
for (EntityDescriptor descriptor : OpenSamlMetadataUtils.descriptors(source)) {
302+
if (descriptor.getIDPSSODescriptor(SAMLConstants.SAML20P_NS) != null) {
303+
OpenSamlAssertingPartyDetails.Builder builder = OpenSamlAssertingPartyDetails
304+
.withEntityDescriptor(descriptor);
305+
builders.add(builder);
306+
}
307+
}
308+
if (builders.isEmpty()) {
309+
throw new Saml2Exception("Metadata response is missing the necessary IDPSSODescriptor element");
310+
}
311+
return builders;
312+
}
313+
277314
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.registration;
18+
19+
import java.io.IOException;
20+
import java.sql.PreparedStatement;
21+
import java.sql.ResultSet;
22+
import java.sql.SQLException;
23+
import java.sql.Types;
24+
import java.util.ArrayList;
25+
import java.util.Collection;
26+
import java.util.Iterator;
27+
import java.util.List;
28+
import java.util.function.Function;
29+
30+
import org.apache.commons.logging.Log;
31+
import org.apache.commons.logging.LogFactory;
32+
import org.slf4j.Logger;
33+
import org.slf4j.LoggerFactory;
34+
import org.springframework.core.log.LogMessage;
35+
import org.springframework.core.serializer.DefaultDeserializer;
36+
import org.springframework.core.serializer.DefaultSerializer;
37+
import org.springframework.core.serializer.Deserializer;
38+
import org.springframework.core.serializer.Serializer;
39+
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
40+
import org.springframework.jdbc.core.JdbcOperations;
41+
import org.springframework.jdbc.core.PreparedStatementSetter;
42+
import org.springframework.jdbc.core.RowMapper;
43+
import org.springframework.jdbc.core.SqlParameterValue;
44+
import org.springframework.security.saml2.core.Saml2X509Credential;
45+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
46+
import org.springframework.util.Assert;
47+
48+
/**
49+
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
50+
*
51+
* @author Cathy Wang
52+
* @since 7.0
53+
*/
54+
public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository {
55+
56+
private final JdbcOperations jdbcOperations;
57+
58+
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper =
59+
new AssertingPartyMetadataRowMapper(ResultSet::getBytes);
60+
61+
private Function<AssertingPartyMetadata, List<SqlParameterValue>> assertingPartyMetadataParametersMapper =
62+
new AssertingPartyMetadataParametersMapper();
63+
64+
private final SetBytes setBytes = PreparedStatement::setBytes;
65+
66+
// @formatter:off
67+
static final String COLUMN_NAMES = "entity_id, "
68+
+ "singlesignon_url, "
69+
+ "singlesignon_binding, "
70+
+ "singlesignon_sign_request, "
71+
+ "signing_algorithms, "
72+
+ "verification_credentials, "
73+
+ "encryption_credentials, "
74+
+ "singlelogout_url, "
75+
+ "singlelogout_response_url, "
76+
+ "singlelogout_binding";
77+
// @formatter:on
78+
79+
private static final String TABLE_NAME = "saml2_asserting_party_metadata";
80+
81+
private static final String ENTITY_ID_FILTER = "entity_id = ?";
82+
83+
// @formatter:off
84+
private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
85+
+ " FROM " + TABLE_NAME
86+
+ " WHERE " + ENTITY_ID_FILTER;
87+
88+
private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
89+
+ " FROM " + TABLE_NAME;
90+
91+
private static final String SAVE_SQL = "INSERT INTO " + TABLE_NAME + " ("
92+
+ COLUMN_NAMES
93+
+ ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
94+
// @formatter:on
95+
96+
private static final String DELETE_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + ENTITY_ID_FILTER;
97+
98+
// @formatter:off
99+
private static final String UPDATE_SQL = "UPDATE " + TABLE_NAME
100+
+ " SET singlesignon_url = ?, " +
101+
"singlesignon_binding = ?, " +
102+
"singlesignon_sign_request = ?, " +
103+
"signing_algorithms = ?, " +
104+
"verification_credentials = ?, " +
105+
"encryption_credentials = ?, " +
106+
"singlelogout_url = ? ," +
107+
"singlelogout_response_url = ?, " +
108+
"singlelogout_binding = ?"
109+
+ " WHERE " + ENTITY_ID_FILTER;
110+
// @formatter:on
111+
112+
/**
113+
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
114+
* parameters.
115+
*
116+
* @param jdbcOperations the JDBC operations
117+
*/
118+
public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) {
119+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
120+
this.jdbcOperations = jdbcOperations;
121+
}
122+
123+
/**
124+
* Sets the {@link RowMapper} used for mapping the current row in
125+
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. The default is
126+
* {@link AssertingPartyMetadataRowMapper}.
127+
*
128+
* @param assertingPartyMetadataRowMapper the {@link RowMapper} used for mapping the
129+
* current row in {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}
130+
*/
131+
public void setAssertingPartyMetadataRowMapper(
132+
RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper) {
133+
Assert.notNull(assertingPartyMetadataRowMapper, "assertingPartyMetadataRowMapper cannot be null");
134+
this.assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper;
135+
}
136+
137+
public void setAssertingPartyMetadataParametersMapper(Function<AssertingPartyMetadata, List<SqlParameterValue>> assertingPartyMetadataParametersMapper) {
138+
Assert.notNull(assertingPartyMetadataParametersMapper, "assertingPartyMetadataParametersMapper cannot be null");
139+
this.assertingPartyMetadataParametersMapper = assertingPartyMetadataParametersMapper;
140+
}
141+
142+
public void save(AssertingPartyMetadata metadata) {
143+
Assert.notNull(metadata, "metadata cannot be null");
144+
int rows = update(metadata);
145+
if (rows == 0) {
146+
insert(metadata);
147+
}
148+
}
149+
150+
private void insert(AssertingPartyMetadata metadata) {
151+
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
152+
PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter(this.setBytes, parameters.toArray());
153+
this.jdbcOperations.update(SAVE_SQL, pss);
154+
}
155+
156+
private int update(AssertingPartyMetadata metadata) {
157+
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
158+
SqlParameterValue credentialId = parameters.remove(0);
159+
parameters.add(credentialId);
160+
PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter(this.setBytes, parameters.toArray());
161+
return this.jdbcOperations.update(UPDATE_SQL, pss);
162+
}
163+
164+
public void delete(String entityId) {
165+
Assert.notNull(entityId, "entityId cannot be null");
166+
SqlParameterValue[] parameters = new SqlParameterValue[]{
167+
new SqlParameterValue(Types.VARCHAR, entityId),};
168+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
169+
this.jdbcOperations.update(DELETE_SQL, pss);
170+
}
171+
172+
@Override
173+
public AssertingPartyMetadata findByEntityId(String entityId) {
174+
Assert.hasText(entityId, "entityId cannot be empty");
175+
SqlParameterValue[] parameters = new SqlParameterValue[]{
176+
new SqlParameterValue(Types.VARCHAR, entityId)};
177+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
178+
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss,
179+
this.assertingPartyMetadataRowMapper);
180+
return !result.isEmpty() ? result.get(0) : null;
181+
}
182+
183+
@Override
184+
public Iterator<AssertingPartyMetadata> iterator() {
185+
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_ALL_SQL,
186+
this.assertingPartyMetadataRowMapper);
187+
return result.iterator();
188+
}
189+
190+
private static class AssertingPartyMetadataParametersMapper
191+
implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
192+
193+
private final Logger logger = LoggerFactory.getLogger(AssertingPartyMetadataParametersMapper.class);
194+
195+
private final Serializer<Object> serializer = new DefaultSerializer();
196+
197+
@Override
198+
public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
199+
List<SqlParameterValue> parameters = new ArrayList<>();
200+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
201+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
202+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
203+
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
204+
try {
205+
parameters.add(new SqlParameterValue(Types.BLOB,
206+
this.serializer.serializeToByteArray(record.getSigningAlgorithms())));
207+
} catch (IOException ex) {
208+
this.logger.debug("Failed to serialize signing algorithms", ex);
209+
throw new IllegalArgumentException(ex);
210+
}
211+
try {
212+
parameters.add(new SqlParameterValue(Types.BLOB,
213+
this.serializer.serializeToByteArray(record.getVerificationX509Credentials())));
214+
} catch (IOException ex) {
215+
this.logger.debug("Failed to serialize verification credentials", ex);
216+
throw new IllegalArgumentException(ex);
217+
}
218+
try {
219+
parameters.add(new SqlParameterValue(Types.BLOB,
220+
this.serializer.serializeToByteArray(record.getEncryptionX509Credentials())));
221+
} catch (IOException ex) {
222+
this.logger.debug("Failed to serialize encryption credentials", ex);
223+
throw new IllegalArgumentException(ex);
224+
}
225+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
226+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
227+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
228+
return parameters;
229+
}
230+
}
231+
232+
private static final class BlobArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
233+
234+
private final SetBytes setBytes;
235+
236+
private BlobArgumentPreparedStatementSetter(SetBytes setBytes, Object[] args) {
237+
super(args);
238+
this.setBytes = setBytes;
239+
}
240+
241+
@Override
242+
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
243+
if (argValue instanceof SqlParameterValue paramValue) {
244+
if (paramValue.getSqlType() == Types.BLOB) {
245+
if (paramValue.getValue() != null) {
246+
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
247+
"Value of blob parameter must be byte[]");
248+
}
249+
byte[] valueBytes = (byte[]) paramValue.getValue();
250+
this.setBytes.setBytes(ps, parameterPosition, valueBytes);
251+
return;
252+
}
253+
}
254+
super.doSetValue(ps, parameterPosition, argValue);
255+
}
256+
257+
}
258+
259+
/**
260+
* The default {@link RowMapper} that maps the current row in
261+
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
262+
*/
263+
private final static class AssertingPartyMetadataRowMapper implements RowMapper<AssertingPartyMetadata> {
264+
265+
private final Log logger = LogFactory.getLog(AssertingPartyMetadataRowMapper.class);
266+
267+
private final Deserializer<Object> deserializer = new DefaultDeserializer();
268+
269+
private final GetBytes getBytes;
270+
271+
AssertingPartyMetadataRowMapper(GetBytes getBytes) {
272+
this.getBytes = getBytes;
273+
}
274+
275+
@Override
276+
public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException {
277+
String entityId = rs.getString("entity_id");
278+
String singleSignOnUrl = rs.getString("singlesignon_url");
279+
Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding
280+
.from(rs.getString("singlesignon_binding"));
281+
boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request");
282+
List<String> signingAlgorithms;
283+
try {
284+
signingAlgorithms = (List<String>) deserializer.deserializeFromByteArray(
285+
this.getBytes.getBytes(rs, "signing_algorithms"));
286+
} catch (IOException ex) {
287+
this.logger.debug(
288+
LogMessage.format("Verification credentials of %s could not be parsed.", entityId), ex);
289+
return null;
290+
}
291+
Collection<Saml2X509Credential> verificationCredentials;
292+
try {
293+
verificationCredentials = (Collection<Saml2X509Credential>) deserializer.deserializeFromByteArray(
294+
this.getBytes.getBytes(rs, "verification_credentials"));
295+
} catch (IOException ex) {
296+
this.logger.debug(
297+
LogMessage.format("Verification credentials of %s could not be parsed.", entityId), ex);
298+
return null;
299+
}
300+
Collection<Saml2X509Credential> encryptionCredentials;
301+
try {
302+
encryptionCredentials = (Collection<Saml2X509Credential>) deserializer.deserializeFromByteArray(
303+
this.getBytes.getBytes(rs, "encryption_credentials"));
304+
} catch (IOException ex) {
305+
this.logger.debug(
306+
LogMessage.format("Encryption credentials of %s could not be parsed.", entityId), ex);
307+
return null;
308+
}
309+
String singleLogoutUrl = rs.getString("singlelogout_url");
310+
String singleLogoutResponseUrl = rs.getString("singlelogout_response_url");
311+
Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding
312+
.from(rs.getString("singlelogout_binding"));
313+
314+
return new AssertingPartyDetails.Builder()
315+
.entityId(entityId)
316+
.wantAuthnRequestsSigned(singleSignOnSignRequest)
317+
.signingAlgorithms(algorithms -> algorithms.addAll(signingAlgorithms))
318+
.verificationX509Credentials(credentials -> credentials.addAll(verificationCredentials))
319+
.encryptionX509Credentials(credentials -> credentials.addAll(encryptionCredentials))
320+
.singleSignOnServiceLocation(singleSignOnUrl)
321+
.singleSignOnServiceBinding(singleSignOnBinding)
322+
.singleLogoutServiceLocation(singleLogoutUrl)
323+
.singleLogoutServiceBinding(singleLogoutBinding)
324+
.singleLogoutServiceResponseLocation(singleLogoutResponseUrl)
325+
.build();
326+
}
327+
}
328+
329+
private interface SetBytes {
330+
331+
void setBytes(PreparedStatement ps, int index, byte[] bytes) throws SQLException;
332+
333+
}
334+
335+
private interface GetBytes {
336+
337+
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
338+
339+
}
340+
341+
}

0 commit comments

Comments
 (0)