@@ -250,6 +250,7 @@ type SimulatedExecutionTime struct {
250250 MinimumExecutionTime time.Duration
251251 RandomExecutionTime time.Duration
252252 Errors []error
253+ Responses []interface {}
253254 // Keep error after execution. The error will continue to be returned until
254255 // it is cleared.
255256 KeepError bool
@@ -678,47 +679,57 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e
678679 return result , nil
679680}
680681
681- func (s * inMemSpannerServer ) simulateExecutionTime (method string , req interface {}) error {
682+ func (s * inMemSpannerServer ) simulateExecutionTime (method string , req interface {}) ( interface {}, error ) {
682683 s .mu .Lock ()
684+ defer s .mu .Unlock ()
685+
686+ // Check if the server is stopped
683687 if s .stopped {
684- s .mu .Unlock ()
685- return gstatus .Error (codes .Unavailable , "server has been stopped" )
688+ return nil , gstatus .Error (codes .Unavailable , "server has been stopped" )
686689 }
690+
691+ // Send the request to the receivedRequests channel
687692 s .receivedRequests <- req
688- s .mu .Unlock ()
689- s .ready ()
690- s .mu .Lock ()
693+
694+ // Check for a simulated error
691695 if s .err != nil {
692696 err := s .err
693697 s .err = nil
694- s .mu .Unlock ()
695- return err
698+ return nil , err
696699 }
700+
701+ // Check for a simulated execution time
697702 executionTime , ok := s .executionTimes [method ]
698- s .mu .Unlock ()
699703 if ok {
700704 var randTime int64
701705 if executionTime .RandomExecutionTime > 0 {
702706 randTime = rand .Int63n (int64 (executionTime .RandomExecutionTime ))
703707 }
704708 totalExecutionTime := time .Duration (int64 (executionTime .MinimumExecutionTime ) + randTime )
705709 <- time .After (totalExecutionTime )
706- s .mu .Lock ()
710+
711+ // Check for errors in the execution time
707712 if len (executionTime .Errors ) > 0 {
708713 err := executionTime .Errors [0 ]
709714 if ! executionTime .KeepError {
710715 executionTime .Errors = executionTime .Errors [1 :]
711716 }
712- s .mu .Unlock ()
713- return err
717+ return nil , err
718+ }
719+
720+ // Check for responses in the execution time
721+ if len (executionTime .Responses ) > 0 {
722+ response := executionTime .Responses [0 ]
723+ executionTime .Responses = executionTime .Responses [1 :]
724+ return response , nil
714725 }
715- s .mu .Unlock ()
716726 }
717- return nil
727+
728+ return nil , nil
718729}
719730
720731func (s * inMemSpannerServer ) CreateSession (ctx context.Context , req * spannerpb.CreateSessionRequest ) (* spannerpb.Session , error ) {
721- if err := s .simulateExecutionTime (MethodCreateSession , req ); err != nil {
732+ if _ , err := s .simulateExecutionTime (MethodCreateSession , req ); err != nil {
722733 return nil , err
723734 }
724735 if req .Database == "" {
@@ -750,7 +761,7 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C
750761}
751762
752763func (s * inMemSpannerServer ) BatchCreateSessions (ctx context.Context , req * spannerpb.BatchCreateSessionsRequest ) (* spannerpb.BatchCreateSessionsResponse , error ) {
753- if err := s .simulateExecutionTime (MethodBatchCreateSession , req ); err != nil {
764+ if _ , err := s .simulateExecutionTime (MethodBatchCreateSession , req ); err != nil {
754765 return nil , err
755766 }
756767 if req .Database == "" {
@@ -792,7 +803,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann
792803}
793804
794805func (s * inMemSpannerServer ) GetSession (ctx context.Context , req * spannerpb.GetSessionRequest ) (* spannerpb.Session , error ) {
795- if err := s .simulateExecutionTime (MethodGetSession , req ); err != nil {
806+ if _ , err := s .simulateExecutionTime (MethodGetSession , req ); err != nil {
796807 return nil , err
797808 }
798809 if req .Name == "" {
@@ -833,7 +844,7 @@ func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.Li
833844}
834845
835846func (s * inMemSpannerServer ) DeleteSession (ctx context.Context , req * spannerpb.DeleteSessionRequest ) (* emptypb.Empty , error ) {
836- if err := s .simulateExecutionTime (MethodDeleteSession , req ); err != nil {
847+ if _ , err := s .simulateExecutionTime (MethodDeleteSession , req ); err != nil {
837848 return nil , err
838849 }
839850 if req .Name == "" {
@@ -850,7 +861,7 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D
850861}
851862
852863func (s * inMemSpannerServer ) ExecuteSql (ctx context.Context , req * spannerpb.ExecuteSqlRequest ) (* spannerpb.ResultSet , error ) {
853- if err := s .simulateExecutionTime (MethodExecuteSql , req ); err != nil {
864+ if _ , err := s .simulateExecutionTime (MethodExecuteSql , req ); err != nil {
854865 return nil , err
855866 }
856867 if req .Sql == "SELECT 1" {
@@ -905,7 +916,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
905916}
906917
907918func (s * inMemSpannerServer ) ExecuteStreamingSql (req * spannerpb.ExecuteSqlRequest , stream spannerpb.Spanner_ExecuteStreamingSqlServer ) error {
908- if err := s .simulateExecutionTime (MethodExecuteStreamingSql , req ); err != nil {
919+ if _ , err := s .simulateExecutionTime (MethodExecuteStreamingSql , req ); err != nil {
909920 return err
910921 }
911922 return s .executeStreamingSQL (req , stream )
@@ -984,7 +995,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
984995}
985996
986997func (s * inMemSpannerServer ) ExecuteBatchDml (ctx context.Context , req * spannerpb.ExecuteBatchDmlRequest ) (* spannerpb.ExecuteBatchDmlResponse , error ) {
987- if err := s .simulateExecutionTime (MethodExecuteBatchDml , req ); err != nil {
998+ if _ , err := s .simulateExecutionTime (MethodExecuteBatchDml , req ); err != nil {
988999 return nil , err
9891000 }
9901001 if req .Session == "" {
@@ -1047,7 +1058,7 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques
10471058}
10481059
10491060func (s * inMemSpannerServer ) StreamingRead (req * spannerpb.ReadRequest , stream spannerpb.Spanner_StreamingReadServer ) error {
1050- if err := s .simulateExecutionTime (MethodStreamingRead , req ); err != nil {
1061+ if _ , err := s .simulateExecutionTime (MethodStreamingRead , req ); err != nil {
10511062 return err
10521063 }
10531064 sqlReq := & spannerpb.ExecuteSqlRequest {
@@ -1066,7 +1077,7 @@ func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream sp
10661077}
10671078
10681079func (s * inMemSpannerServer ) BeginTransaction (ctx context.Context , req * spannerpb.BeginTransactionRequest ) (* spannerpb.Transaction , error ) {
1069- if err := s .simulateExecutionTime (MethodBeginTransaction , req ); err != nil {
1080+ if _ , err := s .simulateExecutionTime (MethodBeginTransaction , req ); err != nil {
10701081 return nil , err
10711082 }
10721083 if req .Session == "" {
@@ -1085,7 +1096,8 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp
10851096}
10861097
10871098func (s * inMemSpannerServer ) Commit (ctx context.Context , req * spannerpb.CommitRequest ) (* spannerpb.CommitResponse , error ) {
1088- if err := s .simulateExecutionTime (MethodCommitTransaction , req ); err != nil {
1099+ mockResponse , err := s .simulateExecutionTime (MethodCommitTransaction , req )
1100+ if err != nil {
10891101 return nil , err
10901102 }
10911103 if req .Session == "" {
@@ -1107,8 +1119,11 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
11071119 } else {
11081120 return nil , gstatus .Error (codes .InvalidArgument , "Missing transaction in commit request" )
11091121 }
1110- s .removeTransaction (tx )
1111- resp := & spannerpb.CommitResponse {CommitTimestamp : getCurrentTimestamp ()}
1122+ resp , ok := mockResponse .(* spannerpb.CommitResponse )
1123+ if ! ok {
1124+ resp = & spannerpb.CommitResponse {CommitTimestamp : getCurrentTimestamp ()}
1125+ s .removeTransaction (tx )
1126+ }
11121127 if req .ReturnCommitStats {
11131128 resp .CommitStats = & spannerpb.CommitResponse_CommitStats {
11141129 MutationCount : int64 (1 ),
@@ -1142,7 +1157,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
11421157}
11431158
11441159func (s * inMemSpannerServer ) PartitionQuery (ctx context.Context , req * spannerpb.PartitionQueryRequest ) (* spannerpb.PartitionResponse , error ) {
1145- if err := s .simulateExecutionTime (MethodPartitionQuery , req ); err != nil {
1160+ if _ , err := s .simulateExecutionTime (MethodPartitionQuery , req ); err != nil {
11461161 return nil , err
11471162 }
11481163 s .mu .Lock ()
@@ -1214,7 +1229,7 @@ func DecodeResumeToken(t []byte) (uint64, error) {
12141229}
12151230
12161231func (s * inMemSpannerServer ) BatchWrite (req * spannerpb.BatchWriteRequest , stream spannerpb.Spanner_BatchWriteServer ) error {
1217- if err := s .simulateExecutionTime (MethodBatchWrite , req ); err != nil {
1232+ if _ , err := s .simulateExecutionTime (MethodBatchWrite , req ); err != nil {
12181233 return err
12191234 }
12201235 return s .batchWrite (req , stream )
0 commit comments