|
1 | 1 | package rethinkdb |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "encoding/binary" |
4 | 5 | "encoding/json" |
5 | 6 | "fmt" |
| 7 | + "gopkg.in/check.v1" |
| 8 | + "net" |
6 | 9 | "reflect" |
7 | 10 | "sync" |
8 | 11 | "time" |
@@ -338,43 +341,41 @@ func (m *Mock) Query(ctx context.Context, q Query) (*Cursor, error) { |
338 | 341 | return nil, query.Error |
339 | 342 | } |
340 | 343 |
|
| 344 | + var conn *Connection = nil |
| 345 | + responseVal := reflect.ValueOf(query.Response) |
| 346 | + if responseVal.Kind() == reflect.Chan || responseVal.Kind() == reflect.Func { |
| 347 | + conn = newConnection(newMockConn(query.Response), "mock", &ConnectOpts{}) |
| 348 | + |
| 349 | + query.Query.Type = p.Query_CONTINUE |
| 350 | + query.Query.Token = conn.nextToken() |
| 351 | + |
| 352 | + conn.runConnection() |
| 353 | + } |
| 354 | + |
| 355 | + if ctx == nil { |
| 356 | + ctx = context.Background() |
| 357 | + } |
| 358 | + |
341 | 359 | // Build cursor and return |
342 | | - c := newCursor(ctx, nil, "", query.Query.Token, query.Query.Term, query.Query.Opts) |
| 360 | + c := newCursor(ctx, conn, "", query.Query.Token, query.Query.Term, query.Query.Opts) |
343 | 361 | c.finished = true |
344 | 362 | c.fetching = false |
345 | 363 | c.isAtom = true |
346 | 364 |
|
347 | | - responseVal := reflect.ValueOf(query.Response) |
348 | 365 | if responseVal.Kind() == reflect.Slice || responseVal.Kind() == reflect.Array { |
349 | 366 | for i := 0; i < responseVal.Len(); i++ { |
350 | | - c.buffer = append(c.buffer, getMockValue(responseVal.Index(i).Interface())) |
| 367 | + c.buffer = append(c.buffer, responseVal.Index(i).Interface()) |
351 | 368 | } |
| 369 | + } else if conn != nil { |
| 370 | + conn.cursors[query.Query.Token] = c |
| 371 | + c.finished = false |
352 | 372 | } else { |
353 | | - c.buffer = append(c.buffer, getMockValue(query.Response)) |
| 373 | + c.buffer = append(c.buffer, query.Response) |
354 | 374 | } |
355 | 375 |
|
356 | 376 | return c, nil |
357 | 377 | } |
358 | 378 |
|
359 | | -// getMockValue turns some responses to delayedData: values of "chan |
360 | | -// interface{}" type will turn to delayed data that produce data when there is |
361 | | -// an element available on the channel. Values of "func() interface{}" type |
362 | | -// will produce data by calling the function. |
363 | | -func getMockValue(val interface{}) interface{} { |
364 | | - switch v := val.(type) { |
365 | | - case chan interface{}: |
366 | | - return delayedData{ |
367 | | - f: func() interface{} { return <-v }, |
368 | | - } |
369 | | - case func() interface{}: |
370 | | - return delayedData{ |
371 | | - f: v, |
372 | | - } |
373 | | - default: |
374 | | - return val |
375 | | - } |
376 | | -} |
377 | | - |
378 | 379 | func (m *Mock) Exec(ctx context.Context, q Query) error { |
379 | 380 | _, err := m.Query(ctx, q) |
380 | 381 |
|
@@ -422,3 +423,82 @@ func (m *Mock) queries() []MockQuery { |
422 | 423 | defer m.mu.Unlock() |
423 | 424 | return append([]MockQuery{}, m.Queries...) |
424 | 425 | } |
| 426 | + |
| 427 | +type mockConn struct { |
| 428 | + c *check.C |
| 429 | + mu sync.Mutex |
| 430 | + value []byte |
| 431 | + tokens chan int64 |
| 432 | + valueGetter func() []interface{} |
| 433 | +} |
| 434 | + |
| 435 | +func newMockConn(responseGetter interface{}) *mockConn { |
| 436 | + c := &mockConn{tokens: make(chan int64, 1)} |
| 437 | + switch g := responseGetter.(type) { |
| 438 | + case chan []interface{}: |
| 439 | + c.valueGetter = func() []interface{} { return <-g } |
| 440 | + case func() []interface{}: |
| 441 | + c.valueGetter = g |
| 442 | + default: |
| 443 | + panic(fmt.Sprintf("unsupported value generator type: %T", responseGetter)) |
| 444 | + } |
| 445 | + return c |
| 446 | +} |
| 447 | + |
| 448 | +func (c *mockConn) Read(b []byte) (n int, err error) { |
| 449 | + c.mu.Lock() |
| 450 | + defer c.mu.Unlock() |
| 451 | + |
| 452 | + if c.value == nil { |
| 453 | + values := c.valueGetter() |
| 454 | + |
| 455 | + jresps := make([]json.RawMessage, len(values)) |
| 456 | + for i := range values { |
| 457 | + jresps[i], err = json.Marshal(values[i]) |
| 458 | + if err != nil { |
| 459 | + panic(fmt.Sprintf("failed to encode response: %v", err)) |
| 460 | + } |
| 461 | + } |
| 462 | + |
| 463 | + token := <-c.tokens |
| 464 | + resp := Response{ |
| 465 | + Token: token, |
| 466 | + Responses: jresps, |
| 467 | + Type: p.Response_SUCCESS_PARTIAL, |
| 468 | + } |
| 469 | + if values == nil { |
| 470 | + resp.Type = p.Response_SUCCESS_SEQUENCE |
| 471 | + } |
| 472 | + |
| 473 | + c.value, err = json.Marshal(resp) |
| 474 | + if err != nil { |
| 475 | + panic(fmt.Sprintf("failed to encode response: %v", err)) |
| 476 | + } |
| 477 | + |
| 478 | + if len(b) != respHeaderLen { |
| 479 | + panic("wrong header len") |
| 480 | + } |
| 481 | + binary.LittleEndian.PutUint64(b[:8], uint64(token)) |
| 482 | + binary.LittleEndian.PutUint32(b[8:], uint32(len(c.value))) |
| 483 | + return len(b), nil |
| 484 | + } else { |
| 485 | + copy(b, c.value) |
| 486 | + c.value = nil |
| 487 | + return len(b), nil |
| 488 | + } |
| 489 | +} |
| 490 | + |
| 491 | +func (c *mockConn) Write(b []byte) (n int, err error) { |
| 492 | + if len(b) < 8 { |
| 493 | + panic("bad socket write") |
| 494 | + } |
| 495 | + token := int64(binary.LittleEndian.Uint64(b[:8])) |
| 496 | + c.tokens <- token |
| 497 | + return len(b), nil |
| 498 | +} |
| 499 | +func (c *mockConn) Close() error { panic("not implemented") } |
| 500 | +func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") } |
| 501 | +func (c *mockConn) RemoteAddr() net.Addr { panic("not implemented") } |
| 502 | +func (c *mockConn) SetDeadline(t time.Time) error { panic("not implemented") } |
| 503 | +func (c *mockConn) SetReadDeadline(t time.Time) error { panic("not implemented") } |
| 504 | +func (c *mockConn) SetWriteDeadline(t time.Time) error { panic("not implemented") } |
0 commit comments