Skip to content

Commit

Permalink
fetch_server: support digest function and blake3 (#6382)
Browse files Browse the repository at this point in the history
Co-authored-by: Brandon Duffany <[email protected]>
  • Loading branch information
sluongng and bduffany authored Apr 23, 2024
1 parent 5c3630a commit d532563
Show file tree
Hide file tree
Showing 3 changed files with 477 additions and 58 deletions.
22 changes: 21 additions & 1 deletion server/remote_asset/fetch_server/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "fetch_server",
Expand All @@ -23,3 +23,23 @@ go_library(
"@org_golang_google_protobuf//types/known/durationpb",
],
)

go_test(
name = "fetch_server_test",
srcs = ["fetch_server_test.go"],
deps = [
":fetch_server",
"//proto:remote_asset_go_proto",
"//proto:remote_execution_go_proto",
"//proto:resource_go_proto",
"//server/remote_cache/byte_stream_server",
"//server/remote_cache/digest",
"//server/testutil/testenv",
"//server/util/prefix",
"//server/util/scratchspace",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_genproto_googleapis_bytestream//:bytestream",
"@org_golang_google_grpc//:go_default_library",
],
)
187 changes: 130 additions & 57 deletions server/remote_asset/fetch_server/fetch_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import (
)

const (
checksumQualifier = "checksum.sri"
ChecksumQualifier = "checksum.sri"
sha256Prefix = "sha256-"
blake3Prefix = "blake3-"
maxHTTPTimeout = 60 * time.Minute
)

Expand Down Expand Up @@ -64,9 +65,6 @@ func checkPreconditions(env environment.Env) error {
if env.GetCache() == nil {
return status.FailedPreconditionError("missing Cache")
}
if env.GetByteStreamClient() == nil {
return status.FailedPreconditionError("missing ByteStreamClient")
}
return nil
}

Expand Down Expand Up @@ -109,52 +107,43 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
return nil, err
}

var expectedSHA256 string

storageFunc := req.GetDigestFunction()
if storageFunc == repb.DigestFunction_UNKNOWN {
storageFunc = repb.DigestFunction_SHA256
}
var checksumFunc repb.DigestFunction_Value
var expectedChecksum string
for _, qualifier := range req.GetQualifiers() {
if qualifier.GetName() == checksumQualifier && strings.HasPrefix(qualifier.GetValue(), sha256Prefix) {
b64sha256 := strings.TrimPrefix(qualifier.GetValue(), sha256Prefix)
sha256, err := base64.StdEncoding.DecodeString(b64sha256)
if err != nil {
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
blobDigest := &repb.Digest{
Hash: fmt.Sprintf("%x", sha256),
// The digest size is unknown since the client only sends up
// the hash. We can look up the size using the Metadata API,
// which looks up only using the hash, so the size we pass here
// doesn't matter.
SizeBytes: 1,
var prefix string
if qualifier.GetName() == ChecksumQualifier {
if strings.HasPrefix(qualifier.GetValue(), sha256Prefix) {
checksumFunc = repb.DigestFunction_SHA256
prefix = sha256Prefix
} else if strings.HasPrefix(qualifier.GetValue(), blake3Prefix) {
checksumFunc = repb.DigestFunction_BLAKE3
prefix = blake3Prefix
}
expectedSHA256 = blobDigest.Hash
cacheRN := digest.NewResourceName(blobDigest, req.GetInstanceName(), rspb.CacheType_CAS, repb.DigestFunction_SHA256)

log.CtxInfof(ctx, "Looking up %s in cache", blobDigest.Hash)

// Lookup metadata to get the correct digest size to be returned to
// the client.
cache := p.env.GetCache()
md, err := cache.Metadata(ctx, cacheRN.ToProto())
}
if prefix != "" {
b64hash := strings.TrimPrefix(qualifier.GetValue(), prefix)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
log.CtxInfof(ctx, "FetchServer failed to get metadata for %s: %s", expectedSHA256, err)
continue
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
blobDigest.SizeBytes = md.DigestSizeBytes
expectedChecksum = fmt.Sprintf("%x", decodedHash)
break
}
}
if len(expectedChecksum) != 0 {
blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), checksumFunc, expectedChecksum)
// If the digestFunc is supplied and differ from the checksum sri,
// after looking up the cached blob using checksum sri, re-upload
// that blob using the requested digestFunc.
if blobDigest != nil && checksumFunc != storageFunc {
blobDigest = p.rewriteToCache(ctx, blobDigest, req.GetInstanceName(), checksumFunc, storageFunc)
}

// Even though we successfully fetched metadata, we need to renew
// the cache entry (using Contains()) to ensure that it doesn't
// expire by the time the client requests it from cache.
cacheRN = digest.NewResourceName(blobDigest, req.GetInstanceName(), rspb.CacheType_CAS, repb.DigestFunction_SHA256)
exists, err := cache.Contains(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err)
continue
}
if !exists {
log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest))
continue
}
log.CtxInfof(ctx, "FetchServer found %s in cache", digest.String(blobDigest))
if blobDigest != nil {
return &rapb.FetchBlobResponse{
Status: &statuspb.Status{Code: int32(gcodes.OK)},
BlobDigest: blobDigest,
Expand All @@ -172,7 +161,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if err != nil {
return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri)
}
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, expectedSHA256)
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, expectedChecksum)
if err != nil {
lastFetchErr = err
log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err)
Expand Down Expand Up @@ -203,12 +192,81 @@ func (p *FetchServer) FetchDirectory(ctx context.Context, req *rapb.FetchDirecto
return nil, status.UnimplementedError("FetchDirectory is not yet implemented")
}

func (p *FetchServer) rewriteToCache(ctx context.Context, blobDigest *repb.Digest, instanceName string, fromFunc, toFunc repb.DigestFunction_Value) *repb.Digest {
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, fromFunc)
cache := p.env.GetCache()
reader, err := cache.Reader(ctx, cacheRN.ToProto(), 0, 0)
if err != nil {
log.CtxErrorf(ctx, "Failed to get cache reader for %s: %s", digest.String(blobDigest), err)
return nil
}

tmpFilePath, err := tempCopy(reader)
if err != nil {
log.CtxErrorf(ctx, "Failed to copy from reader to temp for %s: %s", digest.String(blobDigest), err)
return nil
}
defer func() {
if err := os.Remove(tmpFilePath); err != nil {
log.Errorf("Failed to remove temp file: %s", err)
}
}()

bsClient := p.env.GetByteStreamClient()
storageDigest, err := cachetools.UploadFile(ctx, bsClient, instanceName, toFunc, tmpFilePath)
if err != nil {
log.CtxErrorf(ctx, "Failed to re-upload blob with new digestFunc %s for %s: %s", toFunc, digest.String(blobDigest), err)
return nil
}
return storageDigest
}

func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, checksumFunc repb.DigestFunction_Value, expectedChecksum string) *repb.Digest {
blobDigest := &repb.Digest{
Hash: expectedChecksum,
// The digest size is unknown since the client only sends up
// the hash. We can look up the size using the Metadata API,
// which looks up only using the hash, so the size we pass here
// doesn't matter.
SizeBytes: 1,
}
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
log.CtxDebugf(ctx, "Looking up %s in cache", blobDigest.Hash)

// Lookup metadata to get the correct digest size to be returned to
// the client.
cache := p.env.GetCache()
md, err := cache.Metadata(ctx, cacheRN.ToProto())
if err != nil {
log.CtxInfof(ctx, "FetchServer failed to get metadata for %s: %s", expectedChecksum, err)
return nil
}
blobDigest.SizeBytes = md.DigestSizeBytes

// Even though we successfully fetched metadata, we need to renew
// the cache entry (using Contains()) to ensure that it doesn't
// expire by the time the client requests it from cache.
cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
exists, err := cache.Contains(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err)
return nil
}
if !exists {
log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest))
return nil
}

log.CtxDebugf(ctx, "FetchServer found %s in cache", digest.String(blobDigest))
return blobDigest
}

// mirrorToCache uploads the contents at the given URI to the given cache,
// returning the digest. The fetched contents are checked against the given
// expectedSHA256 (if non-empty), and if there is a mismatch then an error is
// expectedChecksum (if non-empty), and if there is a mismatch then an error is
// returned.
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri, expectedSHA256 string) (*repb.Digest, error) {
log.CtxInfof(ctx, "Fetching %s", uri)
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, expectedChecksum string) (*repb.Digest, error) {
log.CtxDebugf(ctx, "Fetching %s", uri)
rsp, err := httpClient.Get(uri)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP GET failed: %s", uri, err)
Expand All @@ -218,12 +276,12 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP %s", uri, err)
}

// If we know what the SHA256 should be and the content length is known,
// If we know what the hash should be and the content length is known,
// then we know the full digest, and can pipe directly from the HTTP
// response to cache.
if expectedSHA256 != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: expectedSHA256, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, repb.DigestFunction_SHA256)
if checksumFunc == storageFunc && expectedChecksum != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: expectedChecksum, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, storageFunc)
if _, err := cachetools.UploadFromReader(ctx, bsClient, rn, rsp.Body); err != nil {
return nil, status.UnavailableErrorf("failed to upload %s to cache: %s", digest.String(d), err)
}
Expand All @@ -246,14 +304,29 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
log.Errorf("Failed to remove temp file: %s", err)
}
}()
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, repb.DigestFunction_SHA256, tmpFilePath)

// If the requested digestFunc is supplied and differ from the checksum sri,
// verify the downloaded file with the checksum sri before storing it to
// our cache.
if checksumFunc != storageFunc {
checksumDigestRN, err := cachetools.ComputeFileDigest(tmpFilePath, remoteInstanceName, checksumFunc)
if err != nil {
return nil, status.UnavailableErrorf("failed to compute checksum digest: %s", err)
}
if expectedChecksum != "" && checksumDigestRN.GetDigest().GetHash() != expectedChecksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, checksumDigestRN.GetDigest().Hash, expectedChecksum)
}
}
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, storageFunc, tmpFilePath)
if err != nil {
return nil, status.UnavailableErrorf("failed to add object to cache: %s", err)
}
if expectedSHA256 != "" && blobDigest.Hash != expectedSHA256 {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedSHA256)
// If the requested digestFunc is supplied is the same with the checksum sri,
// verify the expected checksum of the downloaded file after storing it in our cache.
if checksumFunc == storageFunc && expectedChecksum != "" && blobDigest.Hash != expectedChecksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedChecksum)
}
log.CtxInfof(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest))
log.CtxDebugf(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest))
return blobDigest, nil
}

Expand Down
Loading

0 comments on commit d532563

Please sign in to comment.