@@ -339,6 +339,10 @@ private Map<Integer, Object[]> partitionArguments(final DataSourceProvider dataS
339
339
340
340
// TODO: currently only implemented for single shardKey argument as first argument!
341
341
final List <Object > originalArgument = (List <Object >) args [0 ];
342
+ if (originalArgument == null || originalArgument .isEmpty ()) {
343
+ throw new IllegalArgumentException ("ShardKey (first argument) of sproc '" + name + "' not defined" );
344
+ }
345
+
342
346
List <Object > partitionedArgument = null ;
343
347
Object [] partitionedArguments = null ;
344
348
int shardId ;
@@ -383,28 +387,27 @@ private static class Call implements Callable<Object> {
383
387
private final StoredProcedure sproc ;
384
388
private final DataSource shardDs ;
385
389
private final Object [] params ;
386
- private final Object [] originalArgs ;
390
+ private final InvocationContext invocation ;
387
391
388
392
public Call (final StoredProcedure sproc , final DataSource shardDs , final Object [] params ,
389
- final Object [] originalArgs ) {
393
+ final InvocationContext invocation ) {
390
394
this .sproc = sproc ;
391
395
this .shardDs = shardDs ;
392
396
this .params = params ;
393
- this .originalArgs = originalArgs ;
397
+ this .invocation = invocation ;
394
398
}
395
399
396
400
@ Override
397
401
public Object call () throws Exception {
398
- return sproc .executor .executeSProc (shardDs , sproc .getQuery (), params , sproc .getTypes (), originalArgs ,
402
+ return sproc .executor .executeSProc (shardDs , sproc .getQuery (), params , sproc .getTypes (), invocation ,
399
403
sproc .returnType );
400
404
}
401
405
402
406
}
403
407
404
408
private static ExecutorService parallelThreadPool = Executors .newCachedThreadPool ();
405
409
406
- @ SuppressWarnings ({ "unchecked" , "rawtypes" })
407
- public Object execute (final DataSourceProvider dp , final Object [] args ) {
410
+ public Object execute (final DataSourceProvider dp , final InvocationContext invocation ) {
408
411
409
412
List <Integer > shardIds = null ;
410
413
Map <Integer , Object []> partitionedArguments = null ;
@@ -413,17 +416,17 @@ public Object execute(final DataSourceProvider dp, final Object[] args) {
413
416
shardIds = dp .getDistinctShardIds ();
414
417
} else {
415
418
if (autoPartition ) {
416
- partitionedArguments = partitionArguments (dp , args );
419
+ partitionedArguments = partitionArguments (dp , invocation . getArgs () );
417
420
shardIds = Lists .newArrayList (partitionedArguments .keySet ());
418
421
} else {
419
- shardIds = Lists .newArrayList (getShardId (args ));
422
+ shardIds = Lists .newArrayList (getShardId (invocation . getArgs () ));
420
423
}
421
424
}
422
425
423
426
if (partitionedArguments == null ) {
424
427
partitionedArguments = Maps .newHashMap ();
425
428
for (final int shardId : shardIds ) {
426
- partitionedArguments .put (shardId , args );
429
+ partitionedArguments .put (shardId , invocation . getArgs () );
427
430
}
428
431
}
429
432
@@ -434,7 +437,7 @@ public Object execute(final DataSourceProvider dp, final Object[] args) {
434
437
435
438
} catch (final SQLException e ) {
436
439
throw new CannotGetJdbcConnectionException ("Failed to acquire connection for virtual shard "
437
- + shardIds .get (0 ) + " translates to" + dp . getDataSourceId ( shardIds . get ( 0 )) + " for " + name , e );
440
+ + shardIds .get (0 ) + " for " + name , e );
438
441
}
439
442
440
443
final List <Object []> paramValues = Lists .newArrayList ();
@@ -459,7 +462,7 @@ public Object execute(final DataSourceProvider dp, final Object[] args) {
459
462
}
460
463
461
464
// most common case: only one shard and no argument partitioning
462
- return executor .executeSProc (firstDs , getQuery (), paramValues .get (0 ), getTypes (), args , returnType );
465
+ return executor .executeSProc (firstDs , getQuery (), paramValues .get (0 ), getTypes (), invocation , returnType );
463
466
} else {
464
467
Map <Integer , SameConnectionDatasource > transactionalDatasources = null ;
465
468
try {
@@ -471,11 +474,11 @@ public Object execute(final DataSourceProvider dp, final Object[] args) {
471
474
Object sprocResult = null ;
472
475
final long start = System .currentTimeMillis ();
473
476
if (parallel ) {
474
- sprocResult = executeInParallel (dp , args , shardIds , paramValues , transactionalDatasources , results ,
475
- sprocResult );
477
+ sprocResult = executeInParallel (dp , invocation , shardIds , paramValues , transactionalDatasources ,
478
+ results , sprocResult );
476
479
} else {
477
- sprocResult = executeSequential (dp , args , shardIds , paramValues , transactionalDatasources , results ,
478
- sprocResult );
480
+ sprocResult = executeSequential (dp , invocation , shardIds , paramValues , transactionalDatasources ,
481
+ results , sprocResult );
479
482
}
480
483
481
484
if (LOG .isTraceEnabled ()) {
@@ -528,9 +531,10 @@ public Object execute(final DataSourceProvider dp, final Object[] args) {
528
531
}
529
532
530
533
@ SuppressWarnings ({ "rawtypes" , "unchecked" })
531
- private Object executeSequential (final DataSourceProvider dp , final Object [] args , final List <Integer > shardIds ,
532
- final List <Object []> paramValues , final Map <Integer , SameConnectionDatasource > transactionalDatasources ,
533
- final List <?> results , Object sprocResult ) {
534
+ private Object executeSequential (final DataSourceProvider dp , final InvocationContext invocation ,
535
+ final List <Integer > shardIds , final List <Object []> paramValues ,
536
+ final Map <Integer , SameConnectionDatasource > transactionalDatasources , final List <?> results ,
537
+ Object sprocResult ) {
534
538
DataSource shardDs ;
535
539
int i = 0 ;
536
540
final List <String > exceptions = Lists .newArrayList ();
@@ -543,7 +547,7 @@ private Object executeSequential(final DataSourceProvider dp, final Object[] arg
543
547
544
548
sprocResult = null ;
545
549
try {
546
- sprocResult = executor .executeSProc (shardDs , getQuery (), paramValues .get (i ), getTypes (), args ,
550
+ sprocResult = executor .executeSProc (shardDs , getQuery (), paramValues .get (i ), getTypes (), invocation ,
547
551
returnType );
548
552
} catch (final Exception e ) {
549
553
@@ -568,9 +572,10 @@ private Object executeSequential(final DataSourceProvider dp, final Object[] arg
568
572
}
569
573
570
574
@ SuppressWarnings ({ "rawtypes" , "unchecked" })
571
- private Object executeInParallel (final DataSourceProvider dp , final Object [] args , final List <Integer > shardIds ,
572
- final List <Object []> paramValues , final Map <Integer , SameConnectionDatasource > transactionalDatasources ,
573
- final List <?> results , Object sprocResult ) {
575
+ private Object executeInParallel (final DataSourceProvider dp , final InvocationContext invocation ,
576
+ final List <Integer > shardIds , final List <Object []> paramValues ,
577
+ final Map <Integer , SameConnectionDatasource > transactionalDatasources , final List <?> results ,
578
+ Object sprocResult ) {
574
579
DataSource shardDs ;
575
580
final Map <Integer , FutureTask <Object >> tasks = Maps .newHashMapWithExpectedSize (shardIds .size ());
576
581
FutureTask <Object > task ;
@@ -582,7 +587,7 @@ private Object executeInParallel(final DataSourceProvider dp, final Object[] arg
582
587
LOG .debug (getDebugLog (paramValues .get (i )));
583
588
}
584
589
585
- task = new FutureTask <Object >(new Call (this , shardDs , paramValues .get (i ), args ));
590
+ task = new FutureTask <Object >(new Call (this , shardDs , paramValues .get (i ), invocation ));
586
591
tasks .put (shardId , task );
587
592
parallelThreadPool .execute (task );
588
593
i ++;
0 commit comments