diff --git a/mock/func_mock.go b/mock/func_mock.go new file mode 100644 index 000000000..9a3753ded --- /dev/null +++ b/mock/func_mock.go @@ -0,0 +1,66 @@ +package mock + +import ( + "errors" + "reflect" + "testing" +) + +var ErrFuncMockNotFunc = errors.New("not a function") + +func FuncMockFor(fn interface{}) (*FuncMock, error) { + typ := reflect.TypeOf(fn) + if typ == nil || typ.Kind() != reflect.Func { + return nil, ErrFuncMockNotFunc + } + + return &FuncMock{ + mock: Mock{}, + typ: typ, + }, nil +} + +type FuncMock struct { + mock Mock + typ reflect.Type +} + +func (m *FuncMock) Build() interface{} { + return reflect.MakeFunc(m.typ, func(args []reflect.Value) []reflect.Value { + argsAsInterface := make([]interface{}, len(args)) + for i, arg := range args { + argsAsInterface[i] = arg.Interface() + } + outs := m.mock.MethodCalled("func", argsAsInterface...) + res := make([]reflect.Value, m.typ.NumOut()) + for i := 0; i < m.typ.NumOut(); i++ { + val := outs.Get(i) + if val == nil { + res[i] = reflect.Zero(m.typ.Out(i)) + continue + } + res[i] = reflect.ValueOf(val) + } + return res + }).Interface() +} + +func (m *FuncMock) On(args ...interface{}) *Call { + return m.mock.On("func", args...) +} + +func (m *FuncMock) AssertExpectations(t *testing.T) { + m.mock.AssertExpectations(t) +} + +func (m *FuncMock) AssertNotCalled(t *testing.T, arguments ...interface{}) { + m.mock.AssertNotCalled(t, "func", arguments...) +} + +func (m *FuncMock) AssertCalled(t *testing.T, arguments ...interface{}) { + m.mock.AssertCalled(t, "func", arguments...) +} + +func (m *FuncMock) AssertNumberOfCalls(t *testing.T, expectedCalls int) { + m.mock.AssertNumberOfCalls(t, "func", expectedCalls) +} diff --git a/mock/func_mock_test.go b/mock/func_mock_test.go new file mode 100644 index 000000000..15599a480 --- /dev/null +++ b/mock/func_mock_test.go @@ -0,0 +1,134 @@ +package mock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ExampleError = errors.New("test") + +type TestFuncTypeNoArgumentsNoOuts func() + +type TestFuncTypeArgumentsNoOuts func(a, b string) + +type TestFuncTypeNoArgumentsOuts func() (string, error) + +type TestFuncTypeArgumentsOuts func(a, b string) (string, error) + +func TestFuncMockFor(t *testing.T) { + t.Run("success", func(t *testing.T) { + var typ TestFuncTypeNoArgumentsNoOuts + res, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, res) + }) + t.Run("fail", func(t *testing.T) { + t.Run("not a func", func(t *testing.T) { + var typ interface{} + res, err := FuncMockFor(typ) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) +} + +func TestFuncMock(t *testing.T) { + t.Run("with no arguments and with no outputs", func(t *testing.T) { + var typ TestFuncTypeNoArgumentsNoOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On().Return() + fn := funcMock.Build().(TestFuncTypeNoArgumentsNoOuts) + assert.NotPanics(t, func() { + fn() + }) + }) + t.Run("with arguments and with no outputs", func(t *testing.T) { + var typ TestFuncTypeArgumentsNoOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + funcMock.On("a", "b").Return() + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On().Return() + fn := funcMock.Build().(TestFuncTypeArgumentsNoOuts) + assert.NotPanics(t, func() { + fn("a", "b") + }) + }) + t.Run("with no arguments and with outputs", func(t *testing.T) { + t.Run("with no error", func(t *testing.T) { + var typ TestFuncTypeNoArgumentsOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + funcMock.On().Return("test", nil) + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On().Return() + fn := funcMock.Build().(TestFuncTypeNoArgumentsOuts) + assert.NotPanics(t, func() { + res, err := fn() + assert.ErrorIs(t, err, nil) + assert.Equal(t, "test", res) + }) + }) + + t.Run("with error", func(t *testing.T) { + var typ TestFuncTypeNoArgumentsOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + funcMock.On().Return("test", ExampleError) + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On().Return() + fn := funcMock.Build().(TestFuncTypeNoArgumentsOuts) + assert.NotPanics(t, func() { + res, err := fn() + assert.ErrorIs(t, err, ExampleError) + assert.Equal(t, "test", res) + }) + }) + }) + t.Run("with arguments and with outputs", func(t *testing.T) { + t.Run("with no error", func(t *testing.T) { + var typ TestFuncTypeArgumentsOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + funcMock.On().Return("test", ExampleError) + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On("1", "2").Return("1 2", nil) + fn := funcMock.Build().(TestFuncTypeArgumentsOuts) + assert.NotPanics(t, func() { + res, err := fn("1", "2") + assert.NoError(t, err) + assert.Equal(t, "1 2", res) + }) + }) + t.Run("with error", func(t *testing.T) { + var typ TestFuncTypeArgumentsOuts + funcMock, err := FuncMockFor(typ) + assert.NoError(t, err) + assert.IsType(t, &FuncMock{}, funcMock) + funcMock.On().Return("test", ExampleError) + defer funcMock.AssertNumberOfCalls(t, 1) + + funcMock.On("1", "2").Return("1 2", ExampleError) + fn := funcMock.Build().(TestFuncTypeArgumentsOuts) + assert.NotPanics(t, func() { + res, err := fn("1", "2") + assert.ErrorIs(t, err, ExampleError) + assert.Equal(t, "1 2", res) + }) + }) + }) +}