diff --git a/server/remote_asset/fetch_server/fetch_server.go b/server/remote_asset/fetch_server/fetch_server.go index 5ac088b95d5..7874cf0ad53 100644 --- a/server/remote_asset/fetch_server/fetch_server.go +++ b/server/remote_asset/fetch_server/fetch_server.go @@ -1,6 +1,7 @@ package fetch_server import ( + "bytes" "context" "encoding/base64" "fmt" @@ -107,25 +108,20 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest) return nil, err } - digestFunction := req.GetDigestFunction() - if digestFunction == repb.DigestFunction_UNKNOWN { - digestFunction = repb.DigestFunction_SHA256 + storageFunc := req.GetDigestFunction() + if storageFunc == repb.DigestFunction_UNKNOWN { + storageFunc = repb.DigestFunction_SHA256 } + var checksumFunc repb.DigestFunction_Value var hash string for _, qualifier := range req.GetQualifiers() { var prefix string if qualifier.GetName() == ChecksumQualifier { if strings.HasPrefix(qualifier.GetValue(), sha256Prefix) { - if digestFunction != repb.DigestFunction_SHA256 { - log.Warningf("FetchBlob request came with %s digest function but SHA256 checksum.sri: %s", digestFunction, qualifier.GetValue()) - } - digestFunction = repb.DigestFunction_SHA256 + checksumFunc = repb.DigestFunction_SHA256 prefix = sha256Prefix } else if strings.HasPrefix(qualifier.GetValue(), blake3Prefix) { - if digestFunction != repb.DigestFunction_BLAKE3 { - log.Warningf("FetchBlob request came with %s digest function but BLAKE3 checksum.sri: %s", digestFunction, qualifier.GetValue()) - } - digestFunction = repb.DigestFunction_BLAKE3 + checksumFunc = repb.DigestFunction_BLAKE3 prefix = blake3Prefix } } @@ -140,7 +136,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest) } } if len(hash) != 0 { - blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), digestFunction, hash) + blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), checksumFunc, hash, storageFunc) if blobDigest != nil { return &rapb.FetchBlobResponse{ Status: &statuspb.Status{Code: int32(gcodes.OK)}, @@ -159,7 +155,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest) if err != nil { return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri) } - blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), digestFunction, httpClient, uri, hash) + blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, hash) if err != nil { lastFetchErr = err log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err) @@ -190,9 +186,9 @@ func (p *FetchServer) FetchDirectory(ctx context.Context, req *rapb.FetchDirecto return nil, status.UnimplementedError("FetchDirectory is not yet implemented") } -func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, digestFunction repb.DigestFunction_Value, hash string) *repb.Digest { +func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, checksumFunc repb.DigestFunction_Value, checksum string, storageFunc repb.DigestFunction_Value) *repb.Digest { blobDigest := &repb.Digest{ - Hash: hash, + Hash: checksum, // The digest size is unknown since the client only sends up // the hash. We can look up the size using the Metadata API, // which looks up only using the hash, so the size we pass here @@ -200,7 +196,7 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, SizeBytes: 1, } expectedHash := blobDigest.Hash - cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, digestFunction) + cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc) log.CtxInfof(ctx, "Looking up %s in cache", blobDigest.Hash) // Lookup metadata to get the correct digest size to be returned to @@ -216,7 +212,7 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, // Even though we successfully fetched metadata, we need to renew // the cache entry (using Contains()) to ensure that it doesn't // expire by the time the client requests it from cache. - cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, digestFunction) + cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc) exists, err := cache.Contains(ctx, cacheRN.ToProto()) if err != nil { log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err) @@ -226,6 +222,25 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest)) return nil } + + // If the digestFunc is supplied and differ from the checksum sri, + // after looking up the cached blob using checksum sri, re-upload + // that blob using the requested digestFunc. + if checksumFunc != storageFunc { + b, err := cache.Get(ctx, cacheRN.ToProto()) + if err != nil { + log.CtxErrorf(ctx, "Failed to get cache reader for %s: %s", digest.String(blobDigest), err) + return nil + } + bsClient := p.env.GetByteStreamClient() + storageDigest, err := cachetools.UploadBlob(ctx, bsClient, instanceName, storageFunc, bytes.NewReader(b)) + if err != nil { + log.CtxErrorf(ctx, "Failed to re-upload blob with new digestFunc %s for %s: %s", digest.String(blobDigest), err) + return nil + } + blobDigest = storageDigest + } + log.CtxInfof(ctx, "FetchServer found %s in cache", digest.String(blobDigest)) return blobDigest } @@ -234,7 +249,7 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, // returning the digest. The fetched contents are checked against the given // expectedHash (if non-empty), and if there is a mismatch then an error is // returned. -func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, digestFunc repb.DigestFunction_Value, httpClient *http.Client, uri, expectedHash string) (*repb.Digest, error) { +func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, checksum string) (*repb.Digest, error) { log.CtxInfof(ctx, "Fetching %s", uri) rsp, err := httpClient.Get(uri) if err != nil { @@ -248,9 +263,9 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn // If we know what the hash should be and the content length is known, // then we know the full digest, and can pipe directly from the HTTP // response to cache. - if expectedHash != "" && rsp.ContentLength >= 0 { - d := &repb.Digest{Hash: expectedHash, SizeBytes: rsp.ContentLength} - rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, digestFunc) + if checksumFunc == storageFunc && checksum != "" && rsp.ContentLength >= 0 { + d := &repb.Digest{Hash: checksum, SizeBytes: rsp.ContentLength} + rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, storageFunc) if _, err := cachetools.UploadFromReader(ctx, bsClient, rn, rsp.Body); err != nil { return nil, status.UnavailableErrorf("failed to upload %s to cache: %s", digest.String(d), err) } @@ -273,12 +288,27 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn log.Errorf("Failed to remove temp file: %s", err) } }() - blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, digestFunc, tmpFilePath) + + // If the requsted digestFunc is supplied and differ from the checksum sri, + // verify the downloaded file with the checksum sri before storing it to + // our cache. + if checksumFunc != storageFunc { + checksumDigestRN, err := cachetools.ComputeFileDigest(tmpFilePath, remoteInstanceName, checksumFunc) + if err != nil { + return nil, status.UnavailableErrorf("failed to compute checksum digest: %s", err) + } + if checksum != "" && checksumDigestRN.GetDigest().Hash != checksum { + return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, checksumDigestRN.GetDigest().Hash, checksum) + } + } + blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, storageFunc, tmpFilePath) if err != nil { return nil, status.UnavailableErrorf("failed to add object to cache: %s", err) } - if expectedHash != "" && blobDigest.Hash != expectedHash { - return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedHash) + // If the requsted digestFunc is supplied is the same with the checksum sri, + // verify the expected checksum of the downloaded file after storing it in our cache. + if checksumFunc == storageFunc && checksum != "" && blobDigest.Hash != checksum { + return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, checksum) } log.CtxInfof(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest)) return blobDigest, nil