Skip to content

Commit

Permalink
Use a concurrently safe map for the underlying router map
Browse files Browse the repository at this point in the history
  • Loading branch information
imthatgin committed Apr 1, 2024
1 parent c01d7a4 commit 0cd59da
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 20 deletions.
11 changes: 7 additions & 4 deletions esync/srvsync/esync_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/yohamta/donburi"
"github.com/yohamta/donburi/component"
"golang.org/x/sync/errgroup"
"nhooyr.io/websocket"
"reflect"
"slices"
"sync"
Expand Down Expand Up @@ -74,15 +75,17 @@ func NetworkSync(world donburi.World, entity *donburi.Entity, components ...donb
// This is done by serializing all the components of the entity, and preparing a network bundle for the clients.
func DoSync() error {
errs, _ := errgroup.WithContext(context.Background())
for _, client := range router.Peers() {
snapshot := buildSnapshot(client, world)

client := client
router.PeerMap().Range(func(key *websocket.Conn, client *router.NetworkClient) bool {
snapshot := buildSnapshot(client, world)
errs.Go(func() error {
err := client.SendMessage(snapshot)
return err
})
}

return true
})

return errs.Wait()
}

Expand Down
38 changes: 38 additions & 0 deletions internal/syncx/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package syncx

import "sync"

type Map[K comparable, V any] struct {
m sync.Map
}

func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}

func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}

func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}

func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}

func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool { return f(key.(K), value.(V)) })
}

func (m *Map[K, V]) Store(key K, value V) { m.m.Store(key, value) }
46 changes: 46 additions & 0 deletions internal/syncx/map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package syncx_test

import (
"github.com/leap-fish/necs/internal/syncx"
"github.com/stretchr/testify/assert"
"testing"
)

func TestMap(t *testing.T) {
m := &syncx.Map[int, string]{}
// Test LoadAndStore
actual, loaded := m.LoadOrStore(1, "value1")
assert.False(t, loaded, "Expected loaded=false")
assert.Equal(t, "value1", actual, "Expected actual value 'value1'")

// Test Load
actualValue, ok := m.Load(1)
assert.True(t, ok, "Expected ok=true")
assert.Equal(t, "value1", actualValue, "Expected actual value 'value1'")

// Test Store
m.Store(2, "value2")
actualValue, ok = m.Load(2)
assert.True(t, ok, "Expected ok=true")
assert.Equal(t, "value2", actualValue, "Expected actual value 'value2'")

// Test Delete
m.Delete(1)
_, ok = m.Load(1)
assert.False(t, ok, "Expected ok=false for key 1 after deletion")

// Test LoadAndDelete
actualValue, loaded = m.LoadAndDelete(2)
assert.True(t, loaded, "Expected loaded=true")
assert.Equal(t, "value2", actualValue, "Expected actual value 'value2'")

// Test Range
m.Store(3, "value3")
m.Store(4, "value4")
var count int
m.Range(func(key int, value string) bool {
count++
return true
})
assert.Equal(t, 2, count, "Expected 2 iterations")
}
48 changes: 32 additions & 16 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"github.com/leap-fish/necs/internal/syncx"
"github.com/leap-fish/necs/typeid"
"github.com/leap-fish/necs/typemapper"
"nhooyr.io/websocket"
Expand All @@ -25,8 +26,8 @@ var (

callbacks = make(map[reflect.Type][]any)

connMap = map[*websocket.Conn]string{}
clientMap = map[*websocket.Conn]*NetworkClient{}
idMap = syncx.Map[*websocket.Conn, string]{}
clientMap = syncx.Map[*websocket.Conn, *NetworkClient]{}
)

// On adds a callback to be called whenever the specified message type T is received.
Expand Down Expand Up @@ -92,37 +93,49 @@ func ProcessMessage(sender *NetworkClient, msg []byte) error {
}

func Client(conn *websocket.Conn) *NetworkClient {
if _, ok := clientMap[conn]; ok {
return clientMap[conn]
client, ok := clientMap.Load(conn)
if ok {
return client
}

clientMap[conn] = NewNetworkClient(context.Background(), conn)
return clientMap[conn]
clientMap.Store(conn, NewNetworkClient(context.Background(), conn))
// Ignore because we know it's set
r, _ := clientMap.Load(conn)
return r
}

func Id(client *NetworkClient) string {

if _, ok := connMap[client.Conn]; ok {
return connMap[client.Conn]
id, ok := idMap.Load(client.Conn)
if ok {
return id
}

bytes := make([]byte, 16)
_, _ = rand.Read(bytes)
id := fmt.Sprintf("%d:%x", len(connMap), bytes[:10])
id = fmt.Sprintf("%x", bytes[:10])

connMap[client.Conn] = id
idMap.Store(client.Conn, id)
return id
}

// Peers returns a new slice of NetworkClient pointers from the underlying map.
// Use PeerMap if you are able to as this avoids this kind of duplication.
func Peers() []*NetworkClient {
peers := make([]*NetworkClient, 0, len(clientMap))
for _, client := range clientMap {
peers = append(peers, client)
}
var peers []*NetworkClient

clientMap.Range(func(key *websocket.Conn, value *NetworkClient) bool {
peers = append(peers, value)
return true
})
return peers
}

// PeerMap returns a pointer to the underlying peer map.
func PeerMap() *syncx.Map[*websocket.Conn, *NetworkClient] {
return &clientMap
}

func Broadcast(msg any) error {
payload, err := Serialize(msg)
if err != nil {
Expand Down Expand Up @@ -164,8 +177,8 @@ func CallDisconnect(sender *websocket.Conn, err error) {
go callback(client, err)
}

delete(connMap, sender)
delete(clientMap, sender)
idMap.Delete(sender)
clientMap.Delete(sender)
}

func CallError(sender *websocket.Conn, err error) {
Expand All @@ -181,4 +194,7 @@ func ResetRouter() {
disconnectCallbacks = []func(sender *NetworkClient, err error){}
errorCallbacks = []func(sender *NetworkClient, err error){}
callbacks = make(map[reflect.Type][]any)

idMap = syncx.Map[*websocket.Conn, string]{}
clientMap = syncx.Map[*websocket.Conn, *NetworkClient]{}
}

0 comments on commit 0cd59da

Please sign in to comment.