1
+ package org .springdoc .core .customisers ;
2
+
3
+ import java .lang .reflect .Field ;
4
+ import java .lang .reflect .Modifier ;
5
+ import java .lang .reflect .Type ;
6
+ import java .util .ArrayList ;
7
+ import java .util .Arrays ;
8
+ import java .util .Collections ;
9
+ import java .util .List ;
10
+ import java .util .Map ;
11
+ import java .util .Optional ;
12
+ import java .util .Set ;
13
+ import java .util .stream .Collectors ;
14
+
15
+ import com .querydsl .core .types .Path ;
16
+ import io .swagger .v3 .core .converter .ModelConverters ;
17
+ import io .swagger .v3 .core .converter .ResolvedSchema ;
18
+ import io .swagger .v3 .core .util .PrimitiveType ;
19
+ import io .swagger .v3 .oas .models .Operation ;
20
+ import io .swagger .v3 .oas .models .media .Schema ;
21
+ import io .swagger .v3 .oas .models .parameters .Parameter ;
22
+ import org .apache .commons .lang3 .StringUtils ;
23
+ import org .slf4j .Logger ;
24
+ import org .slf4j .LoggerFactory ;
25
+ import org .springdoc .core .customizers .OperationCustomizer ;
26
+
27
+ import org .springframework .core .LocalVariableTableParameterNameDiscoverer ;
28
+ import org .springframework .core .MethodParameter ;
29
+ import org .springframework .data .querydsl .binding .QuerydslBinderCustomizer ;
30
+ import org .springframework .data .querydsl .binding .QuerydslBindings ;
31
+ import org .springframework .data .querydsl .binding .QuerydslBindingsFactory ;
32
+ import org .springframework .data .querydsl .binding .QuerydslPredicate ;
33
+ import org .springframework .data .util .CastUtils ;
34
+ import org .springframework .data .util .ClassTypeInformation ;
35
+ import org .springframework .data .util .TypeInformation ;
36
+ import org .springframework .web .method .HandlerMethod ;
37
+
38
+ /**
39
+ * @author Gibah Joseph
40
+
41
+ * Mar, 2020
42
+ **/
43
+ public class QuerydslPredicateOperationCustomizer implements OperationCustomizer {
44
+ private static final Logger LOGGER = LoggerFactory .getLogger (QuerydslPredicateOperationCustomizer .class );
45
+ private QuerydslBindingsFactory querydslBindingsFactory ;
46
+ private LocalVariableTableParameterNameDiscoverer localVariableTableParameterNameDiscoverer ;
47
+
48
+ public QuerydslPredicateOperationCustomizer (QuerydslBindingsFactory querydslBindingsFactory , LocalVariableTableParameterNameDiscoverer localVariableTableParameterNameDiscoverer ) {
49
+ this .querydslBindingsFactory = querydslBindingsFactory ;
50
+ this .localVariableTableParameterNameDiscoverer = localVariableTableParameterNameDiscoverer ;
51
+ }
52
+
53
+ @ Override
54
+ public Operation customize (Operation operation , HandlerMethod handlerMethod ) {
55
+ if (operation .getParameters () == null ) {
56
+ return operation ;
57
+ }
58
+
59
+ MethodParameter [] methodParameters = handlerMethod .getMethodParameters ();
60
+ String [] methodParameterNames = this .localVariableTableParameterNameDiscoverer .getParameterNames (handlerMethod .getMethod ());
61
+ String [] reflectionParametersNames = Arrays .stream (methodParameters ).map (MethodParameter ::getParameterName ).toArray (String []::new );
62
+ if (methodParameterNames == null ) {
63
+ methodParameterNames = reflectionParametersNames ;
64
+ }
65
+ int parametersLength = methodParameters .length ;
66
+ List <Parameter > parametersToAddToOperation = new ArrayList <>();
67
+ for (int i = 0 ; i < parametersLength ; i ++) {
68
+ MethodParameter parameter = methodParameters [i ];
69
+ QuerydslPredicate predicate = parameter .getParameterAnnotation (QuerydslPredicate .class );
70
+
71
+ if (predicate == null ) {
72
+ continue ;
73
+ }
74
+
75
+ List <io .swagger .v3 .oas .models .parameters .Parameter > operationParameters = operation .getParameters ();
76
+ QuerydslBindings bindings = extractQdslBindings (predicate );
77
+
78
+ Set <String > fieldsToAdd = Arrays .stream (predicate .root ().getDeclaredFields ()).map (Field ::getName ).collect (Collectors .toSet ());
79
+
80
+ Map <String , Object > pathSpecMap = getPathSpec (bindings , "pathSpecs" );
81
+ //remove blacklisted fields
82
+ Set <String > blacklist = getFieldValues (bindings , "blackList" );
83
+ fieldsToAdd .removeIf (blacklist ::contains );
84
+
85
+ Set <String > whiteList = getFieldValues (bindings , "whiteList" );
86
+ Set <String > aliases = getFieldValues (bindings , "aliases" );
87
+
88
+ fieldsToAdd .addAll (aliases );
89
+ fieldsToAdd .addAll (whiteList );
90
+ for (String fieldName : fieldsToAdd ) {
91
+ Type type = getFieldType (fieldName , pathSpecMap , predicate .root ());
92
+ io .swagger .v3 .oas .models .parameters .Parameter newParameter = buildParam (type , fieldName );
93
+
94
+ parametersToAddToOperation .add (newParameter );
95
+ }
96
+ }
97
+ operation .getParameters ().addAll (parametersToAddToOperation );
98
+ return operation ;
99
+ }
100
+
101
+ private QuerydslBindings extractQdslBindings (QuerydslPredicate predicate ) {
102
+ ClassTypeInformation <?> classTypeInformation = ClassTypeInformation .from (predicate .root ());
103
+ TypeInformation <?> domainType = classTypeInformation .getRequiredActualType ();
104
+
105
+ Optional <Class <? extends QuerydslBinderCustomizer <?>>> bindingsAnnotation = Optional .of (predicate )
106
+ .map (QuerydslPredicate ::bindings )
107
+ .map (CastUtils ::cast );
108
+
109
+ return bindingsAnnotation
110
+ .map (it -> querydslBindingsFactory .createBindingsFor (domainType , it ))
111
+ .orElseGet (() -> querydslBindingsFactory .createBindingsFor (domainType ));
112
+ }
113
+
114
+ private Set <String > getFieldValues (QuerydslBindings instance , String fieldName ) {
115
+ try {
116
+ Field field = instance .getClass ().getDeclaredField (fieldName );
117
+ if (Modifier .isPrivate (field .getModifiers ())) {
118
+ field .setAccessible (true );
119
+ }
120
+ return (Set <String >) field .get (instance );
121
+ } catch (NoSuchFieldException | IllegalAccessException e ) {
122
+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
123
+ }
124
+ return Collections .emptySet ();
125
+ }
126
+
127
+ private Map <String , Object > getPathSpec (QuerydslBindings instance , String fieldName ) {
128
+ try {
129
+ Field field = instance .getClass ().getDeclaredField (fieldName );
130
+ if (Modifier .isPrivate (field .getModifiers ())) {
131
+ field .setAccessible (true );
132
+ }
133
+ return (Map <String , Object >) field .get (instance );
134
+ } catch (NoSuchFieldException | IllegalAccessException e ) {
135
+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
136
+ }
137
+ return Collections .emptyMap ();
138
+ }
139
+
140
+ private Optional <Path <?>> getPathFromPathSpec (Object instance ) {
141
+ try {
142
+ if (instance == null ) {
143
+ return Optional .empty ();
144
+ }
145
+ Field field = instance .getClass ().getDeclaredField ("path" );
146
+ if (Modifier .isPrivate (field .getModifiers ())) {
147
+ field .setAccessible (true );
148
+ }
149
+ return (Optional <Path <?>>) field .get (instance );
150
+ } catch (NoSuchFieldException | IllegalAccessException e ) {
151
+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
152
+ }
153
+ return Optional .empty ();
154
+ }
155
+
156
+ /***
157
+ * Tries to figure out the Type of the field. It first checks the Qdsl pathSpecMap before checking the root class. Defaults to String.class
158
+ * @param fieldName The name of the field used as reference to get the type
159
+ * @param pathSpecMap The Qdsl path specifications as defined in the resolved bindings
160
+ * @param root The root type where the paths are gotten
161
+ * @return The type of the field. Returns
162
+ */
163
+ private Type getFieldType (String fieldName , Map <String , Object > pathSpecMap , Class <?> root ) {
164
+ try {
165
+ Object pathAndBinding = pathSpecMap .get (fieldName );
166
+ Optional <Path <?>> path = getPathFromPathSpec (pathAndBinding );
167
+
168
+ Type genericType ;
169
+ Field declaredField = null ;
170
+ if (path .isPresent ()) {
171
+ genericType = path .get ().getType ();
172
+ } else {
173
+ declaredField = root .getDeclaredField (fieldName );
174
+ genericType = declaredField .getGenericType ();
175
+ }
176
+ if (genericType != null ) {
177
+ return genericType ;
178
+ }
179
+ } catch (NoSuchFieldException e ) {
180
+ LOGGER .warn ("Field {} not found on {} : {}" , fieldName , root .getName (), e .getMessage ());
181
+ }
182
+ return String .class ;
183
+ }
184
+
185
+ /***
186
+ * Constructs the parameter
187
+ * @param type The type of the parameter
188
+ * @param name The name of the parameter
189
+ * @return The swagger parameter
190
+ */
191
+ private io .swagger .v3 .oas .models .parameters .Parameter buildParam (Type type , String name ) {
192
+ io .swagger .v3 .oas .models .parameters .Parameter parameter = new io .swagger .v3 .oas .models .parameters .Parameter ();
193
+
194
+ if (StringUtils .isBlank (parameter .getName ())) {
195
+ parameter .setName (name );
196
+ }
197
+
198
+ if (StringUtils .isBlank (parameter .getIn ())) {
199
+ parameter .setIn ("query" );
200
+ }
201
+
202
+ if (parameter .getSchema () == null ) {
203
+ Schema <?> schema = null ;
204
+ PrimitiveType primitiveType = PrimitiveType .fromType (type );
205
+ if (primitiveType != null ) {
206
+ schema = primitiveType .createProperty ();
207
+ } else {
208
+ ResolvedSchema resolvedSchema = ModelConverters .getInstance ()
209
+ .resolveAsResolvedSchema (
210
+ new io .swagger .v3 .core .converter .AnnotatedType (type ).resolveAsRef (true ));
211
+ // could not resolve the schema or this schema references other schema
212
+ // we dont want this since there's no reference to the components in order to register a new schema if it doesnt already exist
213
+ // defaulting to string
214
+ if (resolvedSchema == null || !resolvedSchema .referencedSchemas .isEmpty ()) {
215
+ schema = PrimitiveType .fromType (String .class ).createProperty ();
216
+ } else {
217
+ schema = resolvedSchema .schema ;
218
+ }
219
+ }
220
+ parameter .setSchema (schema );
221
+ }
222
+ return parameter ;
223
+ }
224
+ }
0 commit comments