Skip to content

Commit 9a8fb1e

Browse files
committed
Add FieldExtractor implementation for Java records
Resolves #4159
1 parent 310c280 commit 9a8fb1e

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.batch.item.file.transform;
18+
19+
import java.lang.reflect.InvocationTargetException;
20+
import java.lang.reflect.RecordComponent;
21+
import java.util.ArrayList;
22+
import java.util.Arrays;
23+
import java.util.List;
24+
25+
import org.springframework.lang.Nullable;
26+
import org.springframework.util.Assert;
27+
28+
/**
29+
* This is a field extractor for a Java record. By default, it will extract all record
30+
* components, unless a subset is selected using {@link #setNames(String...)}.
31+
*
32+
* @author Mahmoud Ben Hassine
33+
* @since 5.0
34+
*/
35+
public class RecordFieldExtractor<T> implements FieldExtractor<T> {
36+
37+
private List<String> names;
38+
39+
private Class<? extends T> targetType;
40+
41+
private RecordComponent[] recordComponents;
42+
43+
public RecordFieldExtractor(Class<? extends T> targetType) {
44+
Assert.notNull(targetType, "target type must not be null");
45+
Assert.isTrue(targetType.isRecord(), "target type must be a record");
46+
this.targetType = targetType;
47+
this.recordComponents = this.targetType.getRecordComponents();
48+
this.names = getRecordComponentNames();
49+
}
50+
51+
/**
52+
* Set the names of record components to extract.
53+
* @param names of record component to be extracted.
54+
*/
55+
public void setNames(String... names) {
56+
Assert.notNull(names, "Names must not be null");
57+
Assert.notEmpty(names, "Names must not be empty");
58+
validate(names);
59+
this.names = Arrays.stream(names).toList();
60+
}
61+
62+
/**
63+
* @see FieldExtractor#extract(Object)
64+
*/
65+
@Override
66+
public Object[] extract(T item) {
67+
List<Object> values = new ArrayList<>();
68+
for (String componentName : this.names) {
69+
RecordComponent recordComponent = getRecordComponentByName(componentName);
70+
Object value;
71+
try {
72+
value = recordComponent.getAccessor().invoke(item);
73+
values.add(value);
74+
}
75+
catch (IllegalAccessException | InvocationTargetException e) {
76+
throw new RuntimeException("Unable to extract value for record component " + componentName, e);
77+
}
78+
}
79+
return values.toArray();
80+
}
81+
82+
private List<String> getRecordComponentNames() {
83+
return Arrays.stream(this.recordComponents).map(recordComponent -> recordComponent.getName()).toList();
84+
}
85+
86+
private void validate(String[] names) {
87+
for (String name : names) {
88+
if (getRecordComponentByName(name) == null) {
89+
throw new IllegalArgumentException(
90+
"Component '" + name + "' is not defined in record " + targetType.getName());
91+
}
92+
}
93+
}
94+
95+
@Nullable
96+
private RecordComponent getRecordComponentByName(String name) {
97+
return Arrays.stream(this.recordComponents).filter(recordComponent -> recordComponent.getName().equals(name))
98+
.findFirst().orElse(null);
99+
}
100+
101+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.batch.item.file.transform;
17+
18+
import org.junit.Assert;
19+
import org.junit.Test;
20+
21+
/**
22+
* @author Mahmoud Ben Hassine
23+
*/
24+
public class RecordFieldExtractorTests {
25+
26+
@Test(expected = IllegalArgumentException.class)
27+
public void testSetupWithNullTargetType() {
28+
new RecordFieldExtractor<>(null);
29+
}
30+
31+
@Test(expected = IllegalArgumentException.class)
32+
public void testSetupWithNonRecordTargetType() {
33+
new RecordFieldExtractor<>(NonRecordType.class);
34+
}
35+
36+
@Test
37+
public void testExtractFields() {
38+
// given
39+
RecordFieldExtractor<Person> recordFieldExtractor = new RecordFieldExtractor<>(Person.class);
40+
Person person = new Person(1, "foo");
41+
42+
// when
43+
Object[] fields = recordFieldExtractor.extract(person);
44+
45+
// then
46+
Assert.assertNotNull(fields);
47+
Assert.assertArrayEquals(new Object[] { 1, "foo" }, fields);
48+
}
49+
50+
@Test
51+
public void testExtractFieldsSubset() {
52+
// given
53+
RecordFieldExtractor<Person> recordFieldExtractor = new RecordFieldExtractor<>(Person.class);
54+
recordFieldExtractor.setNames("name");
55+
Person person = new Person(1, "foo");
56+
57+
// when
58+
Object[] fields = recordFieldExtractor.extract(person);
59+
60+
// then
61+
Assert.assertNotNull(fields);
62+
Assert.assertArrayEquals(new Object[] { "foo" }, fields);
63+
}
64+
65+
@Test(expected = IllegalArgumentException.class)
66+
public void testInvalidComponentName() {
67+
// given
68+
RecordFieldExtractor<Person> recordFieldExtractor = new RecordFieldExtractor<>(Person.class);
69+
recordFieldExtractor.setNames("nonExistent");
70+
Person person = new Person(1, "foo");
71+
72+
// when
73+
recordFieldExtractor.extract(person);
74+
75+
// then
76+
// expected exception
77+
}
78+
79+
public record Person(int id, String name) {
80+
}
81+
82+
public class NonRecordType {
83+
}
84+
85+
}

0 commit comments

Comments
 (0)