Skip to content

Commit

Permalink
Separate checksumFunc and storageFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
sluongng committed Apr 17, 2024
1 parent fb9e249 commit 93c847f
Showing 1 changed file with 54 additions and 24 deletions.
78 changes: 54 additions & 24 deletions server/remote_asset/fetch_server/fetch_server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fetch_server

import (
"bytes"
"context"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -107,25 +108,20 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
return nil, err
}

digestFunction := req.GetDigestFunction()
if digestFunction == repb.DigestFunction_UNKNOWN {
digestFunction = repb.DigestFunction_SHA256
storageFunc := req.GetDigestFunction()
if storageFunc == repb.DigestFunction_UNKNOWN {
storageFunc = repb.DigestFunction_SHA256
}
var checksumFunc repb.DigestFunction_Value
var hash string
for _, qualifier := range req.GetQualifiers() {
var prefix string
if qualifier.GetName() == ChecksumQualifier {
if strings.HasPrefix(qualifier.GetValue(), sha256Prefix) {
if digestFunction != repb.DigestFunction_SHA256 {
log.Warningf("FetchBlob request came with %s digest function but SHA256 checksum.sri: %s", digestFunction, qualifier.GetValue())
}
digestFunction = repb.DigestFunction_SHA256
checksumFunc = repb.DigestFunction_SHA256
prefix = sha256Prefix
} else if strings.HasPrefix(qualifier.GetValue(), blake3Prefix) {
if digestFunction != repb.DigestFunction_BLAKE3 {
log.Warningf("FetchBlob request came with %s digest function but BLAKE3 checksum.sri: %s", digestFunction, qualifier.GetValue())
}
digestFunction = repb.DigestFunction_BLAKE3
checksumFunc = repb.DigestFunction_BLAKE3
prefix = blake3Prefix
}
}
Expand All @@ -140,7 +136,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
}
}
if len(hash) != 0 {
blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), digestFunction, hash)
blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), checksumFunc, hash, storageFunc)
if blobDigest != nil {
return &rapb.FetchBlobResponse{
Status: &statuspb.Status{Code: int32(gcodes.OK)},
Expand All @@ -159,7 +155,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if err != nil {
return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri)
}
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), digestFunction, httpClient, uri, hash)
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, hash)
if err != nil {
lastFetchErr = err
log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err)
Expand Down Expand Up @@ -190,17 +186,17 @@ func (p *FetchServer) FetchDirectory(ctx context.Context, req *rapb.FetchDirecto
return nil, status.UnimplementedError("FetchDirectory is not yet implemented")
}

func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, digestFunction repb.DigestFunction_Value, hash string) *repb.Digest {
func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, checksumFunc repb.DigestFunction_Value, checksum string, storageFunc repb.DigestFunction_Value) *repb.Digest {
blobDigest := &repb.Digest{
Hash: hash,
Hash: checksum,
// The digest size is unknown since the client only sends up
// the hash. We can look up the size using the Metadata API,
// which looks up only using the hash, so the size we pass here
// doesn't matter.
SizeBytes: 1,
}
expectedHash := blobDigest.Hash
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, digestFunction)
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
log.CtxInfof(ctx, "Looking up %s in cache", blobDigest.Hash)

// Lookup metadata to get the correct digest size to be returned to
Expand All @@ -216,7 +212,7 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string,
// Even though we successfully fetched metadata, we need to renew
// the cache entry (using Contains()) to ensure that it doesn't
// expire by the time the client requests it from cache.
cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, digestFunction)
cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
exists, err := cache.Contains(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err)
Expand All @@ -226,6 +222,25 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string,
log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest))
return nil
}

// If the digestFunc is supplied and differ from the checksum sri,
// after looking up the cached blob using checksum sri, re-upload
// that blob using the requested digestFunc.
if checksumFunc != storageFunc {
b, err := cache.Get(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to get cache reader for %s: %s", digest.String(blobDigest), err)
return nil
}
bsClient := p.env.GetByteStreamClient()
storageDigest, err := cachetools.UploadBlob(ctx, bsClient, instanceName, storageFunc, bytes.NewReader(b))
if err != nil {
log.CtxErrorf(ctx, "Failed to re-upload blob with new digestFunc %s for %s: %s", storageFunc, digest.String(blobDigest), err)
return nil
}
blobDigest = storageDigest
}

log.CtxInfof(ctx, "FetchServer found %s in cache", digest.String(blobDigest))
return blobDigest
}
Expand All @@ -234,7 +249,7 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string,
// returning the digest. The fetched contents are checked against the given
// expectedHash (if non-empty), and if there is a mismatch then an error is
// returned.
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, digestFunc repb.DigestFunction_Value, httpClient *http.Client, uri, expectedHash string) (*repb.Digest, error) {
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, checksum string) (*repb.Digest, error) {
log.CtxInfof(ctx, "Fetching %s", uri)
rsp, err := httpClient.Get(uri)
if err != nil {
Expand All @@ -248,9 +263,9 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
// If we know what the hash should be and the content length is known,
// then we know the full digest, and can pipe directly from the HTTP
// response to cache.
if expectedHash != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: expectedHash, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, digestFunc)
if checksumFunc == storageFunc && checksum != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: checksum, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, storageFunc)
if _, err := cachetools.UploadFromReader(ctx, bsClient, rn, rsp.Body); err != nil {
return nil, status.UnavailableErrorf("failed to upload %s to cache: %s", digest.String(d), err)
}
Expand All @@ -273,12 +288,27 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
log.Errorf("Failed to remove temp file: %s", err)
}
}()
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, digestFunc, tmpFilePath)

// If the requsted digestFunc is supplied and differ from the checksum sri,
// verify the downloaded file with the checksum sri before storing it to
// our cache.
if checksumFunc != storageFunc {
checksumDigestRN, err := cachetools.ComputeFileDigest(tmpFilePath, remoteInstanceName, checksumFunc)
if err != nil {
return nil, status.UnavailableErrorf("failed to compute checksum digest: %s", err)
}
if checksum != "" && checksumDigestRN.GetDigest().Hash != checksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, checksumDigestRN.GetDigest().Hash, checksum)
}
}
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, storageFunc, tmpFilePath)
if err != nil {
return nil, status.UnavailableErrorf("failed to add object to cache: %s", err)
}
if expectedHash != "" && blobDigest.Hash != expectedHash {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedHash)
// If the requsted digestFunc is supplied is the same with the checksum sri,
// verify the expected checksum of the downloaded file after storing it in our cache.
if checksumFunc == storageFunc && checksum != "" && blobDigest.Hash != checksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, checksum)
}
log.CtxInfof(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest))
return blobDigest, nil
Expand Down

0 comments on commit 93c847f

Please sign in to comment.