|
| 1 | +package retrieval |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "fmt" |
| 7 | + "io" |
| 8 | + "net/http" |
| 9 | + "net/http/httptest" |
| 10 | + "net/url" |
| 11 | + "testing" |
| 12 | + |
| 13 | + prime "github.com/ipld/go-ipld-prime" |
| 14 | + "github.com/multiformats/go-multihash" |
| 15 | + "github.com/storacha/go-ucanto/core/dag/blockstore" |
| 16 | + "github.com/storacha/go-ucanto/core/delegation" |
| 17 | + "github.com/storacha/go-ucanto/core/invocation" |
| 18 | + "github.com/storacha/go-ucanto/core/ipld" |
| 19 | + "github.com/storacha/go-ucanto/core/receipt" |
| 20 | + "github.com/storacha/go-ucanto/core/receipt/fx" |
| 21 | + "github.com/storacha/go-ucanto/core/result" |
| 22 | + "github.com/storacha/go-ucanto/core/result/failure" |
| 23 | + "github.com/storacha/go-ucanto/core/schema" |
| 24 | + ed25519 "github.com/storacha/go-ucanto/principal/ed25519/signer" |
| 25 | + "github.com/storacha/go-ucanto/server" |
| 26 | + "github.com/storacha/go-ucanto/server/retrieval" |
| 27 | + "github.com/storacha/go-ucanto/testing/fixtures" |
| 28 | + "github.com/storacha/go-ucanto/testing/helpers" |
| 29 | + "github.com/storacha/go-ucanto/testing/helpers/printer" |
| 30 | + thttp "github.com/storacha/go-ucanto/transport/http" |
| 31 | + "github.com/storacha/go-ucanto/ucan" |
| 32 | + "github.com/storacha/go-ucanto/validator" |
| 33 | + "github.com/stretchr/testify/require" |
| 34 | +) |
| 35 | + |
| 36 | +type serveCaveats struct { |
| 37 | + Digest []byte |
| 38 | + Range []int |
| 39 | +} |
| 40 | + |
| 41 | +var serveTS = helpers.Must(prime.LoadSchemaBytes([]byte(` |
| 42 | + type ServeCaveats struct { |
| 43 | + digest Bytes |
| 44 | + range [Int] |
| 45 | + } |
| 46 | + type ServeOk struct { |
| 47 | + digest Bytes |
| 48 | + range [Int] |
| 49 | + } |
| 50 | +`))) |
| 51 | + |
| 52 | +func (sc serveCaveats) ToIPLD() (ipld.Node, error) { |
| 53 | + return ipld.WrapWithRecovery(&sc, serveTS.TypeByName("ServeCaveats")) |
| 54 | +} |
| 55 | + |
| 56 | +type serveOk struct { |
| 57 | + Digest []byte |
| 58 | + Range []int |
| 59 | +} |
| 60 | + |
| 61 | +func (so serveOk) ToIPLD() (ipld.Node, error) { |
| 62 | + return ipld.WrapWithRecovery(&so, serveTS.TypeByName("ServeOk")) |
| 63 | +} |
| 64 | + |
| 65 | +var serveCaveatsReader = schema.Struct[serveCaveats](serveTS.TypeByName("ServeCaveats"), nil) |
| 66 | + |
| 67 | +var serve = validator.NewCapability( |
| 68 | + "content/serve", |
| 69 | + schema.DIDString(), |
| 70 | + serveCaveatsReader, |
| 71 | + validator.DefaultDerives, |
| 72 | +) |
| 73 | + |
| 74 | +func mkDelegationChain(t *testing.T, rootIssuer ucan.Signer, endAudience ucan.Principal, can ucan.Ability, len int) delegation.Delegation { |
| 75 | + require.GreaterOrEqual(t, len, 1) |
| 76 | + |
| 77 | + var dlg delegation.Delegation |
| 78 | + var proof delegation.Delegation |
| 79 | + |
| 80 | + iss := rootIssuer |
| 81 | + aud, err := ed25519.Generate() |
| 82 | + require.NoError(t, err) |
| 83 | + |
| 84 | + for range len - 1 { |
| 85 | + var opts []delegation.Option |
| 86 | + if proof != nil { |
| 87 | + opts = append(opts, delegation.WithProof(delegation.FromDelegation(proof))) |
| 88 | + } |
| 89 | + dlg, err = delegation.Delegate( |
| 90 | + iss, |
| 91 | + aud, |
| 92 | + []ucan.Capability[ucan.NoCaveats]{ |
| 93 | + ucan.NewCapability(can, rootIssuer.DID().String(), ucan.NoCaveats{}), |
| 94 | + }, |
| 95 | + opts..., |
| 96 | + ) |
| 97 | + require.NoError(t, err) |
| 98 | + iss = aud |
| 99 | + aud, err = ed25519.Generate() |
| 100 | + require.NoError(t, err) |
| 101 | + proof = dlg |
| 102 | + } |
| 103 | + |
| 104 | + var opts []delegation.Option |
| 105 | + if proof != nil { |
| 106 | + opts = append(opts, delegation.WithProof(delegation.FromDelegation(proof))) |
| 107 | + } |
| 108 | + dlg, err = delegation.Delegate( |
| 109 | + iss, |
| 110 | + endAudience, |
| 111 | + []ucan.Capability[ucan.NoCaveats]{ |
| 112 | + ucan.NewCapability(can, rootIssuer.DID().String(), ucan.NoCaveats{}), |
| 113 | + }, |
| 114 | + opts..., |
| 115 | + ) |
| 116 | + require.NoError(t, err) |
| 117 | + |
| 118 | + return dlg |
| 119 | +} |
| 120 | + |
| 121 | +func calcHeadersSize(h http.Header) int { |
| 122 | + var buf bytes.Buffer |
| 123 | + h.Write(&buf) |
| 124 | + return buf.Len() |
| 125 | +} |
| 126 | + |
| 127 | +var kb = 1024 |
| 128 | + |
| 129 | +// newRetrievalHTTPServer creates a HTTP server that will send a 431 response |
| 130 | +// when HTTP headers exceed 2KiB, but otherwise calls the UCAN server as usual |
| 131 | +func newRetrievalHTTPServer(t *testing.T, server server.ServerView[retrieval.Service]) *httptest.Server { |
| 132 | + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 133 | + t.Logf("-> %s %s", r.Method, r.URL) |
| 134 | + printer.PrintHeaders(t, r.Header) |
| 135 | + size := calcHeadersSize(r.Header) |
| 136 | + t.Logf("Total size of headers: %s", printer.SprintBytes(t, size)) |
| 137 | + |
| 138 | + if size > 2*kb { |
| 139 | + t.Logf("<- %d %s", http.StatusRequestHeaderFieldsTooLarge, http.StatusText(http.StatusRequestHeaderFieldsTooLarge)) |
| 140 | + w.WriteHeader(http.StatusRequestHeaderFieldsTooLarge) |
| 141 | + return |
| 142 | + } |
| 143 | + |
| 144 | + resp, err := server.Request(r.Context(), thttp.NewInboundRequest(r.URL, r.Body, r.Header)) |
| 145 | + require.NoError(t, err) |
| 146 | + |
| 147 | + t.Logf("<- %d %s", resp.Status(), http.StatusText(resp.Status())) |
| 148 | + printer.PrintHeaders(t, resp.Headers()) |
| 149 | + t.Logf("Total size of headers: %s", printer.SprintBytes(t, calcHeadersSize(resp.Headers()))) |
| 150 | + |
| 151 | + for name, values := range resp.Headers() { |
| 152 | + for _, value := range values { |
| 153 | + w.Header().Add(name, value) |
| 154 | + } |
| 155 | + } |
| 156 | + w.WriteHeader(resp.Status()) |
| 157 | + body := resp.Body() |
| 158 | + if body != nil { |
| 159 | + // log out the "not extended" dag-json response for debugging purposes |
| 160 | + if resp.Status() == http.StatusNotExtended { |
| 161 | + bodyBytes, err := io.ReadAll(body) |
| 162 | + require.NoError(t, err) |
| 163 | + t.Logf("Body: %s", string(bodyBytes)) |
| 164 | + body = io.NopCloser(bytes.NewReader(bodyBytes)) |
| 165 | + } |
| 166 | + _, err := io.Copy(w, body) |
| 167 | + require.NoError(t, err) |
| 168 | + } |
| 169 | + })) |
| 170 | +} |
| 171 | + |
| 172 | +type testDelegationCache struct { |
| 173 | + t *testing.T |
| 174 | + data map[string]delegation.Delegation |
| 175 | +} |
| 176 | + |
| 177 | +func (c *testDelegationCache) Get(ctx context.Context, root ipld.Link) (delegation.Delegation, bool, error) { |
| 178 | + d, ok := c.data[root.String()] |
| 179 | + if ok { |
| 180 | + c.t.Logf("CACHE HIT: %s", root.String()) |
| 181 | + } else { |
| 182 | + c.t.Logf("CACHE MISS: %s", root.String()) |
| 183 | + } |
| 184 | + return d, ok, nil |
| 185 | +} |
| 186 | + |
| 187 | +func (c *testDelegationCache) Put(ctx context.Context, d delegation.Delegation) error { |
| 188 | + c.data[d.Link().String()] = d |
| 189 | + c.t.Logf("CACHE PUT: %s", d.Link().String()) |
| 190 | + return nil |
| 191 | +} |
| 192 | + |
| 193 | +func newTestDelegationCache(t *testing.T) *testDelegationCache { |
| 194 | + return &testDelegationCache{t: t, data: map[string]delegation.Delegation{}} |
| 195 | +} |
| 196 | + |
| 197 | +func TestExecute(t *testing.T) { |
| 198 | + chainLengths := []int{1, 5, 10} |
| 199 | + for _, length := range chainLengths { |
| 200 | + t.Run(fmt.Sprintf("retrieval via partitioned request (proof chain of %d delegations)", length), func(t *testing.T) { |
| 201 | + dlg := mkDelegationChain(t, fixtures.Service, fixtures.Alice, serve.Can(), length) |
| 202 | + data := helpers.RandomBytes(512) |
| 203 | + |
| 204 | + // create a retrieval server that will send bytes back for an authorized |
| 205 | + // UCAN invocation sent in HTTP headers of the GET request |
| 206 | + server, err := retrieval.NewServer( |
| 207 | + fixtures.Service, |
| 208 | + retrieval.WithServiceMethod( |
| 209 | + serve.Can(), |
| 210 | + retrieval.Provide( |
| 211 | + serve, |
| 212 | + func(ctx context.Context, cap ucan.Capability[serveCaveats], inv invocation.Invocation, ictx server.InvocationContext, req retrieval.Request) (result.Result[serveOk, failure.IPLDBuilderFailure], fx.Effects, retrieval.Response, error) { |
| 213 | + t.Logf("Handling %s: %s", serve.Can(), req.URL.String()) |
| 214 | + t.Log("Invocation:") |
| 215 | + printer.PrintDelegation(t, inv, 0) |
| 216 | + nb := cap.Nb() |
| 217 | + result := result.Ok[serveOk, failure.IPLDBuilderFailure](serveOk(nb)) |
| 218 | + start, end := nb.Range[0], nb.Range[1] |
| 219 | + length := end - start + 1 |
| 220 | + headers := http.Header{} |
| 221 | + headers.Set("Content-Length", fmt.Sprintf("%d", length)) |
| 222 | + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(data))) |
| 223 | + response := retrieval.Response{ |
| 224 | + Status: http.StatusPartialContent, |
| 225 | + Headers: headers, |
| 226 | + Body: io.NopCloser(bytes.NewReader(data[start : end+1])), |
| 227 | + } |
| 228 | + return result, nil, response, nil |
| 229 | + }, |
| 230 | + ), |
| 231 | + ), |
| 232 | + retrieval.WithDelegationCache(newTestDelegationCache(t)), |
| 233 | + ) |
| 234 | + require.NoError(t, err) |
| 235 | + |
| 236 | + httpServer := newRetrievalHTTPServer(t, server) |
| 237 | + defer httpServer.Close() |
| 238 | + |
| 239 | + // make a UCAN authorized retrieval request for some bytes from the data |
| 240 | + |
| 241 | + // identify the data |
| 242 | + digest, err := multihash.Sum(data, multihash.SHA2_256, -1) |
| 243 | + require.NoError(t, err) |
| 244 | + |
| 245 | + // specify the byte range we want to receive (inclusive) |
| 246 | + contentRange := []int{100, 200} |
| 247 | + |
| 248 | + url, err := url.Parse(httpServer.URL) |
| 249 | + require.NoError(t, err) |
| 250 | + |
| 251 | + // the URL doesn't really have a consequence on this test, but it can be |
| 252 | + // used to idenitfy the data if not done so in the invocation caveats |
| 253 | + conn, err := NewConnection(fixtures.Service, url.JoinPath("blob", "z"+digest.B58String())) |
| 254 | + require.NoError(t, err) |
| 255 | + |
| 256 | + inv, err := serve.Invoke( |
| 257 | + fixtures.Alice, |
| 258 | + fixtures.Service, |
| 259 | + fixtures.Service.DID().String(), |
| 260 | + serveCaveats{Digest: digest, Range: contentRange}, |
| 261 | + delegation.WithProof(delegation.FromDelegation(dlg)), |
| 262 | + ) |
| 263 | + require.NoError(t, err) |
| 264 | + |
| 265 | + // send the invocation, and receive the execution response _as well as_ the |
| 266 | + // HTTP response! |
| 267 | + xRes, hRes, err := Execute(t.Context(), inv, conn) |
| 268 | + require.NoError(t, err) |
| 269 | + require.NotNil(t, xRes) |
| 270 | + require.NotNil(t, hRes) |
| 271 | + |
| 272 | + rcptLink, ok := xRes.Get(inv.Link()) |
| 273 | + require.True(t, ok) |
| 274 | + |
| 275 | + bs, err := blockstore.NewBlockReader(blockstore.WithBlocksIterator(xRes.Blocks())) |
| 276 | + require.NoError(t, err) |
| 277 | + |
| 278 | + rcpt, err := receipt.NewAnyReceipt(rcptLink, bs) |
| 279 | + require.NoError(t, err) |
| 280 | + |
| 281 | + // verify the receipt is not an error, and that the info matches the |
| 282 | + // invocation caveats |
| 283 | + o, x := result.Unwrap(rcpt.Out()) |
| 284 | + require.Nil(t, x) |
| 285 | + |
| 286 | + sok, err := ipld.Rebind[serveOk](o, serveTS.TypeByName("ServeOk")) |
| 287 | + require.NoError(t, err) |
| 288 | + require.Equal(t, digest, multihash.Multihash(sok.Digest)) |
| 289 | + require.Equal(t, []int{100, 200}, sok.Range) |
| 290 | + |
| 291 | + // verify the data in the HTTP body is what we asked for |
| 292 | + body, err := io.ReadAll(hRes.Body()) |
| 293 | + require.NoError(t, err) |
| 294 | + require.Equal(t, data[100:200+1], body) |
| 295 | + }) |
| 296 | + } |
| 297 | +} |
0 commit comments