Skip to content

Commit

Permalink
llbsolver: fix recompute test and avoid struct copy
Browse files Browse the repository at this point in the history
Signed-off-by: Tonis Tiigi <[email protected]>
(cherry picked from commit d1a3df3)
  • Loading branch information
tonistiigi committed Dec 3, 2024
1 parent 64293f9 commit ec39add
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
24 changes: 11 additions & 13 deletions solver/llbsolver/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (dpc *detectPrunedCacheID) Load(op *pb.Op, md *pb.OpMetadata, opt *solver.V

func Load(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, opts ...LoadOpt) (solver.Edge, error) {
return loadLLB(ctx, def, polEngine, func(dgst digest.Digest, op *op, load func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error) {
vtx, err := newVertex(dgst, &op.Op, op.Metadata, load, opts...)
vtx, err := newVertex(dgst, op.Op, op.Metadata, load, opts...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -234,15 +234,12 @@ func recomputeDigests(ctx context.Context, all map[digest.Digest]*op, visited ma
return newDgst, nil
}

// op is a private wrapper around pb.Op that includes its metadata.
type op struct {
pb.Op
*pb.Op
Metadata *pb.OpMetadata
}

func (o *op) Unmarshal(data []byte) error {
return o.Op.UnmarshalVT(data)
}

// loadLLB loads LLB.
// fn is executed sequentially.
func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, fn func(digest.Digest, *op, func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error)) (solver.Edge, error) {
Expand All @@ -255,19 +252,20 @@ func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEval
var lastDgst digest.Digest

for _, dt := range def.Def {
var op op
if err := op.Unmarshal(dt); err != nil {
var pbop pb.Op
if err := pbop.Unmarshal(dt); err != nil {
return solver.Edge{}, errors.Wrap(err, "failed to parse llb proto op")
}
dgst := digest.FromBytes(dt)
if polEngine != nil {
if _, err := polEngine.Evaluate(ctx, op.GetSource()); err != nil {
if _, err := polEngine.Evaluate(ctx, pbop.GetSource()); err != nil {
return solver.Edge{}, errors.Wrap(err, "error evaluating the source policy")
}
}
op.Metadata = def.Metadata[string(dgst)]

allOps[dgst] = &op
allOps[dgst] = &op{
Op: &pbop,
Metadata: def.Metadata[string(dgst)],
}
lastDgst = dgst
}

Expand Down Expand Up @@ -309,7 +307,7 @@ func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEval
return nil, errors.Errorf("invalid missing input digest %s", dgst)
}

if err := opsutils.Validate(&op.Op); err != nil {
if err := opsutils.Validate(op.Op); err != nil {
return nil, err
}

Expand Down
20 changes: 10 additions & 10 deletions solver/llbsolver/vertex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@ func TestRecomputeDigests(t *testing.T) {
require.NoError(t, err)
op2Digest := digest.FromBytes(op2Data)

all := map[digest.Digest]*pb.Op{
newDigest: op1,
op2Digest: op2,
all := map[digest.Digest]*op{
newDigest: {Op: op1},
op2Digest: {Op: op2},
}
visited := map[digest.Digest]digest.Digest{oldDigest: newDigest}

updated, err := recomputeDigests(context.Background(), all, visited, op2Digest)
require.NoError(t, err)
require.Len(t, visited, 2)
require.Len(t, all, 2)
assert.Equal(t, op1, all[newDigest])
assert.Equal(t, op1, all[newDigest].Op)
require.Equal(t, newDigest, visited[oldDigest])
require.Equal(t, op1, all[newDigest])
assert.Equal(t, op2, all[updated])
require.Equal(t, op1, all[newDigest].Op)
assert.Equal(t, op2, all[updated].Op)
require.Equal(t, newDigest, digest.Digest(op2.Inputs[0].Digest))
assert.NotEqual(t, op2Digest, updated)
}
Expand Down Expand Up @@ -88,14 +88,14 @@ func TestIngestDigest(t *testing.T) {
// Read the definition from the test data and ensure it uses the
// canonical digests after recompute.
var lastDgst digest.Digest
all := map[digest.Digest]*pb.Op{}
all := map[digest.Digest]*op{}
for _, in := range def.Def {
op := new(pb.Op)
err := op.Unmarshal(in)
opNew := new(pb.Op)
err := opNew.Unmarshal(in)
require.NoError(t, err)

lastDgst = digest.FromBytes(in)
all[lastDgst] = op
all[lastDgst] = &op{Op: opNew}
}
fmt.Println(all, lastDgst)

Expand Down

0 comments on commit ec39add

Please sign in to comment.