From 9a337bba2e70aa6e953d0f78b5875decbea72f50 Mon Sep 17 00:00:00 2001 From: Jacob Gadikian Date: Fri, 20 Dec 2024 15:06:36 +0700 Subject: [PATCH] memory management --- internal/api/iterator.go | 2 +- internal/runtime/hostfunctions.go | 456 +++++++++++++++++++++++++++--- internal/runtime/memory.go | 131 +++++++++ lib_libwasmvm_test.go | 2 +- 4 files changed, 543 insertions(+), 48 deletions(-) create mode 100644 internal/runtime/memory.go diff --git a/internal/api/iterator.go b/internal/api/iterator.go index c9a768b40..2f997e707 100644 --- a/internal/api/iterator.go +++ b/internal/api/iterator.go @@ -28,7 +28,7 @@ var ( func startCall() uint64 { latestCallIDMutex.Lock() defer latestCallIDMutex.Unlock() - latestCallID += 1 + latestCallID++ return latestCallID } diff --git a/internal/runtime/hostfunctions.go b/internal/runtime/hostfunctions.go index 850b854d6..5752314fe 100644 --- a/internal/runtime/hostfunctions.go +++ b/internal/runtime/hostfunctions.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" @@ -11,131 +12,494 @@ import ( "github.com/CosmWasm/wasmvm/v2/types" ) -// Assume these are your runtime-level interfaces: -// - db is a KVStore for contract storage -// - api is GoAPI for address manipulations -// - querier is a Querier for external queries +const ( + // Maximum number of iterators per contract call + maxIteratorsPerCall = 100 + // Gas costs for iterator operations + gasCostIteratorCreate = 2000 + gasCostIteratorNext = 100 +) + +// RuntimeEnvironment holds the environment for contract execution type RuntimeEnvironment struct { DB types.KVStore API *types.GoAPI Querier types.Querier + Memory *MemoryAllocator + Gas types.GasMeter + GasUsed types.Gas // Track gas usage internally + + // Iterator management + iteratorsMutex sync.RWMutex + iterators map[uint64]map[uint64]types.Iterator + nextIterID uint64 + nextCallID uint64 +} + +// NewRuntimeEnvironment creates a new runtime environment +func NewRuntimeEnvironment(db types.KVStore, api *types.GoAPI, querier types.Querier) *RuntimeEnvironment { + return &RuntimeEnvironment{ + DB: db, + API: api, + Querier: querier, + iterators: make(map[uint64]map[uint64]types.Iterator), + } +} + +// StartCall starts a new contract call and returns a call ID +func (e *RuntimeEnvironment) StartCall() uint64 { + e.iteratorsMutex.Lock() + defer e.iteratorsMutex.Unlock() + + e.nextCallID++ + e.iterators[e.nextCallID] = make(map[uint64]types.Iterator) + return e.nextCallID +} + +// StoreIterator stores an iterator and returns its ID +func (e *RuntimeEnvironment) StoreIterator(callID uint64, iter types.Iterator) uint64 { + e.iteratorsMutex.Lock() + defer e.iteratorsMutex.Unlock() + + e.nextIterID++ + if e.iterators[callID] == nil { + e.iterators[callID] = make(map[uint64]types.Iterator) + } + e.iterators[callID][e.nextIterID] = iter + return e.nextIterID +} + +// GetIterator retrieves an iterator by its IDs +func (e *RuntimeEnvironment) GetIterator(callID, iterID uint64) types.Iterator { + e.iteratorsMutex.RLock() + defer e.iteratorsMutex.RUnlock() + + if callMap, exists := e.iterators[callID]; exists { + return callMap[iterID] + } + return nil +} + +// EndCall cleans up all iterators for a call +func (e *RuntimeEnvironment) EndCall(callID uint64) { + e.iteratorsMutex.Lock() + defer e.iteratorsMutex.Unlock() + + delete(e.iterators, callID) +} + +// IteratorID represents a unique identifier for an iterator +type IteratorID struct { + CallID uint64 + IteratorID uint64 } -// Example host function: Get key from DB +// hostGet implements db_get func hostGet(ctx context.Context, mod api.Module, keyPtr uint32, keyLen uint32) (dataPtr uint32, dataLen uint32) { - env := ctx.Value("env").(*RuntimeEnvironment) // Assume you set this in config + env := ctx.Value("env").(*RuntimeEnvironment) mem := mod.Memory() - keyBytes, ok := mem.Read(keyPtr, keyLen) - if !ok { - panic("failed to read key from memory") + key, err := ReadMemory(mem, keyPtr, keyLen) + if err != nil { + panic(fmt.Sprintf("failed to read key from memory: %v", err)) } - value := env.DB.Get(keyBytes) + value := env.DB.Get(key) if len(value) == 0 { - // Return (0,0) for no data return 0, 0 } - // Write the value back to memory. In a real scenario, you need an allocator. - // For simplicity, assume we have a fixed offset for result writes: - offset := uint32(2048) // Just an example offset - if !mem.Write(offset, value) { - panic("failed to write value to memory") + // Allocate memory for the result + offset, err := env.Memory.Allocate(mem, uint32(len(value))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory: %v", err)) + } + + if err := WriteMemory(mem, offset, value); err != nil { + panic(fmt.Sprintf("failed to write value to memory: %v", err)) } + return offset, uint32(len(value)) } -// Example host function: Set key in DB +// hostSet implements db_set func hostSet(ctx context.Context, mod api.Module, keyPtr, keyLen, valPtr, valLen uint32) { env := ctx.Value("env").(*RuntimeEnvironment) mem := mod.Memory() - key, ok := mem.Read(keyPtr, keyLen) - if !ok { - panic("failed to read key from memory") + key, err := ReadMemory(mem, keyPtr, keyLen) + if err != nil { + panic(fmt.Sprintf("failed to read key from memory: %v", err)) } - val, ok2 := mem.Read(valPtr, valLen) - if !ok2 { - panic("failed to read value from memory") + + val, err := ReadMemory(mem, valPtr, valLen) + if err != nil { + panic(fmt.Sprintf("failed to read value from memory: %v", err)) } env.DB.Set(key, val) } -// Example host function: HumanizeAddress +// hostHumanizeAddress implements api_humanize_address func hostHumanizeAddress(ctx context.Context, mod api.Module, addrPtr, addrLen uint32) (resPtr, resLen uint32) { env := ctx.Value("env").(*RuntimeEnvironment) mem := mod.Memory() - addrBytes, ok := mem.Read(addrPtr, addrLen) - if !ok { - panic("failed to read address from memory") + addr, err := ReadMemory(mem, addrPtr, addrLen) + if err != nil { + panic(fmt.Sprintf("failed to read address from memory: %v", err)) } - human, _, err := env.API.HumanizeAddress(addrBytes) + human, _, err := env.API.HumanizeAddress(addr) if err != nil { - // On error, you might return 0,0 or handle differently return 0, 0 } - offset := uint32(4096) // Some offset for writing back - if !mem.Write(offset, []byte(human)) { - panic("failed to write humanized address") + // Allocate memory for the result + offset, err := env.Memory.Allocate(mem, uint32(len(human))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory: %v", err)) } + + if err := WriteMemory(mem, offset, []byte(human)); err != nil { + panic(fmt.Sprintf("failed to write humanized address: %v", err)) + } + return offset, uint32(len(human)) } -// Example host function: QueryExternal +// hostQueryExternal implements querier_query func hostQueryExternal(ctx context.Context, mod api.Module, reqPtr, reqLen, gasLimit uint32) (resPtr, resLen uint32) { env := ctx.Value("env").(*RuntimeEnvironment) mem := mod.Memory() - req, ok := mem.Read(reqPtr, reqLen) - if !ok { - panic("failed to read query request") + req, err := ReadMemory(mem, reqPtr, reqLen) + if err != nil { + panic(fmt.Sprintf("failed to read query request: %v", err)) } res := types.RustQuery(env.Querier, req, uint64(gasLimit)) serialized, err := json.Marshal(res) if err != nil { - // handle error, maybe return 0,0 return 0, 0 } - offset := uint32(8192) // Another offset - if !mem.Write(offset, serialized) { - panic("failed to write query response") + // Allocate memory for the result + offset, err := env.Memory.Allocate(mem, uint32(len(serialized))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory: %v", err)) + } + + if err := WriteMemory(mem, offset, serialized); err != nil { + panic(fmt.Sprintf("failed to write query response: %v", err)) } + return offset, uint32(len(serialized)) } -// RegisterHostFunctions registers all host functions into a module named "env". -// The wasm code must import them from "env". +// hostCanonicalizeAddress implements api_canonicalize_address +func hostCanonicalizeAddress(ctx context.Context, mod api.Module, addrPtr, addrLen uint32) (resPtr, resLen uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + addr, err := ReadMemory(mem, addrPtr, addrLen) + if err != nil { + panic(fmt.Sprintf("failed to read address from memory: %v", err)) + } + + canonical, _, err := env.API.CanonicalizeAddress(string(addr)) + if err != nil { + return 0, 0 + } + + // Allocate memory for the result + offset, err := env.Memory.Allocate(mem, uint32(len(canonical))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory: %v", err)) + } + + if err := WriteMemory(mem, offset, []byte(canonical)); err != nil { + panic(fmt.Sprintf("failed to write canonicalized address: %v", err)) + } + + return offset, uint32(len(canonical)) +} + +// hostValidateAddress implements api_validate_address +func hostValidateAddress(ctx context.Context, mod api.Module, addrPtr, addrLen uint32) uint32 { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + addr, err := ReadMemory(mem, addrPtr, addrLen) + if err != nil { + panic(fmt.Sprintf("failed to read address from memory: %v", err)) + } + + _, err = env.API.ValidateAddress(string(addr)) + if err != nil { + return 0 // Return 0 for invalid address + } + + return 1 // Return 1 for valid address +} + +// hostScan implements db_scan +func hostScan(ctx context.Context, mod api.Module, startPtr, startLen, endPtr, endLen uint32, order uint32) (uint64, uint64, uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + // Check gas for iterator creation + if env.GasUsed+gasCostIteratorCreate > env.Gas.GasConsumed() { + return 0, 0, 1 // Return error code 1 for out of gas + } + env.GasUsed += gasCostIteratorCreate + + start, err := ReadMemory(mem, startPtr, startLen) + if err != nil { + panic(fmt.Sprintf("failed to read start key from memory: %v", err)) + } + + end, err := ReadMemory(mem, endPtr, endLen) + if err != nil { + panic(fmt.Sprintf("failed to read end key from memory: %v", err)) + } + + // Check iterator limits + callID := env.StartCall() + if len(env.iterators[callID]) >= maxIteratorsPerCall { + return 0, 0, 2 // Return error code 2 for too many iterators + } + + // Get iterator from DB with order + var iter types.Iterator + if order == 1 { + iter = env.DB.ReverseIterator(start, end) + } else { + iter = env.DB.Iterator(start, end) + } + if iter == nil { + return 0, 0, 3 // Return error code 3 for iterator creation failure + } + + // Store iterator in the environment + iterID := env.StoreIterator(callID, iter) + + return callID, iterID, 0 // Return 0 for success +} + +// hostNext implements db_next +func hostNext(ctx context.Context, mod api.Module, callID, iterID uint64) (keyPtr, keyLen, valPtr, valLen, errCode uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + // Check gas for iterator next operation + if env.GasUsed+gasCostIteratorNext > env.Gas.GasConsumed() { + return 0, 0, 0, 0, 1 // Return error code 1 for out of gas + } + env.GasUsed += gasCostIteratorNext + + // Get iterator from environment + iter := env.GetIterator(callID, iterID) + if iter == nil { + return 0, 0, 0, 0, 2 // Return error code 2 for invalid iterator + } + + // Check if there are more items + if !iter.Valid() { + return 0, 0, 0, 0, 0 // Return 0 for end of iteration + } + + // Get key and value + key := iter.Key() + value := iter.Value() + + // Allocate memory for key + keyOffset, err := env.Memory.Allocate(mem, uint32(len(key))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory for key: %v", err)) + } + if err := WriteMemory(mem, keyOffset, key); err != nil { + panic(fmt.Sprintf("failed to write key to memory: %v", err)) + } + + // Allocate memory for value + valOffset, err := env.Memory.Allocate(mem, uint32(len(value))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory for value: %v", err)) + } + if err := WriteMemory(mem, valOffset, value); err != nil { + panic(fmt.Sprintf("failed to write value to memory: %v", err)) + } + + // Move to next item + iter.Next() + + return keyOffset, uint32(len(key)), valOffset, uint32(len(value)), 0 +} + +// hostNextKey implements db_next_key +func hostNextKey(ctx context.Context, mod api.Module, callID, iterID uint64) (keyPtr, keyLen, errCode uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + // Check gas for iterator next operation + if env.GasUsed+gasCostIteratorNext > env.Gas.GasConsumed() { + return 0, 0, 1 // Return error code 1 for out of gas + } + env.GasUsed += gasCostIteratorNext + + // Get iterator from environment + iter := env.GetIterator(callID, iterID) + if iter == nil { + return 0, 0, 2 // Return error code 2 for invalid iterator + } + + // Check if there are more items + if !iter.Valid() { + return 0, 0, 0 // Return 0 for end of iteration + } + + // Get key + key := iter.Key() + + // Allocate memory for key + keyOffset, err := env.Memory.Allocate(mem, uint32(len(key))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory for key: %v", err)) + } + if err := WriteMemory(mem, keyOffset, key); err != nil { + panic(fmt.Sprintf("failed to write key to memory: %v", err)) + } + + // Move to next item + iter.Next() + + return keyOffset, uint32(len(key)), 0 +} + +// hostNextValue implements db_next_value +func hostNextValue(ctx context.Context, mod api.Module, callID, iterID uint64) (valPtr, valLen, errCode uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + // Check gas for iterator next operation + if env.GasUsed+gasCostIteratorNext > env.Gas.GasConsumed() { + return 0, 0, 1 // Return error code 1 for out of gas + } + env.GasUsed += gasCostIteratorNext + + // Get iterator from environment + iter := env.GetIterator(callID, iterID) + if iter == nil { + return 0, 0, 2 // Return error code 2 for invalid iterator + } + + // Check if there are more items + if !iter.Valid() { + return 0, 0, 0 // Return 0 for end of iteration + } + + // Get value + value := iter.Value() + + // Allocate memory for value + valOffset, err := env.Memory.Allocate(mem, uint32(len(value))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory for value: %v", err)) + } + if err := WriteMemory(mem, valOffset, value); err != nil { + panic(fmt.Sprintf("failed to write value to memory: %v", err)) + } + + // Move to next item + iter.Next() + + return valOffset, uint32(len(value)), 0 +} + +// hostCloseIterator implements db_close_iterator +func hostCloseIterator(ctx context.Context, mod api.Module, callID, iterID uint64) { + env := ctx.Value("env").(*RuntimeEnvironment) + + // Get iterator from environment + iter := env.GetIterator(callID, iterID) + if iter == nil { + return + } + + // Close the iterator + iter.Close() + + // Remove from environment + env.iteratorsMutex.Lock() + defer env.iteratorsMutex.Unlock() + + if callMap, exists := env.iterators[callID]; exists { + delete(callMap, iterID) + } +} + +// RegisterHostFunctions registers all host functions into a module named "env" func RegisterHostFunctions(r wazero.Runtime, env *RuntimeEnvironment) (wazero.CompiledModule, error) { + // Initialize memory allocator if not already done + if env.Memory == nil { + env.Memory = NewMemoryAllocator(65536) // Start at 64KB offset + } + // Build a module that exports these functions hostModBuilder := r.NewHostModuleBuilder("env") - // Example: Registering hostGet as "db_get" + // Register memory management functions + RegisterMemoryManagement(hostModBuilder, env.Memory) + + // Register DB functions hostModBuilder.NewFunctionBuilder(). WithFunc(hostGet). Export("db_get") - // Similarly for hostSet hostModBuilder.NewFunctionBuilder(). WithFunc(hostSet). Export("db_set") - // For humanize address + // Register API functions hostModBuilder.NewFunctionBuilder(). WithFunc(hostHumanizeAddress). Export("api_humanize_address") - // For queries + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostCanonicalizeAddress). + Export("api_canonicalize_address") + + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostValidateAddress). + Export("api_validate_address") + + // Register Query functions hostModBuilder.NewFunctionBuilder(). WithFunc(hostQueryExternal). Export("querier_query") + // Register Iterator functions + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostScan). + Export("db_scan") + + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostNext). + Export("db_next") + + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostNextKey). + Export("db_next_key") + + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostNextValue). + Export("db_next_value") + + hostModBuilder.NewFunctionBuilder(). + WithFunc(hostCloseIterator). + Export("db_close_iterator") + // Compile the host module compiled, err := hostModBuilder.Compile(context.Background()) if err != nil { diff --git a/internal/runtime/memory.go b/internal/runtime/memory.go new file mode 100644 index 000000000..d8d9fac4f --- /dev/null +++ b/internal/runtime/memory.go @@ -0,0 +1,131 @@ +package runtime + +import ( + "context" + "fmt" + "sync" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// MemoryAllocator manages memory allocation for the WASM module +type MemoryAllocator struct { + mu sync.Mutex + // Start of the memory region we manage + heapStart uint32 + // Current position of the allocation pointer + current uint32 + // Map of allocated memory regions: offset -> size + allocations map[uint32]uint32 + // List of freed regions that can be reused + freeList []memoryRegion +} + +type memoryRegion struct { + offset uint32 + size uint32 +} + +// NewMemoryAllocator creates a new memory allocator starting at the specified offset +func NewMemoryAllocator(heapStart uint32) *MemoryAllocator { + return &MemoryAllocator{ + heapStart: heapStart, + current: heapStart, + allocations: make(map[uint32]uint32), + freeList: make([]memoryRegion, 0), + } +} + +// Allocate allocates a new region of memory of the specified size +// Returns the offset where the memory was allocated +func (m *MemoryAllocator) Allocate(mem api.Memory, size uint32) (uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // Try to reuse a freed region first + for i, region := range m.freeList { + if region.size >= size { + // Remove this region from free list + m.freeList = append(m.freeList[:i], m.freeList[i+1:]...) + m.allocations[region.offset] = size + return region.offset, nil + } + } + + // No suitable freed region found, allocate new memory + offset := m.current + + // Check if we have enough memory + memSize := mem.Size() + if offset+size > memSize { + // Try to grow memory + pages := ((offset + size - memSize) / uint32(65536)) + 1 + newSize, ok := mem.Grow(pages) + if !ok { + return 0, fmt.Errorf("failed to grow memory") + } + if newSize == 0 { + return 0, fmt.Errorf("failed to grow memory: maximum memory size exceeded") + } + } + + m.current += size + m.allocations[offset] = size + return offset, nil +} + +// Free releases the memory at the specified offset +func (m *MemoryAllocator) Free(offset uint32) error { + m.mu.Lock() + defer m.mu.Unlock() + + size, exists := m.allocations[offset] + if !exists { + return fmt.Errorf("attempt to free unallocated memory at offset %d", offset) + } + + delete(m.allocations, offset) + m.freeList = append(m.freeList, memoryRegion{offset: offset, size: size}) + return nil +} + +// WriteMemory writes data to memory at the specified offset +func WriteMemory(mem api.Memory, offset uint32, data []byte) error { + if !mem.Write(offset, data) { + return fmt.Errorf("failed to write %d bytes at offset %d", len(data), offset) + } + return nil +} + +// ReadMemory reads data from memory at the specified offset and length +func ReadMemory(mem api.Memory, offset, length uint32) ([]byte, error) { + data, ok := mem.Read(offset, length) + if !ok { + return nil, fmt.Errorf("failed to read %d bytes at offset %d", length, offset) + } + return data, nil +} + +// RegisterMemoryManagement sets up memory management for the module +func RegisterMemoryManagement(builder wazero.HostModuleBuilder, allocator *MemoryAllocator) { + // Allocate memory + builder.NewFunctionBuilder(). + WithFunc(func(_ context.Context, mod api.Module, size uint32) uint32 { + offset, err := allocator.Allocate(mod.Memory(), size) + if err != nil { + panic(err) // In real code, handle this better + } + return offset + }). + Export("allocate") + + // Free memory + builder.NewFunctionBuilder(). + WithFunc(func(_ context.Context, _ api.Module, offset uint32) { + if err := allocator.Free(offset); err != nil { + panic(err) // In real code, handle this better + } + }). + Export("deallocate") +} diff --git a/lib_libwasmvm_test.go b/lib_libwasmvm_test.go index 388b31d8f..e3fbb7fdb 100644 --- a/lib_libwasmvm_test.go +++ b/lib_libwasmvm_test.go @@ -94,7 +94,7 @@ func TestStoreCode(t *testing.T) { // Nil { - var wasm []byte = nil + var wasm []byte _, _, err := vm.StoreCode(wasm, TESTING_GAS_LIMIT) require.ErrorContains(t, err, "Null/Nil argument: wasm") }