diff --git a/internal/engine/compiler/engine_cache.go b/internal/engine/compiler/engine_cache.go index 204eb1b8a5..de6a7e3dd0 100644 --- a/internal/engine/compiler/engine_cache.go +++ b/internal/engine/compiler/engine_cache.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "hash/crc32" "io" "runtime" @@ -14,6 +15,8 @@ import ( "github.com/tetratelabs/wazero/internal/wasm" ) +var crc = crc32.MakeTable(crc32.Castagnoli) + func (e *engine) deleteCompiledModule(module *wasm.Module) { e.mux.Lock() defer e.mux.Unlock() @@ -130,6 +133,9 @@ func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader buf.Write(u64.LeBytes(uint64(cm.executable.Len()))) // Append the native code. buf.Write(cm.executable.Bytes()) + // Append checksum. + checksum := crc32.Checksum(cm.executable.Bytes(), crc) + buf.Write(u32.LeBytes(checksum)) return bytes.NewReader(buf.Bytes()) } @@ -209,6 +215,13 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul return } + expected := crc32.Checksum(cm.executable.Bytes(), crc) + if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil { + return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err) + } else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum { + return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum) + } + if runtime.GOARCH == "arm64" { // On arm64, we cannot give all of rwx at the same time, so we change it to exec. if err = platform.MprotectRX(cm.executable.Bytes()); err != nil { diff --git a/internal/engine/compiler/engine_cache_test.go b/internal/engine/compiler/engine_cache_test.go index c26433a316..612da0ca61 100644 --- a/internal/engine/compiler/engine_cache_test.go +++ b/internal/engine/compiler/engine_cache_test.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" + "hash/crc32" "io" "math" "testing" @@ -20,6 +21,11 @@ import ( var testVersion = "" +func crcf(b []byte) []byte { + c := crc32.Checksum(b, crc) + return u32.LeBytes(c) +} + func concat(ins ...[]byte) (ret []byte) { for _, in := range ins { ret = append(ret, in...) @@ -49,12 +55,13 @@ func TestSerializeCompiledModule(t *testing.T) { []byte(wazeroMagic), []byte{byte(len(testVersion))}, []byte(testVersion), - []byte{0}, // ensure termination. - u32.LeBytes(1), // number of functions. - u64.LeBytes(12345), // stack pointer ceil. - u64.LeBytes(0), // offset. - u64.LeBytes(5), // length of code. - []byte{1, 2, 3, 4, 5}, // code. + []byte{0}, // ensure termination. + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(0), // offset. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + crcf([]byte{1, 2, 3, 4, 5}), // crc of code. ), }, { @@ -71,12 +78,13 @@ func TestSerializeCompiledModule(t *testing.T) { []byte(wazeroMagic), []byte{byte(len(testVersion))}, []byte(testVersion), - []byte{1}, // ensure termination. - u32.LeBytes(1), // number of functions. - u64.LeBytes(12345), // stack pointer ceil. - u64.LeBytes(0), // offset. - u64.LeBytes(5), // length of code. - []byte{1, 2, 3, 4, 5}, // code. + []byte{1}, // ensure termination. + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(0), // offset. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + crcf([]byte{1, 2, 3, 4, 5}), // crc of code. ), }, { @@ -103,8 +111,9 @@ func TestSerializeCompiledModule(t *testing.T) { u64.LeBytes(0xffffffff), // stack pointer ceil. u64.LeBytes(5), // offset. // Executable. - u64.LeBytes(8), // length of code. - []byte{1, 2, 3, 4, 5, 1, 2, 3}, // code. + u64.LeBytes(8), // length of code. + []byte{1, 2, 3, 4, 5, 1, 2, 3}, // code. + crcf([]byte{1, 2, 3, 4, 5, 1, 2, 3}), // crc of code. ), }, } @@ -151,7 +160,25 @@ func TestDeserializeCompiledModule(t *testing.T) { expStaleCache: true, }, { - name: "one function", + name: "invalid crc", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + []byte{0}, // ensure termination. + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(0), // offset. + // Executable. + u64.LeBytes(5), // size. + []byte{1, 2, 3, 4, 5}, // machine code. + crcf([]byte{1, 2, 3, 4}), // crc of code. + ), + expStaleCache: false, + expErr: "compilationcache: checksum mismatch (expected 1397854123, got 691047668)", + }, + { + name: "missing crc", in: concat( []byte(wazeroMagic), []byte{byte(len(testVersion))}, @@ -164,6 +191,24 @@ func TestDeserializeCompiledModule(t *testing.T) { u64.LeBytes(5), // size. []byte{1, 2, 3, 4, 5}, // machine code. ), + expStaleCache: false, + expErr: "compilationcache: could not read checksum: EOF", + }, + { + name: "one function", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + []byte{0}, // ensure termination. + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(0), // offset. + // Executable. + u64.LeBytes(5), // size. + []byte{1, 2, 3, 4, 5}, // machine code. + crcf([]byte{1, 2, 3, 4, 5}), // crc of code. + ), expCompiledModule: &compiledModule{ compiledCode: &compiledCode{ executable: makeCodeSegment(1, 2, 3, 4, 5), @@ -181,12 +226,13 @@ func TestDeserializeCompiledModule(t *testing.T) { []byte(wazeroMagic), []byte{byte(len(testVersion))}, []byte(testVersion), - []byte{1}, // ensure termination. - u32.LeBytes(1), // number of functions. - u64.LeBytes(12345), // stack pointer ceil. - u64.LeBytes(0), // offset. - u64.LeBytes(5), // length of code. - []byte{1, 2, 3, 4, 5}, // code. + []byte{1}, // ensure termination. + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(0), // offset. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + crcf([]byte{1, 2, 3, 4, 5}), // crc of code. ), expCompiledModule: &compiledModule{ compiledCode: &compiledCode{ @@ -213,8 +259,9 @@ func TestDeserializeCompiledModule(t *testing.T) { u64.LeBytes(0xffffffff), // stack pointer ceil. u64.LeBytes(7), // offset. // Executable. - u64.LeBytes(10), // size. - []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code. + u64.LeBytes(10), // size. + []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code. + crcf([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), // crc of code. ), importedFunctionCount: 1, expCompiledModule: &compiledModule{ @@ -322,8 +369,9 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) { u64.LeBytes(0xffffffff), // stack pointer ceil. u64.LeBytes(5), // offset. // executables. - u64.LeBytes(10), // length of code. - []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code. + u64.LeBytes(10), // length of code. + []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code. + crcf([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), // code. ) tests := []struct { @@ -475,6 +523,7 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) { u64.LeBytes(0), // offset. u64.LeBytes(3), // size of executable. []byte{1, 2, 3}, + crcf([]byte{1, 2, 3}), // code. ), actual) require.NoError(t, content.Close()) }) diff --git a/internal/engine/wazevo/engine_cache.go b/internal/engine/wazevo/engine_cache.go index fedc4d057d..f7c0450aed 100644 --- a/internal/engine/wazevo/engine_cache.go +++ b/internal/engine/wazevo/engine_cache.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "hash/crc32" "io" "runtime" "unsafe" @@ -21,6 +22,8 @@ import ( "github.com/tetratelabs/wazero/internal/wasm" ) +var crc = crc32.MakeTable(crc32.Castagnoli) + // fileCacheKey returns a key for the file cache. // In order to avoid collisions with the existing compiler, we do not use m.ID directly, // but instead we rehash it with magic. @@ -145,6 +148,9 @@ func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader buf.Write(u64.LeBytes(uint64(len(cm.executable)))) // Append the native code. buf.Write(cm.executable) + // Append checksum. + checksum := crc32.Checksum(cm.executable, crc) + buf.Write(u32.LeBytes(checksum)) if sm := cm.sourceMap; len(sm.executableOffsets) > 0 { buf.WriteByte(1) // indicates that source map is present. l := len(sm.wasmBinaryOffsets) @@ -226,6 +232,13 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm * return nil, false, err } + expected := crc32.Checksum(executable, crc) + if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil { + return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err) + } else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum { + return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum) + } + if runtime.GOARCH == "arm64" { // On arm64, we cannot give all of rwx at the same time, so we change it to exec. if err = platform.MprotectRX(executable); err != nil { diff --git a/internal/engine/wazevo/engine_cache_test.go b/internal/engine/wazevo/engine_cache_test.go index 8402f301f9..4a7fe4636c 100644 --- a/internal/engine/wazevo/engine_cache_test.go +++ b/internal/engine/wazevo/engine_cache_test.go @@ -3,6 +3,7 @@ package wazevo import ( "bytes" "crypto/sha256" + "hash/crc32" "io" "testing" @@ -14,6 +15,11 @@ import ( var testVersion = "0.0.1" +func crcf(b []byte) []byte { + c := crc32.Checksum(b, crc) + return u32.LeBytes(c) +} + func TestSerializeCompiledModule(t *testing.T) { tests := []struct { in *compiledModule @@ -28,11 +34,12 @@ func TestSerializeCompiledModule(t *testing.T) { magic, []byte{byte(len(testVersion))}, []byte(testVersion), - u32.LeBytes(1), // number of functions. - u64.LeBytes(0), // offset. - u64.LeBytes(5), // length of code. - []byte{1, 2, 3, 4, 5}, // code. - []byte{0}, // no source map. + u32.LeBytes(1), // number of functions. + u64.LeBytes(0), // offset. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + crcf([]byte{1, 2, 3, 4, 5}), // crc for the code. + []byte{0}, // no source map. ), }, { @@ -44,11 +51,12 @@ func TestSerializeCompiledModule(t *testing.T) { magic, []byte{byte(len(testVersion))}, []byte(testVersion), - u32.LeBytes(1), // number of functions. - u64.LeBytes(0), // offset. - u64.LeBytes(5), // length of code. - []byte{1, 2, 3, 4, 5}, // code. - []byte{0}, // no source map. + u32.LeBytes(1), // number of functions. + u64.LeBytes(0), // offset. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + crcf([]byte{1, 2, 3, 4, 5}), // crc for the code. + []byte{0}, // no source map. ), }, { @@ -66,9 +74,10 @@ func TestSerializeCompiledModule(t *testing.T) { // Function index = 1. u64.LeBytes(5), // offset. // Executable. - u64.LeBytes(8), // length of code. - []byte{1, 2, 3, 4, 5, 1, 2, 3}, // code. - []byte{0}, // no source map. + u64.LeBytes(8), // length of code. + []byte{1, 2, 3, 4, 5, 1, 2, 3}, // code. + crcf([]byte{1, 2, 3, 4, 5, 1, 2, 3}), // crc for the code. + []byte{0}, // no source map. ), }, } @@ -140,9 +149,10 @@ func TestDeserializeCompiledModule(t *testing.T) { u32.LeBytes(1), // number of functions. u64.LeBytes(0), // offset. // Executable. - u64.LeBytes(5), // size. - []byte{1, 2, 3, 4, 5}, // machine code. - []byte{0}, // no source map. + u64.LeBytes(5), // size. + []byte{1, 2, 3, 4, 5}, // machine code. + crcf([]byte{1, 2, 3, 4, 5}), // machine code. + []byte{0}, // no source map. ), expCompiledModule: &compiledModule{ executables: &executables{executable: []byte{1, 2, 3, 4, 5}}, @@ -163,9 +173,10 @@ func TestDeserializeCompiledModule(t *testing.T) { // Function index = 1. u64.LeBytes(7), // offset. // Executable. - u64.LeBytes(10), // size. - []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code. - []byte{0}, // no source map. + u64.LeBytes(10), // size. + []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code. + crcf([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), // crc for machine code. + []byte{0}, // no source map. ), importedFunctionCount: 1, expCompiledModule: &compiledModule{ @@ -206,7 +217,27 @@ func TestDeserializeCompiledModule(t *testing.T) { expErr: "compilationcache: error reading executable (len=5): EOF", }, { - name: "no source map presence", + name: "bad crc", + in: concat( + magic, + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(1), // number of functions. + u64.LeBytes(0), // offset. + // Executable. + u64.LeBytes(5), // size. + []byte{1, 2, 3, 4, 5}, // machine code. + []byte{1, 2, 3, 4}, // crc for machine code. + ), + expCompiledModule: &compiledModule{ + executables: &executables{executable: []byte{1, 2, 3, 4, 5}}, + functionOffsets: []int{0}, + }, + expStaleCache: false, + expErr: "compilationcache: checksum mismatch (expected 1397854123, got 67305985)", + }, + { + name: "missing crc", in: concat( magic, []byte{byte(len(testVersion))}, @@ -222,6 +253,26 @@ func TestDeserializeCompiledModule(t *testing.T) { functionOffsets: []int{0}, }, expStaleCache: false, + expErr: "compilationcache: could not read checksum: EOF", + }, + { + name: "no source map presence", + in: concat( + magic, + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(1), // number of functions. + u64.LeBytes(0), // offset. + // Executable. + u64.LeBytes(5), // size. + []byte{1, 2, 3, 4, 5}, // machine code. + crcf([]byte{1, 2, 3, 4, 5}), // crc for machine code. + ), + expCompiledModule: &compiledModule{ + executables: &executables{executable: []byte{1, 2, 3, 4, 5}}, + functionOffsets: []int{0}, + }, + expStaleCache: false, expErr: "compilationcache: error reading source map presence: EOF", }, }