Skip to content

Commit 2dc33cc

Browse files
committed
s3cache: use s3util.Client where useful
1 parent 65c55db commit 2dc33cc

File tree

1 file changed

+26
-47
lines changed

1 file changed

+26
-47
lines changed

s3cache/s3cache.go

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
package s3cache
44

55
import (
6+
"bytes"
67
"context"
78
"errors"
89
"expvar"
910
"fmt"
10-
"io"
11+
"io/fs"
1112
"os"
1213
"path"
1314
"runtime"
@@ -78,6 +79,7 @@ type Cache struct {
7879
initOnce sync.Once
7980
push *taskgroup.Group
8081
start func(taskgroup.Task) *taskgroup.Group
82+
client *s3util.Client
8183

8284
getLocalHit expvar.Int // count of Get hits in the local cache
8385
getFaultHit expvar.Int // count of Get hits faulted in from S3
@@ -92,11 +94,14 @@ type Cache struct {
9294
func (s *Cache) init() {
9395
s.initOnce.Do(func() {
9496
s.push, s.start = taskgroup.New(nil).Limit(s.uploadConcurrency())
97+
s.client = &s3util.Client{Client: s.S3Client, Bucket: s.S3Bucket}
9598
})
9699
}
97100

98101
// Get implements the corresponding callback of the cache protocol.
99102
func (s *Cache) Get(ctx context.Context, actionID string) (objectID, diskPath string, _ error) {
103+
s.init()
104+
100105
objID, diskPath, err := s.Local.Get(ctx, actionID)
101106
if err == nil && objID != "" && diskPath != "" {
102107
s.getLocalHit.Add(1)
@@ -105,29 +110,22 @@ func (s *Cache) Get(ctx context.Context, actionID string) (objectID, diskPath st
105110

106111
// Reaching here, either we got a cache miss or an error reading from local.
107112
// Try reading the action from S3.
108-
act, err := s.S3Client.GetObject(ctx, &s3.GetObjectInput{
109-
Bucket: &s.S3Bucket,
110-
Key: s.actionKey(actionID),
111-
})
113+
action, err := s.client.GetData(ctx, s.actionKey(actionID))
112114
if err != nil {
113-
if s3util.IsNotExist(err) {
115+
if errors.Is(err, fs.ErrNotExist) {
114116
s.getFaultMiss.Add(1)
115117
return "", "", nil // cache miss, OK
116118
}
117119
return "", "", fmt.Errorf("[s3] read action %s: %w", actionID, err)
118120
}
119121

120122
// We got an action hit remotely, try to update the local copy.
121-
objectID, mtime, err := parseAction(act.Body)
122-
act.Body.Close()
123+
objectID, mtime, err := parseAction(action)
123124
if err != nil {
124125
return "", "", err
125126
}
126127

127-
obj, err := s.S3Client.GetObject(ctx, &s3.GetObjectInput{
128-
Bucket: &s.S3Bucket,
129-
Key: s.objectKey(objectID),
130-
})
128+
object, err := s.client.GetData(ctx, s.objectKey(objectID))
131129
if err != nil {
132130
// At this point we know the action exists, so if we can't read the
133131
// object report it as an error rather than a cache miss.
@@ -137,12 +135,11 @@ func (s *Cache) Get(ctx context.Context, actionID string) (objectID, diskPath st
137135

138136
// Now we should have the body; poke it into the local cache. Preserve the
139137
// modification timestamp recorded with the original action.
140-
defer obj.Body.Close()
141138
diskPath, err = s.Local.Put(ctx, gocache.Object{
142139
ActionID: actionID,
143140
ObjectID: objectID,
144-
Size: *obj.ContentLength,
145-
Body: obj.Body,
141+
Size: int64(len(object)),
142+
Body: bytes.NewReader(object),
146143
ModTime: mtime,
147144
})
148145
return objectID, diskPath, err
@@ -181,11 +178,8 @@ func (s *Cache) Put(ctx context.Context, obj gocache.Object) (diskPath string, _
181178
}
182179

183180
// Stage 2: Write the action record.
184-
if _, err := s.S3Client.PutObject(sctx, &s3.PutObjectInput{
185-
Bucket: &s.S3Bucket,
186-
Key: s.actionKey(obj.ActionID),
187-
Body: strings.NewReader(fmt.Sprintf("%s %d", obj.ObjectID, mtime.UnixNano())),
188-
}); err != nil {
181+
if err := s.client.Put(ctx, s.actionKey(obj.ActionID),
182+
strings.NewReader(fmt.Sprintf("%s %d", obj.ObjectID, mtime.UnixNano()))); err != nil {
189183
gocache.Logf(ctx, "write action %s: %v", obj.ActionID, err)
190184
return err
191185
}
@@ -234,39 +228,28 @@ func (s *Cache) maybePutObject(ctx context.Context, objectID, diskPath, etag str
234228
return time.Time{}, err
235229
}
236230

237-
key := s.objectKey(objectID)
238-
if _, err := s.S3Client.HeadObject(ctx, &s3.HeadObjectInput{
239-
Bucket: &s.S3Bucket,
240-
Key: key,
241-
IfMatch: &etag,
242-
}); err == nil {
243-
s.putS3Found.Add(1)
244-
return fi.ModTime(), nil // already present and matching
245-
}
246-
247-
if _, err := s.S3Client.PutObject(ctx, &s3.PutObjectInput{
248-
Bucket: &s.S3Bucket,
249-
Key: s.objectKey(objectID),
250-
Body: f,
251-
}); err != nil {
231+
written, err := s.client.PutCond(ctx, s.objectKey(objectID), etag, f)
232+
if err != nil {
252233
s.putS3Error.Add(1)
253234
gocache.Logf(ctx, "[s3] put object %s: %v", objectID, err)
254235
return fi.ModTime(), err
255236
}
237+
if written {
238+
s.putS3Found.Add(1)
239+
return fi.ModTime(), nil // already present and matching
240+
}
256241
s.putS3Object.Add(1)
257242
return fi.ModTime(), nil
258243
}
259244

260245
// makeKey assembles a complete key from the specified parts, including the key
261-
// prefix if one is defined. The result is a pointer for compatibility with the
262-
// S3 client library.
263-
func (s *Cache) makeKey(parts ...string) *string {
264-
key := path.Join(s.KeyPrefix, path.Join(parts...))
265-
return &key
246+
// prefix if one is defined.
247+
func (s *Cache) makeKey(parts ...string) string {
248+
return path.Join(s.KeyPrefix, path.Join(parts...))
266249
}
267250

268-
func (s *Cache) actionKey(id string) *string { return s.makeKey("action", id[:2], id) }
269-
func (s *Cache) objectKey(id string) *string { return s.makeKey("object", id[:2], id) }
251+
func (s *Cache) actionKey(id string) string { return s.makeKey("action", id[:2], id) }
252+
func (s *Cache) objectKey(id string) string { return s.makeKey("object", id[:2], id) }
270253

271254
func (s *Cache) uploadConcurrency() int {
272255
if s.UploadConcurrency <= 0 {
@@ -275,11 +258,7 @@ func (s *Cache) uploadConcurrency() int {
275258
return s.UploadConcurrency
276259
}
277260

278-
func parseAction(r io.Reader) (objectID string, mtime time.Time, _ error) {
279-
data, err := io.ReadAll(r)
280-
if err != nil {
281-
return "", time.Time{}, err
282-
}
261+
func parseAction(data []byte) (objectID string, mtime time.Time, _ error) {
283262
fs := strings.Fields(string(data))
284263
if len(fs) != 2 {
285264
return "", time.Time{}, errors.New("invalid action record")

0 commit comments

Comments
 (0)