2020import java .lang .reflect .InvocationHandler ;
2121import java .lang .reflect .InvocationTargetException ;
2222import java .lang .reflect .Method ;
23+ import java .util .Arrays ;
2324import java .util .HashMap ;
2425import java .util .HashSet ;
2526import java .util .List ;
2627import java .util .Map ;
2728import java .util .Set ;
29+ import java .util .concurrent .ConcurrentHashMap ;
30+ import java .util .concurrent .ConcurrentMap ;
31+ import java .util .function .Function ;
2832import java .util .stream .Collectors ;
2933import net .bytebuddy .ByteBuddy ;
34+ import net .bytebuddy .description .modifier .Visibility ;
3035import net .bytebuddy .dynamic .loading .ClassLoadingStrategy ;
36+ import net .bytebuddy .implementation .FieldAccessor ;
3137import net .bytebuddy .implementation .InvocationHandlerAdapter ;
3238import net .bytebuddy .matcher .ElementMatchers ;
3339import org .openqa .selenium .Alert ;
183189@ Beta
184190public class WebDriverDecorator <T extends WebDriver > {
185191
192+ protected static class Definition {
193+ private final Class <?> decoratedClass ;
194+ private final Class <?> originalClass ;
195+
196+ public Definition (Decorated <?> decorated ) {
197+ this .decoratedClass = decorated .getClass ();
198+ this .originalClass = decorated .getOriginal ().getClass ();
199+ }
200+
201+ @ Override
202+ public boolean equals (Object o ) {
203+ if (o == null || getClass () != o .getClass ()) return false ;
204+ Definition definition = (Definition ) o ;
205+ // intentionally an identity check, to ensure we get no false positive lookup due to an
206+ // unknown implementation of decoratedClass.equals or originalClass.equals
207+ return (decoratedClass == definition .decoratedClass )
208+ && (originalClass == definition .originalClass );
209+ }
210+
211+ @ Override
212+ public int hashCode () {
213+ return Arrays .hashCode (
214+ new int [] {
215+ System .identityHashCode (decoratedClass ), System .identityHashCode (originalClass )
216+ });
217+ }
218+ }
219+
220+ public interface HasTarget <Z > {
221+ Decorated <Z > getTarget ();
222+
223+ void setTarget (Decorated <Z > target );
224+ }
225+
226+ protected static class ProxyFactory <T > {
227+ private final Class <? extends T > clazz ;
228+
229+ private ProxyFactory (Class <? extends T > clazz ) {
230+ this .clazz = clazz ;
231+ }
232+
233+ public T newInstance (Decorated <T > target ) {
234+ T instance ;
235+ try {
236+ instance = (T ) clazz .newInstance ();
237+ } catch (ReflectiveOperationException e ) {
238+ throw new AssertionError ("Unable to create new proxy" , e );
239+ }
240+
241+ // ensure we can later find the target to call
242+ //noinspection unchecked
243+ ((HasTarget <T >) instance ).setTarget (target );
244+
245+ return instance ;
246+ }
247+ }
248+
249+ private final ConcurrentMap <Definition , ProxyFactory <?>> cache ;
250+
186251 private final Class <T > targetWebDriverClass ;
187252
188253 private Decorated <T > decorated ;
@@ -194,6 +259,7 @@ public WebDriverDecorator() {
194259
195260 public WebDriverDecorator (Class <T > targetClass ) {
196261 this .targetWebDriverClass = targetClass ;
262+ this .cache = new ConcurrentHashMap <>();
197263 }
198264
199265 public final T decorate (T original ) {
@@ -295,18 +361,36 @@ private Object decorateResult(Object toDecorate) {
295361 return toDecorate ;
296362 }
297363
298- protected final <Z > Z createProxy (final Decorated <Z > decorated , Class <Z > clazz ) {
299- Set <Class <?>> decoratedInterfaces = extractInterfaces (decorated );
300- Set <Class <?>> originalInterfaces = extractInterfaces (decorated .getOriginal ());
301- Map <Class <?>, InvocationHandler > derivedInterfaces =
302- deriveAdditionalInterfaces (decorated .getOriginal ());
364+ protected final <Z > Z createProxy (final Decorated <Z > decorated , Class <? extends Z > clazz ) {
365+ @ SuppressWarnings ("unchecked" )
366+ ProxyFactory <Z > factory =
367+ (ProxyFactory <Z >)
368+ cache .computeIfAbsent (
369+ new Definition (decorated ), (key ) -> createProxyFactory (key , decorated , clazz ));
370+
371+ return factory .newInstance (decorated );
372+ }
373+
374+ protected final <Z > ProxyFactory <? extends Z > createProxyFactory (
375+ Definition definition , final Decorated <Z > sample , Class <? extends Z > clazz ) {
376+ Set <Class <?>> decoratedInterfaces = extractInterfaces (definition .decoratedClass );
377+ Set <Class <?>> originalInterfaces = extractInterfaces (definition .originalClass );
378+ // all samples with the same definition should have the same derivedInterfaces
379+ Map <Class <?>, Function <Z , InvocationHandler >> derivedInterfaces =
380+ deriveAdditionalInterfaces (sample .getOriginal ());
303381
304382 final InvocationHandler handler =
305383 (proxy , method , args ) -> {
384+ // Lookup the instance to call, to reuse the clazz and handler.
385+ @ SuppressWarnings ("unchecked" )
386+ Decorated <Z > instance = ((HasTarget <Z >) proxy ).getTarget ();
387+ if (instance == null ) {
388+ throw new AssertionError ("Failed to get instance to call" );
389+ }
306390 try {
307391 if (method .getDeclaringClass ().equals (Object .class )
308392 || decoratedInterfaces .contains (method .getDeclaringClass ())) {
309- return method .invoke (decorated , args );
393+ return method .invoke (instance , args );
310394 }
311395 // Check if the class in which the method resides, implements any one of the
312396 // interfaces that we extracted from the decorated class.
@@ -317,9 +401,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
317401 eachInterface .isAssignableFrom (method .getDeclaringClass ()));
318402
319403 if (isCompatible ) {
320- decorated .beforeCall (method , args );
321- Object result = decorated .call (method , args );
322- decorated .afterCall (method , result , args );
404+ instance .beforeCall (method , args );
405+ Object result = instance .call (method , args );
406+ instance .afterCall (method , result , args );
323407 return result ;
324408 }
325409
@@ -333,19 +417,24 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
333417 eachInterface .isAssignableFrom (method .getDeclaringClass ()));
334418
335419 if (isCompatible ) {
336- return derivedInterfaces .get (method .getDeclaringClass ()).invoke (proxy , method , args );
420+ return derivedInterfaces
421+ .get (method .getDeclaringClass ())
422+ .apply (instance .getOriginal ())
423+ .invoke (proxy , method , args );
337424 }
338425
339- return method .invoke (decorated .getOriginal (), args );
426+ return method .invoke (instance .getOriginal (), args );
340427 } catch (InvocationTargetException e ) {
341- return decorated .onError (method , e , args );
428+ return instance .onError (method , e , args );
342429 }
343430 };
344431
345432 Set <Class <?>> allInterfaces = new HashSet <>();
346433 allInterfaces .addAll (decoratedInterfaces );
347434 allInterfaces .addAll (originalInterfaces );
348435 allInterfaces .addAll (derivedInterfaces .keySet ());
436+ // ensure a decorated driver can get decorated again
437+ allInterfaces .remove (HasTarget .class );
349438 Class <?>[] allInterfacesArray = allInterfaces .toArray (new Class <?>[0 ]);
350439
351440 Class <? extends Z > proxy =
@@ -354,20 +443,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
354443 .implement (allInterfacesArray )
355444 .method (ElementMatchers .any ())
356445 .intercept (InvocationHandlerAdapter .of (handler ))
446+ .defineField ("target" , Decorated .class , Visibility .PRIVATE )
447+ .implement (HasTarget .class )
448+ .intercept (FieldAccessor .ofField ("target" ))
357449 .make ()
358450 .load (clazz .getClassLoader (), ClassLoadingStrategy .Default .WRAPPER )
359451 .getLoaded ()
360452 .asSubclass (clazz );
361453
362- try {
363- return proxy .newInstance ();
364- } catch (ReflectiveOperationException e ) {
365- throw new IllegalStateException ("Unable to create new proxy" , e );
366- }
367- }
368-
369- static Set <Class <?>> extractInterfaces (final Object object ) {
370- return extractInterfaces (object .getClass ());
454+ return new ProxyFactory <Z >(proxy );
371455 }
372456
373457 private static Set <Class <?>> extractInterfaces (final Class <?> clazz ) {
@@ -393,43 +477,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
393477 extractInterfaces (collector , clazz .getSuperclass ());
394478 }
395479
396- private Map <Class <?>, InvocationHandler > deriveAdditionalInterfaces (Object object ) {
397- Map <Class <?>, InvocationHandler > handlers = new HashMap <>();
480+ private < Z > Map <Class <?>, Function < Z , InvocationHandler >> deriveAdditionalInterfaces (Z sample ) {
481+ Map <Class <?>, Function < Z , InvocationHandler > > handlers = new HashMap <>();
398482
399- if (object instanceof WebDriver && !(object instanceof WrapsDriver )) {
483+ if (sample instanceof WebDriver && !(sample instanceof WrapsDriver )) {
400484 handlers .put (
401485 WrapsDriver .class ,
402- (proxy , method , args ) -> {
403- if ("getWrappedDriver" .equals (method .getName ())) {
404- return object ;
405- }
406- throw new UnsupportedOperationException (method .getName ());
407- });
486+ (instance ) ->
487+ (proxy , method , args ) -> {
488+ if ("getWrappedDriver" .equals (method .getName ())) {
489+ return instance ;
490+ }
491+ throw new UnsupportedOperationException (method .getName ());
492+ });
408493 }
409494
410- if (object instanceof WebElement && !(object instanceof WrapsElement )) {
495+ if (sample instanceof WebElement && !(sample instanceof WrapsElement )) {
411496 handlers .put (
412497 WrapsElement .class ,
413- (proxy , method , args ) -> {
414- if ("getWrappedElement" .equals (method .getName ())) {
415- return object ;
416- }
417- throw new UnsupportedOperationException (method .getName ());
418- });
498+ (instance ) ->
499+ (proxy , method , args ) -> {
500+ if ("getWrappedElement" .equals (method .getName ())) {
501+ return instance ;
502+ }
503+ throw new UnsupportedOperationException (method .getName ());
504+ });
419505 }
420506
421507 try {
422- Method toJson = object .getClass ().getDeclaredMethod ("toJson" );
508+ Method toJson = sample .getClass ().getDeclaredMethod ("toJson" );
423509 toJson .setAccessible (true );
424510
425511 handlers .put (
426512 JsonSerializer .class ,
427- ((proxy , method , args ) -> {
428- if ("toJson" .equals (method .getName ())) {
429- return toJson .invoke (object );
430- }
431- throw new UnsupportedOperationException (method .getName ());
432- }));
513+ (instance ) ->
514+ ((proxy , method , args ) -> {
515+ if ("toJson" .equals (method .getName ())) {
516+ return toJson .invoke (instance );
517+ }
518+ throw new UnsupportedOperationException (method .getName ());
519+ }));
433520 } catch (NoSuchMethodException e ) {
434521 // Fine. Just fall through
435522 }
0 commit comments