|
| 1 | +package httptest |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "io" |
| 6 | + "net/http" |
| 7 | + "strings" |
| 8 | +) |
| 9 | + |
| 10 | +// --------------------------------------------------------------------------- |
| 11 | + |
| 12 | +type TransportComposer interface { |
| 13 | + Compose(base http.RoundTripper) http.RoundTripper |
| 14 | +} |
| 15 | + |
| 16 | +type Executor interface { |
| 17 | + Exec(ctx *Context, code string) |
| 18 | +} |
| 19 | + |
| 20 | +// --------------------------------------------------------------------------- |
| 21 | + |
| 22 | +func mimeType(ct string) string { |
| 23 | + |
| 24 | + if ct == "form" { |
| 25 | + return "application/x-www-form-urlencoded" |
| 26 | + } |
| 27 | + if ct == "binary" { |
| 28 | + return "application/octet-stream" |
| 29 | + } |
| 30 | + if strings.Index(ct, "/") < 0 { |
| 31 | + return "application/" + ct |
| 32 | + } |
| 33 | + return ct |
| 34 | +} |
| 35 | + |
| 36 | +// --------------------------------------------------------------------------- |
| 37 | + |
| 38 | +type Request struct { |
| 39 | + method string |
| 40 | + url string |
| 41 | + auth TransportComposer |
| 42 | + ctx *Context |
| 43 | + header http.Header |
| 44 | + bodyType string |
| 45 | + body string |
| 46 | +} |
| 47 | + |
| 48 | +func NewRequest(ctx *Context, method, url string) *Request { |
| 49 | + |
| 50 | + ctx.DeleteVar("resp") |
| 51 | + |
| 52 | + p := &Request{ |
| 53 | + ctx: ctx, |
| 54 | + method: method, |
| 55 | + url: url, |
| 56 | + header: make(http.Header), |
| 57 | + } |
| 58 | + ctx.Log(" ====>", method, url) |
| 59 | + return p |
| 60 | +} |
| 61 | + |
| 62 | +func (p *Request) WithAuth(v interface{}) *Request { |
| 63 | + |
| 64 | + if v == nil { |
| 65 | + p.auth = nil |
| 66 | + return p |
| 67 | + } |
| 68 | + if name, ok := v.(string); ok { |
| 69 | + auth, ok := p.ctx.auths[name] |
| 70 | + if !ok { |
| 71 | + p.ctx.Fatal("WithAuth failed: auth not found -", name) |
| 72 | + } |
| 73 | + p.auth = auth |
| 74 | + return p |
| 75 | + } |
| 76 | + if auth, ok := v.(TransportComposer); ok { |
| 77 | + p.auth = auth |
| 78 | + return p |
| 79 | + } |
| 80 | + p.ctx.Fatal("WithAuth failed: invalid auth -", v) |
| 81 | + return p |
| 82 | +} |
| 83 | + |
| 84 | +func (p *Request) WithHeader(key string, values ...string) *Request { |
| 85 | + |
| 86 | + p.header[key] = values |
| 87 | + return p |
| 88 | +} |
| 89 | + |
| 90 | +func (p *Request) WithBody(bodyType, body string) *Request { |
| 91 | + |
| 92 | + p.bodyType = mimeType(bodyType) |
| 93 | + p.body = body |
| 94 | + return p |
| 95 | +} |
| 96 | + |
| 97 | +func (p *Request) WithBodyf(bodyType, format string, v ...interface{}) *Request { |
| 98 | + |
| 99 | + p.bodyType = mimeType(bodyType) |
| 100 | + p.body = fmt.Sprintf(format, v...) |
| 101 | + return p |
| 102 | +} |
| 103 | + |
| 104 | +func mergeHeader(to, from http.Header) { |
| 105 | + |
| 106 | + for k, v := range from { |
| 107 | + to[k] = v |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +func (p *Request) send() (resp *http.Response, err error) { |
| 112 | + |
| 113 | + var body io.Reader |
| 114 | + if len(p.body) > 0 { |
| 115 | + body = strings.NewReader(p.body) |
| 116 | + } |
| 117 | + req, err := p.ctx.newRequest(p.method, p.url, body) |
| 118 | + if err != nil { |
| 119 | + p.ctx.Fatal("http.NewRequest failed:", p.method, p.url, p.body, err) |
| 120 | + return |
| 121 | + } |
| 122 | + |
| 123 | + mergeHeader(req.Header, p.ctx.DefaultHeader) |
| 124 | + |
| 125 | + if body != nil { |
| 126 | + if p.bodyType != "" { |
| 127 | + req.Header.Set("Content-Type", p.bodyType) |
| 128 | + } |
| 129 | + req.ContentLength = int64(len(p.body)) |
| 130 | + } |
| 131 | + |
| 132 | + mergeHeader(req.Header, p.header) |
| 133 | + |
| 134 | + t := p.ctx.transport |
| 135 | + if p.auth != nil { |
| 136 | + t = p.auth.Compose(t) |
| 137 | + } |
| 138 | + |
| 139 | + c := &http.Client{Transport: t} |
| 140 | + return c.Do(req) |
| 141 | +} |
| 142 | + |
| 143 | +func (p *Request) Ret(code int) (resp *Response) { |
| 144 | + |
| 145 | + resp1, err := p.send() |
| 146 | + resp = newResponse(p, resp1, err) |
| 147 | + p.ctx.MatchVar("resp", map[string]interface{}{ |
| 148 | + "body": resp.BodyObj, |
| 149 | + "header": resp.Header, |
| 150 | + "code": float64(resp.StatusCode), |
| 151 | + }) |
| 152 | + return resp.matchCode(code) |
| 153 | +} |
| 154 | + |
| 155 | +// --------------------------------------------------------------------------- |
| 156 | + |
| 157 | +type TestingT interface { |
| 158 | + Fatal(args ...interface{}) |
| 159 | + Log(args ...interface{}) |
| 160 | +} |
| 161 | + |
| 162 | +type NilTestingT struct {} |
| 163 | + |
| 164 | +func (p NilTestingT) Fatal(args ...interface{}) {} |
| 165 | +func (p NilTestingT) Log(args ...interface{}) {} |
| 166 | + |
| 167 | +// --------------------------------------------------------------------------- |
| 168 | + |
| 169 | +type Context struct { |
| 170 | + TestingT |
| 171 | + varsMgr |
| 172 | + hostsMgr |
| 173 | + transport http.RoundTripper |
| 174 | + auths map[string]TransportComposer |
| 175 | + DefaultHeader http.Header |
| 176 | + MatchResponseError func(message string, req *Request, resp *Response) |
| 177 | +} |
| 178 | + |
| 179 | +func New(t TestingT) *Context { |
| 180 | + |
| 181 | + auths := make(map[string]TransportComposer) |
| 182 | + p := &Context{ |
| 183 | + TestingT: t, |
| 184 | + auths: auths, |
| 185 | + transport: http.DefaultTransport, |
| 186 | + DefaultHeader: make(http.Header), |
| 187 | + MatchResponseError: matchRespError, |
| 188 | + } |
| 189 | + p.initHostsMgr() |
| 190 | + p.initVarsMgr() |
| 191 | + return p |
| 192 | +} |
| 193 | + |
| 194 | +func (p *Context) SetTransport(transport http.RoundTripper) { |
| 195 | + |
| 196 | + p.transport = transport |
| 197 | +} |
| 198 | + |
| 199 | +func (p *Context) SetAuth(name string, auth TransportComposer) { |
| 200 | + |
| 201 | + p.auths[name] = auth |
| 202 | +} |
| 203 | + |
| 204 | +func (p *Context) Exec(executor Executor, code string) *Context { |
| 205 | + |
| 206 | + executor.Exec(p, code) |
| 207 | + return p |
| 208 | +} |
| 209 | + |
| 210 | +func (p *Context) Request(method, url string) *Request { |
| 211 | + |
| 212 | + return NewRequest(p, method, url) |
| 213 | +} |
| 214 | + |
| 215 | +func (p *Context) Requestf(method, format string, v ...interface{}) *Request { |
| 216 | + |
| 217 | + url := fmt.Sprintf(format, v...) |
| 218 | + return NewRequest(p, method, url) |
| 219 | +} |
| 220 | + |
| 221 | +// --------------------------------------------------------------------------- |
| 222 | + |
0 commit comments