From bdd42dd427e6a4fc9da263f1edb4c5f3f1ade9d8 Mon Sep 17 00:00:00 2001 From: Adam Chen Date: Wed, 6 May 2026 23:29:40 +0000 Subject: [PATCH 1/2] fix(core): add synchronization for race conditions Due to the addition of hot-reloading config, and the realization in #99, nylon has many sections of code which exhibit race conditions. This PR tries to address some of the obvious race conditions. --- .github/workflows/go-test.yml | 4 +- core/entrypoint.go | 151 +---------------------------- core/ipc_handler.go | 75 +++++++-------- core/nylon.go | 159 ++++++++++++++++++++++++++++++- core/nylon_apply.go | 35 +++++-- core/nylon_distribution.go | 2 +- core/nylon_endpoints.go | 62 +++++++----- core/nylon_scheduler_test.go | 7 +- core/nylon_tc.go | 24 +++-- core/router.go | 28 ++++-- integration/apply_config_test.go | 8 +- integration/harness.go | 44 ++++++--- integration/ipc_test.go | 92 +++++++----------- integration/race_test.go | 121 +++++++++++++++++++++++ integration/routing_test.go | 5 +- state/config.go | 13 +++ 16 files changed, 506 insertions(+), 324 deletions(-) create mode 100644 integration/race_test.go diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 7fdccbd..b853699 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -25,7 +25,7 @@ jobs: with: go-version: ${{ env.GO_VERSION }} - name: Run unit test - run: go run gotest.tools/gotestsum@latest -- -tags=router_test ./... + run: go run gotest.tools/gotestsum@latest -- --race -tags=router_test ./... integration: runs-on: ubuntu-latest steps: @@ -35,7 +35,7 @@ jobs: with: go-version: ${{ env.GO_VERSION }} - name: Run integration - run: go run gotest.tools/gotestsum@latest -- -tags=integration ./integration/... + run: go run gotest.tools/gotestsum@latest -- --race -tags=integration ./integration/... e2e: runs-on: ubuntu-latest steps: diff --git a/core/entrypoint.go b/core/entrypoint.go index 863e8a8..206872d 100644 --- a/core/entrypoint.go +++ b/core/entrypoint.go @@ -1,26 +1,15 @@ package core import ( - "context" - "errors" "fmt" "log" "log/slog" "net/http" "os" - "os/signal" - "path" - "reflect" - "runtime" "runtime/trace" - "syscall" - "time" - "github.com/encodeous/nylon/perf" "github.com/encodeous/nylon/state" - "github.com/encodeous/tint" "github.com/goccy/go-yaml" - slogmulti "github.com/samber/slog-multi" ) func setupDebugging() { @@ -135,146 +124,12 @@ func Bootstrap(centralPath, nodePath, logPath string, verbose bool) { if err != nil { panic(err) } - err = Start(*centralCfg, *nodeCfg, level, centralPath, nil, nil) + n, err := NewNylon(*centralCfg, *nodeCfg, level, centralPath, nil) if err != nil { panic(err) } -} - -func Start(ccfg state.CentralCfg, ncfg state.LocalCfg, logLevel slog.Level, configPath string, aux map[string]any, initNylon **Nylon) error { - ctx, cancel := context.WithCancelCause(context.Background()) - - dispatch := make(chan func() error, 128) - - handlers := make([]slog.Handler, 0) - if state.DBG_log_json { - handlers = append(handlers, - slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ - Level: logLevel, - }), - ) - } else { - handlers = append(handlers, - tint.NewHandler(os.Stderr, &tint.Options{ - Level: logLevel, - AddSource: false, - CustomPrefix: string(ncfg.Id), - ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { - if attr.Key == "time" { - return slog.Attr{} - } - return attr - }, - })) - } - - if ncfg.LogPath != "" { - err := os.MkdirAll(path.Dir(ncfg.LogPath), 0700) - if err != nil { - return err - } - f, err := os.OpenFile(ncfg.LogPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0700) - if err != nil { - return err - } - handlers = append(handlers, slog.NewTextHandler(f, &slog.HandlerOptions{Level: logLevel})) - } - - logger := slog.New( - slogmulti.Fanout(handlers...)) - - if ncfg.InterfaceName == "" { - ncfg.InterfaceName = "nylon" - } - - n := &Nylon{ - Trace: &NylonTrace{}, - ConfigState: state.ConfigState{ - CentralCfg: ccfg, - LocalCfg: ncfg, - }, - Context: ctx, - Cancel: cancel, - DispatchChannel: dispatch, - Log: logger, - ConfigPath: configPath, - AuxConfig: aux, - } - - n.Log.Info("init modules") - - err := n.Init() + err = n.Start() if err != nil { - return err - } - if initNylon != nil { - *initNylon = n - } - n.Log.Info("init modules complete") - - n.Log.Info("Nylon has been initialized. To gracefully exit, send SIGINT or Ctrl+C.") - - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) - go func() { - select { - case _ = <-c: - n.Cancel(errors.New("received shutdown signal")) - case <-ctx.Done(): - return - } - }() - - err = MainLoop(n, dispatch) - if err != nil { - return err - } - return nil -} - -func MainLoop(n *Nylon, dispatch <-chan func() error) error { - n.Log.Debug("started main loop") - for { - select { - case fun := <-dispatch: - if fun == nil { - goto endLoop - } - //n.Log.Debug("start") - start := time.Now() - err := fun() - if err != nil { - n.Log.Error("error occurred during dispatch: ", "error", err) - n.Cancel(err) - } - elapsed := time.Since(start) - perf.DispatchLatency.Add(float64(elapsed.Microseconds())) - if elapsed > time.Millisecond*4 { - n.Log.Warn("dispatch took a long time!", "fun", runtime.FuncForPC(reflect.ValueOf(fun).Pointer()).Name(), "elapsed", elapsed, "len", len(dispatch)) - } - //n.Log.Debug("done", "elapsed", elapsed) - case <-n.Context.Done(): - goto endLoop - } + panic(err) } -endLoop: - n.Log.Info("stopped main loop", "reason", context.Cause(n.Context).Error()) - Stop(n) - return nil -} - -func Stop(n *Nylon) { - n.cleanupOnce.Do(func() { - n.Cancel(context.Canceled) - if n.DispatchChannel != nil { - close(n.DispatchChannel) - n.DispatchChannel = nil - } - n.Log.Info("cleaning up modules") - err := n.Cleanup() - if err != nil { - n.Log.Error("error occurred during Stop: ", "error", err) - } - n.Log.Info("stopped") - }) } diff --git a/core/ipc_handler.go b/core/ipc_handler.go index f847059..86a2328 100644 --- a/core/ipc_handler.go +++ b/core/ipc_handler.go @@ -59,18 +59,37 @@ func HandleNylonIPC(n *Nylon, rw *bufio.ReadWriter) error { } return device.ErrIPCStatusHandled } - var resp *protocol.IpcResponse - switch req.Request.(type) { - case *protocol.IpcRequest_Status: - resp = handleStatus(n, req.GetStatus()) - case *protocol.IpcRequest_Probe: - resp = handleIPCProbe(n, req.GetProbe()) - case *protocol.IpcRequest_Reload: - resp = handleIPCReload(n, req.GetReload()) - case *protocol.IpcRequest_Trace: + + // trace is blocking, so we dont dispatch + if _, ok := req.Request.(*protocol.IpcRequest_Trace); ok { return handleTrace(n, rw) - default: - resp = errResponse("unknown method") + } + + done := make(chan *protocol.IpcResponse, 1) + n.Dispatch(func() error { + var resp *protocol.IpcResponse + switch req.Request.(type) { + case *protocol.IpcRequest_Status: + resp = handleStatus(n, req.GetStatus()) + case *protocol.IpcRequest_Probe: + resp = handleIPCProbe(n, req.GetProbe()) + case *protocol.IpcRequest_Reload: + resp = handleIPCReload(n, req.GetReload()) + default: + resp = errResponse("unknown method") + } + done <- resp + return nil + }) + + var resp *protocol.IpcResponse + select { + case resp = <-done: + case <-n.Context.Done(): + resp = errResponse("nylon shutting down") + case <-time.After(1 * time.Second): + // nylon is too busy to handle IPC requests + resp = errResponse("timed out waiting for dispatch") } if err := writeResponse(rw, resp); err != nil { return err @@ -231,7 +250,7 @@ func buildRouteTables(n *Nylon) *protocol.RouteTables { slices.SortFunc(tables.Selected, func(a, b *protocol.SelRoute) int { return comparePubRoute(a.PubRoute, b.PubRoute) }) - for prefix, route := range n.router.ForwardTable.All() { + for prefix, route := range n.router.ForwardTable.Load().All() { tables.Forward = append(tables.Forward, &protocol.RouteTableEntry{ Prefix: prefix.String(), Nh: string(route.Nh), @@ -239,7 +258,7 @@ func buildRouteTables(n *Nylon) *protocol.RouteTables { }) } sortRouteTableEntries(tables.Forward) - for prefix, route := range n.router.ExitTable.All() { + for prefix, route := range n.router.ExitTable.Load().All() { tables.Exit = append(tables.Exit, &protocol.RouteTableEntry{ Prefix: prefix.String(), Nh: string(route.Nh), @@ -401,7 +420,7 @@ func handleIPCProbe(n *Nylon, req *protocol.ProbeRequest) *protocol.IpcResponse for _, ep := range neigh.Eps { nep := ep.AsNylonEndpoint() addr := nep.DynEP.Value - err := n.Probe(neigh.Id, nep) + err := n.Probe(neigh.Id, nep, true) r := &protocol.EndpointProbeResult{Address: addr, Success: err == nil} if err != nil { r.Error = err.Error() @@ -420,10 +439,12 @@ func handleIPCReload(n *Nylon, req *protocol.ReloadRequest) *protocol.IpcRespons return errResponse(fmt.Sprintf("read file: %v", err)) } var cfg state.CentralCfg - if err := yaml.Unmarshal(data, &cfg); err != nil { + if err = yaml.Unmarshal(data, &cfg); err != nil { return errResponse(fmt.Sprintf("parse config: %v", err)) } - result, err := applyCentralConfigSync(n, cfg) + // We're running on the dispatch goroutine, so call ApplyCentralConfig + // directly rather than re-dispatching (which would deadlock). + result, err := n.ApplyCentralConfig(&cfg) msg := "" if err != nil { msg = err.Error() @@ -448,28 +469,6 @@ func handleIPCReload(n *Nylon, req *protocol.ReloadRequest) *protocol.IpcRespons } } -func applyCentralConfigSync(n *Nylon, cfg state.CentralCfg) (ApplyResult, error) { - type result struct { - applyResult ApplyResult - err error - } - done := make(chan result, 1) - n.Dispatch(func() error { - applyResult, err := n.ApplyCentralConfig(cfg) - done <- result{applyResult: applyResult, err: err} - return nil - }) - - select { - case r := <-done: - return r.applyResult, r.err - case <-n.Context.Done(): - return ApplyRejected, context.Cause(n.Context) - case <-time.After(30 * time.Second): - return ApplyRejected, fmt.Errorf("timed out waiting for config reload") - } -} - func handleTrace(n *Nylon, rw *bufio.ReadWriter) error { if !state.DBG_trace_tc { if err := writeResponse(rw, errResponse("tracing not enabled; restart with --dbg-trace-tc")); err != nil { diff --git a/core/nylon.go b/core/nylon.go index 47fc13f..4016300 100644 --- a/core/nylon.go +++ b/core/nylon.go @@ -2,17 +2,28 @@ package core import ( "context" + "errors" "log/slog" "net" "net/netip" + "os" + "os/signal" + "path" + "reflect" + "runtime" "sync" + "sync/atomic" + "syscall" "time" + "github.com/encodeous/nylon/perf" "github.com/encodeous/nylon/polyamide/device" "github.com/encodeous/nylon/polyamide/tun" "github.com/encodeous/nylon/state" + "github.com/encodeous/tint" "github.com/gaissmai/bart" "github.com/jellydator/ttlcache/v3" + slogmulti "github.com/samber/slog-multi" ) type Nylon struct { @@ -23,15 +34,16 @@ type Nylon struct { RouterState *state.RouterState AppliedSystem AppliedSystemState PingBuf *ttlcache.Cache[uint64, EpPing] + PeerMap atomic.Pointer[map[state.NyPublicKey]state.NodeId] router struct { LastStarvationRequest time.Time IO map[state.NodeId]*IOPending // ForwardTable contains the full routing table - ForwardTable bart.Table[RouteTableEntry] + ForwardTable atomic.Pointer[bart.Table[RouteTableEntry]] // ExitTable contains only routes to services hosted on this node - ExitTable bart.Table[RouteTableEntry] + ExitTable atomic.Pointer[bart.Table[RouteTableEntry]] log *slog.Logger } @@ -61,6 +73,75 @@ type AppliedSystemState struct { Peers map[state.NodeId]state.NyPublicKey } +func NewNylon(ccfg state.CentralCfg, ncfg state.LocalCfg, logLevel slog.Level, configPath string, aux map[string]any) (*Nylon, error) { + ctx, cancel := context.WithCancelCause(context.Background()) + + dispatch := make(chan func() error, 128) + + handlers := make([]slog.Handler, 0) + if state.DBG_log_json { + handlers = append(handlers, + slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: logLevel, + }), + ) + } else { + handlers = append(handlers, + tint.NewHandler(os.Stderr, &tint.Options{ + Level: logLevel, + AddSource: false, + CustomPrefix: string(ncfg.Id), + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == "time" { + return slog.Attr{} + } + return attr + }, + })) + } + + if ncfg.LogPath != "" { + err := os.MkdirAll(path.Dir(ncfg.LogPath), 0700) + if err != nil { + return nil, err + } + f, err := os.OpenFile(ncfg.LogPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0700) + if err != nil { + return nil, err + } + handlers = append(handlers, slog.NewTextHandler(f, &slog.HandlerOptions{Level: logLevel})) + } + + logger := slog.New( + slogmulti.Fanout(handlers...)) + + if ncfg.InterfaceName == "" { + ncfg.InterfaceName = "nylon" + } + + n := &Nylon{ + Trace: &NylonTrace{}, + ConfigState: state.ConfigState{ + CentralCfg: ccfg, + LocalCfg: ncfg, + }, + Context: ctx, + Cancel: cancel, + DispatchChannel: dispatch, + Log: logger, + ConfigPath: configPath, + AuxConfig: aux, + } + + n.Log.Info("init modules") + + err := n.Init() + if err != nil { + return nil, err + } + return n, nil +} + func (n *Nylon) Init() error { n.Log.Debug("init nylon") @@ -78,7 +159,7 @@ func (n *Nylon) Init() error { if n.AppliedSystem.Peers == nil { n.AppliedSystem.Peers = make(map[state.NodeId]state.NyPublicKey) } - err = n.reconcileRouterState(n.CentralCfg) + err = n.reconcileRouterState(&n.CentralCfg) if err != nil { return err } @@ -143,6 +224,78 @@ func (n *Nylon) Init() error { return nil } +func (n *Nylon) Start() error { + n.Log.Info("init modules complete") + + n.Log.Info("Nylon has been initialized. To gracefully exit, send SIGINT or Ctrl+C.") + + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + go func() { + select { + case _ = <-c: + n.Cancel(errors.New("received shutdown signal")) + case <-n.Context.Done(): + return + } + }() + + err := n.mainLoop() + if err != nil { + return err + } + return nil +} + +func (n *Nylon) Stop() { + n.cleanupOnce.Do(func() { + n.Cancel(context.Canceled) + if n.Context.Err() == nil { + select { + case n.DispatchChannel <- nil: // instead of close(), which can cause a data race + case <-n.Context.Done(): + } + } + }) +} + +func (n *Nylon) mainLoop() error { + n.Log.Debug("started main loop") + for { + select { + case fun := <-n.DispatchChannel: + if fun == nil { + goto endLoop + } + //n.Log.Debug("start") + start := time.Now() + err := fun() + if err != nil { + n.Log.Error("error occurred during dispatch: ", "error", err) + n.Cancel(err) + } + elapsed := time.Since(start) + perf.DispatchLatency.Add(float64(elapsed.Microseconds())) + if elapsed > time.Millisecond*4 { + n.Log.Warn("dispatch took a long time!", "fun", runtime.FuncForPC(reflect.ValueOf(fun).Pointer()).Name(), "elapsed", elapsed, "len", len(n.DispatchChannel)) + } + //n.Log.Debug("done", "elapsed", elapsed) + case <-n.Context.Done(): + goto endLoop + } + } +endLoop: + n.Log.Info("stopped main loop", "reason", context.Cause(n.Context).Error()) + n.Stop() + n.Log.Info("cleaning up modules") + err := n.Cleanup() + if err != nil { + n.Log.Error("error occurred during Stop: ", "error", err) + } + n.Log.Info("stopped") + return nil +} + func (n *Nylon) Cleanup() error { n.PingBuf.Stop() for _, ph := range n.GetNode(n.LocalCfg.Id).Prefixes { diff --git a/core/nylon_apply.go b/core/nylon_apply.go index df468b6..a501ab2 100644 --- a/core/nylon_apply.go +++ b/core/nylon_apply.go @@ -19,9 +19,13 @@ const ( ApplyRestartRequired ApplyResult = "restart_required" ) -func (n *Nylon) ApplyCentralConfig(next state.CentralCfg) (ApplyResult, error) { - state.ExpandCentralConfig(&next) - if err := state.CentralConfigValidator(&next); err != nil { +func (n *Nylon) ApplyCentralConfig(cfg *state.CentralCfg) (ApplyResult, error) { + err, next := cfg.Clone() + if err != nil { + return ApplyRejected, err + } + state.ExpandCentralConfig(next) + if err := state.CentralConfigValidator(next); err != nil { return ApplyRejected, err } if !next.IsRouter(n.LocalCfg.Id) { @@ -35,7 +39,7 @@ func (n *Nylon) ApplyCentralConfig(next state.CentralCfg) (ApplyResult, error) { return ApplyRejected, err } n.reconcileAdvertisedPrefixes(next) - n.CentralCfg = next + n.CentralCfg = *next if err := n.SyncWireGuard(); err != nil { return ApplyRejected, err @@ -48,7 +52,7 @@ func (n *Nylon) ApplyCentralConfig(next state.CentralCfg) (ApplyResult, error) { return ApplyApplied, nil } -func (n *Nylon) reconcileRouterState(next state.CentralCfg) error { +func (n *Nylon) reconcileRouterState(next *state.CentralCfg) error { desired := make(map[state.NodeId]state.RouterCfg) for _, peer := range next.GetPeers(n.LocalCfg.Id) { if !next.IsRouter(peer) { @@ -61,14 +65,17 @@ func (n *Nylon) reconcileRouterState(next state.CentralCfg) error { for _, neigh := range n.RouterState.Neighbours { cfg, ok := desired[neigh.Id] if !ok { + // remove old neighbours delete(n.router.IO, neigh.Id) continue } + // configure existing neighbours reconcileConfiguredEndpoints(neigh, cfg.Endpoints) neighs = append(neighs, neigh) delete(desired, neigh.Id) } + // create new neighbours ids := make([]state.NodeId, 0, len(desired)) for id := range desired { ids = append(ids, id) @@ -87,6 +94,16 @@ func (n *Nylon) reconcileRouterState(next state.CentralCfg) error { neighs = append(neighs, stNeigh) } n.RouterState.Neighbours = neighs + + // rebuild pubkey to peer's id mapping + pubkeyMap := make(map[state.NyPublicKey]state.NodeId) + for _, x := range next.Routers { + pubkeyMap[x.PubKey] = x.Id + } + for _, x := range next.Clients { + pubkeyMap[x.PubKey] = x.Id + } + n.PeerMap.Store(new(pubkeyMap)) return nil } @@ -104,8 +121,8 @@ func reconcileConfiguredEndpoints(neigh *state.Neighbour, desired []*state.Dynam eps = append(eps, ep) continue } + // only keep if desired if desiredEp, ok := desiredByValue[nep.DynEP.Value]; ok { - nep.DynEP = desiredEp eps = append(eps, ep) seen[desiredEp.Value] = struct{}{} } @@ -119,7 +136,7 @@ func reconcileConfiguredEndpoints(neigh *state.Neighbour, desired []*state.Dynam neigh.Eps = eps } -func (n *Nylon) reconcileAdvertisedPrefixes(next state.CentralCfg) { +func (n *Nylon) reconcileAdvertisedPrefixes(next *state.CentralCfg) { cur := n.GetRouter(n.LocalCfg.Id) nextRouter := next.GetRouter(n.LocalCfg.Id) @@ -146,7 +163,7 @@ func (n *Nylon) reconcileAdvertisedPrefixes(next state.CentralCfg) { for prefix, desired := range desiredLocal { if _, ok := currentLocal[prefix]; !ok { - n.Log.Info("starting prefix healthcheck", "prefix", prefix) + n.Log.Debug("starting prefix healthcheck", "prefix", prefix) desired.Start(n.Log) } n.RouterState.Advertised[prefix] = state.Advertisement{ @@ -162,7 +179,7 @@ func (n *Nylon) reconcileAdvertisedPrefixes(next state.CentralCfg) { func (n *Nylon) startAdvertisedPrefixHealth() { for _, ph := range n.GetNode(n.LocalCfg.Id).Prefixes { - n.Log.Info("starting prefix healthcheck", "prefix", ph.GetPrefix()) + n.Log.Debug("starting prefix healthcheck", "prefix", ph.GetPrefix()) ph.Start(n.Log) } } diff --git a/core/nylon_distribution.go b/core/nylon_distribution.go index 951bba3..7ca4d5b 100644 --- a/core/nylon_distribution.go +++ b/core/nylon_distribution.go @@ -100,7 +100,7 @@ func checkForConfigUpdates(n *Nylon) error { return nil } n.Log.Info("Found a new config update in repo", "repo", repo) - result, err := n.ApplyCentralConfig(*config) + result, err := n.ApplyCentralConfig(config) if err != nil { n.Log.Error("failed to apply central config update", "repo", repo, "result", result, "err", err) return nil diff --git a/core/nylon_endpoints.go b/core/nylon_endpoints.go index 0e917c9..976b2b4 100644 --- a/core/nylon_endpoints.go +++ b/core/nylon_endpoints.go @@ -3,6 +3,7 @@ package core import ( "math/rand/v2" "slices" + "sync" "time" "github.com/encodeous/nylon/polyamide/conn" @@ -16,7 +17,7 @@ type EpPing struct { TimeSent time.Time } -func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint) error { +func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint, waitErr bool) error { token := rand.Uint64() ping := &protocol.Ny{ Type: &protocol.Ny_ProbeOp{ @@ -31,14 +32,27 @@ func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint) error { if err != nil { return err } - err = n.SendNylon(ping, nep, peer) - if err != nil { - return err - } - n.PingBuf.Set(token, EpPing{ - TimeSent: time.Now(), - }, ttlcache.DefaultTTL) + wg := sync.WaitGroup{} + wg.Add(1) + + var sendErr error + go func() { + defer wg.Done() + sendErr = n.SendNylon(ping, nep, peer) + if sendErr != nil { + return + } + + n.PingBuf.Set(token, EpPing{ + TimeSent: time.Now(), + }, ttlcache.DefaultTTL) + }() + + if waitErr { + wg.Wait() + return sendErr + } return nil } @@ -69,7 +83,7 @@ func handleProbe(n *Nylon, pkt *protocol.Ny_Probe, endpoint conn.Endpoint, peer } } -func handleProbePing(n *Nylon, node state.NodeId, ep conn.Endpoint) { +func handleProbePing(n *Nylon, node state.NodeId, wgEndpoint conn.Endpoint) { if node == n.LocalCfg.Id { return } @@ -78,11 +92,11 @@ func handleProbePing(n *Nylon, node state.NodeId, ep conn.Endpoint) { for _, dep := range neigh.Eps { dep := dep.AsNylonEndpoint() ap, err := dep.DynEP.Get() - if err == nil && ap == ep.DstIPPort() && neigh.Id == node { + if err == nil && ap == wgEndpoint.DstIPPort() && neigh.Id == node { // we have a link // refresh wireguard ep - dep.WgEndpoint = ep + dep.WgEndpoint = wgEndpoint if !dep.IsActive() { n.UpdateNeighbour(node) @@ -99,7 +113,7 @@ func handleProbePing(n *Nylon, node state.NodeId, ep conn.Endpoint) { // create a new link if we dont have a link for _, neigh := range n.RouterState.Neighbours { if neigh.Id == node { - newEp := state.NewEndpoint(state.NewDynamicEndpoint(ep.DstIPPort().String()), true, ep) + newEp := state.NewEndpoint(state.NewDynamicEndpoint(wgEndpoint.DstIPPort().String()), true, wgEndpoint) newEp.Renew() neigh.Eps = append(neigh.Eps, newEp) // push route update to improve convergence time @@ -112,8 +126,8 @@ func handleProbePing(n *Nylon, node state.NodeId, ep conn.Endpoint) { func handleProbePong(n *Nylon, node state.NodeId, token uint64, ep conn.Endpoint) { // check if link exists for _, neigh := range n.RouterState.Neighbours { - for _, dpLink := range neigh.Eps { - dpLink := dpLink.AsNylonEndpoint() + for _, dep := range neigh.Eps { + dpLink := dep.AsNylonEndpoint() ap, err := dpLink.DynEP.Get() if err == nil && ap == ep.DstIPPort() && neigh.Id == node { linkHealth, ok := n.PingBuf.GetAndDelete(token) @@ -144,12 +158,10 @@ func (n *Nylon) probeLinks(active bool) error { for _, neigh := range n.RouterState.Neighbours { for _, ep := range neigh.Eps { if ep.IsActive() == active { - go func() { - err := n.Probe(neigh.Id, ep.AsNylonEndpoint()) - if err != nil { - n.Log.Debug("probe failed", "err", err.Error()) - } - }() + err := n.Probe(neigh.Id, ep.AsNylonEndpoint(), false) + if err != nil { + n.Log.Debug("probe failed", "err", err.Error()) + } } } } @@ -184,12 +196,10 @@ func (n *Nylon) probeNew() error { // add the link to the neighbour dpl := state.NewEndpoint(ep, false, nil) neigh.Eps = append(neigh.Eps, dpl) - go func() { - err := n.Probe(peer, dpl) - if err != nil { - //n.Log.Debug("discovery probe failed", "err", err.Error()) - } - }() + err := n.Probe(peer, dpl, false) + if err != nil { + //n.Log.Debug("discovery probe failed", "err", err.Error()) + } } } } diff --git a/core/nylon_scheduler_test.go b/core/nylon_scheduler_test.go index 0afb75e..7ec40f3 100644 --- a/core/nylon_scheduler_test.go +++ b/core/nylon_scheduler_test.go @@ -3,6 +3,7 @@ package core import ( "context" "sync" + "sync/atomic" "testing" "time" ) @@ -18,7 +19,7 @@ func TestDispatch(t *testing.T) { Cancel: cancel, } - var called bool + var called atomic.Bool go func() { select { @@ -32,13 +33,13 @@ func TestDispatch(t *testing.T) { }() n.Dispatch(func() error { - called = true + called.Store(true) return nil }) time.Sleep(150 * time.Millisecond) - if !called { + if !called.Load() { t.Fatal("Dispatch function was not executed") } } diff --git a/core/nylon_tc.go b/core/nylon_tc.go index 272d633..750d06a 100644 --- a/core/nylon_tc.go +++ b/core/nylon_tc.go @@ -53,7 +53,7 @@ func (n *Nylon) InstallTC() { }) // forward only outgoing packets based on the routing table n.Device.InstallFilter(func(dev *device.Device, packet *device.TCElement) (device.TCAction, error) { - entry, ok := n.router.ForwardTable.Lookup(packet.GetDst()) + entry, ok := n.router.ForwardTable.Load().Lookup(packet.GetDst()) if ok && !packet.Incoming() { if entry.Blackhole { return device.TcDrop, nil @@ -69,7 +69,7 @@ func (n *Nylon) InstallTC() { } else { // forward packets based on the routing table n.Device.InstallFilter(func(dev *device.Device, packet *device.TCElement) (device.TCAction, error) { - entry, ok := n.router.ForwardTable.Lookup(packet.GetDst()) + entry, ok := n.router.ForwardTable.Load().Lookup(packet.GetDst()) if ok { if entry.Blackhole { return device.TcDrop, nil @@ -107,7 +107,7 @@ func (n *Nylon) InstallTC() { // bounce back packets destined for the current node n.Device.InstallFilter(func(dev *device.Device, packet *device.TCElement) (device.TCAction, error) { - entry, ok := n.router.ExitTable.Lookup(packet.GetDst()) + entry, ok := n.router.ExitTable.Load().Lookup(packet.GetDst()) // we should only accept packets destined to us, but not our passive clients if ok && entry.Nh == n.LocalCfg.Id { if state.DBG_trace_tc { @@ -160,6 +160,7 @@ func (n *Nylon) SendNylonBundle(pkt *protocol.TransportBundle, endpoint conn.End } func (n *Nylon) handleNylonPacket(packet []byte, endpoint conn.Endpoint, peer *device.Peer) { + // we need to be careful here, since this function is called on the dataplane bundle := &protocol.TransportBundle{} err := proto.Unmarshal(packet, bundle) if err != nil { @@ -168,8 +169,12 @@ func (n *Nylon) handleNylonPacket(packet []byte, endpoint conn.Endpoint, peer *d return } - neigh := n.FindNodeBy(state.NyPublicKey(peer.GetPublicKey())) - if neigh == nil { + nt := n.PeerMap.Load() + if nt == nil { + return // not loaded yet + } + neigh, ok := (*nt)[state.NyPublicKey(peer.GetPublicKey())] + if !ok { // this should not be possible panic("impossible state, peer added, but not a node in the network") } @@ -185,18 +190,19 @@ func (n *Nylon) handleNylonPacket(packet []byte, endpoint conn.Endpoint, peer *d switch pkt.Type.(type) { case *protocol.Ny_SeqnoRequestOp: n.Dispatch(func() error { - return n.routerHandleSeqnoRequest(*neigh, pkt.GetSeqnoRequestOp()) + return n.routerHandleSeqnoRequest(neigh, pkt.GetSeqnoRequestOp()) }) case *protocol.Ny_RouteOp: n.Dispatch(func() error { - return n.routerHandleRouteUpdate(*neigh, pkt.GetRouteOp()) + return n.routerHandleRouteUpdate(neigh, pkt.GetRouteOp()) }) case *protocol.Ny_AckRetractOp: n.Dispatch(func() error { - return n.routerHandleAckRetract(*neigh, pkt.GetAckRetractOp()) + return n.routerHandleAckRetract(neigh, pkt.GetAckRetractOp()) }) case *protocol.Ny_ProbeOp: - handleProbe(n, pkt.GetProbeOp(), endpoint, peer, *neigh) + // we don't want to wait for dispatch before responding to this packet + handleProbe(n, pkt.GetProbeOp(), endpoint, peer, neigh) } } } diff --git a/core/router.go b/core/router.go index 2265236..60fa2f9 100644 --- a/core/router.go +++ b/core/router.go @@ -99,32 +99,42 @@ func (n *Nylon) UpdateNeighbour(neigh state.NodeId) { func (n *Nylon) TableInsertRoute(prefix netip.Prefix, route state.SelRoute) { nh := route.Nh + nf := n.router.ForwardTable.Load().Clone() + ne := n.router.ExitTable.Load().Clone() if route.Metric == state.INF { - n.router.ForwardTable.Insert(prefix, RouteTableEntry{ + nf.Insert(prefix, RouteTableEntry{ Nh: nh, Blackhole: true, }) - n.router.ExitTable.Delete(prefix) + ne.Delete(prefix) + n.router.ForwardTable.Store(nf) + n.router.ExitTable.Store(ne) return } peer := n.Device.LookupPeer(device.NoisePublicKey(n.GetNode(nh).PubKey)) - n.router.ForwardTable.Insert(prefix, RouteTableEntry{ + nf.Insert(prefix, RouteTableEntry{ Nh: nh, Peer: peer, }) if route.Nh == n.LocalCfg.Id { - n.router.ExitTable.Insert(prefix, RouteTableEntry{ + ne.Insert(prefix, RouteTableEntry{ Nh: nh, Peer: peer, }) } else { - n.router.ExitTable.Delete(prefix) + ne.Delete(prefix) } + n.router.ForwardTable.Store(nf) + n.router.ExitTable.Store(ne) } func (n *Nylon) TableDeleteRoute(prefix netip.Prefix) { - n.router.ForwardTable.Delete(prefix) - n.router.ExitTable.Delete(prefix) + nf := n.router.ForwardTable.Load().Clone() + ne := n.router.ExitTable.Load().Clone() + nf.Delete(prefix) + ne.Delete(prefix) + n.router.ForwardTable.Store(nf) + n.router.ExitTable.Store(ne) } type IOPending struct { @@ -159,8 +169,8 @@ func (n *Nylon) InitRouter() error { n.router.log = n.Log.With("module", log.ScopeRouter) n.router.log.Debug("init router") n.router.IO = make(map[state.NodeId]*IOPending) - n.router.ForwardTable = bart.Table[RouteTableEntry]{} - n.router.ExitTable = bart.Table[RouteTableEntry]{} + n.router.ForwardTable.Store(new(bart.Table[RouteTableEntry]{})) + n.router.ExitTable.Store(new(bart.Table[RouteTableEntry]{})) n.RouterState = &state.RouterState{ Id: n.LocalCfg.Id, SelfSeqno: make(map[netip.Prefix]uint16), diff --git a/integration/apply_config_test.go b/integration/apply_config_test.go index a0cf639..c483549 100644 --- a/integration/apply_config_test.go +++ b/integration/apply_config_test.go @@ -51,7 +51,7 @@ func TestApplyCentralConfigRemovesNeighbourFromLiveNode(t *testing.T) { "b, c", } - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() a.Dispatch(func() error { beforeC := a.RouterState.GetNeighbour("c") if beforeC == nil || len(beforeC.Eps) == 0 { @@ -60,7 +60,7 @@ func TestApplyCentralConfigRemovesNeighbourFromLiveNode(t *testing.T) { return nil } keptEndpoint := beforeC.Eps[0] - result, err := a.ApplyCentralConfig(next) + result, err := a.ApplyCentralConfig(&next) afterC := a.RouterState.GetNeighbour("c") apply = applyResult{ @@ -118,7 +118,7 @@ func TestApplyCentralConfigLocalNodeRemovedRequiresRestart(t *testing.T) { errs := vh.Start() defer vh.Stop() - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() next := vh.Central next.Timestamp++ next.Routers = next.Routers[1:] @@ -130,7 +130,7 @@ func TestApplyCentralConfigLocalNodeRemovedRequiresRestart(t *testing.T) { var centralUnchanged bool var bStillNeighbour bool a.Dispatch(func() error { - result, err = a.ApplyCentralConfig(next) + result, err = a.ApplyCentralConfig(&next) centralUnchanged = a.CentralCfg.Timestamp == vh.Central.Timestamp bStillNeighbour = a.RouterState.GetNeighbour("b") != nil close(done) diff --git a/integration/harness.go b/integration/harness.go index aa6caa0..39475ed 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -14,6 +14,7 @@ import ( "runtime/pprof" "slices" "sync" + "sync/atomic" "time" "github.com/encodeous/nylon/core" @@ -106,10 +107,12 @@ type VirtualHarness struct { Cancel context.CancelCauseFunc Local []state.LocalCfg Net *InMemoryNetwork - Nylons []*core.Nylon + Nylons []atomic.Pointer[core.Nylon] Links []*VirtualLink + linksMu sync.RWMutex Endpoints map[string]state.NodeId UntrackedRouting bool + LogLevel *slog.Level } func (v *VirtualHarness) IndexOf(id state.NodeId) int { @@ -151,7 +154,9 @@ func (v *VirtualHarness) AddLink(from, to string) *VirtualLink { V1: bindtest.ChannelEndpoint2(netip.MustParseAddrPort(from)), V2: bindtest.ChannelEndpoint2(netip.MustParseAddrPort(to)), } + v.linksMu.Lock() v.Links = append(v.Links, link) + v.linksMu.Unlock() return link } @@ -159,7 +164,7 @@ func (v *VirtualHarness) Start() chan error { ctx, cancel := context.WithCancelCause(context.Background()) v.Context = ctx v.Cancel = cancel - v.Nylons = make([]*core.Nylon, len(v.Central.Routers)) + v.Nylons = make([]atomic.Pointer[core.Nylon], len(v.Central.Routers)) errChan := make(chan error, 128) // a large number so we dont get blocked vn := &InMemoryNetwork{} v.Net = vn @@ -181,16 +186,25 @@ func (v *VirtualHarness) Start() chan error { idx := v.IndexOf(n) v.Central.Routers[idx].Endpoints = append(v.Central.Routers[idx].Endpoints, state.NewDynamicEndpoint(e)) } + if v.LogLevel == nil { + v.LogLevel = new(slog.LevelDebug) + } startDelay := 0 * time.Millisecond for idx, rt := range v.Central.Routers { sd := startDelay go func() { time.Sleep(sd) labels := pprof.Labels("nylon node", string(rt.Id)) + n, err := core.NewNylon(v.Central, v.Local[idx], *v.LogLevel, "", map[string]any{ + "vnet": vn, + }) + if err != nil { + errChan <- err + return + } + v.Nylons[idx].Store(n) pprof.Do(context.Background(), labels, func(_ context.Context) { - cErr := core.Start(v.Central, v.Local[idx], slog.LevelDebug, "", map[string]any{ - "vnet": vn, - }, &v.Nylons[idx]) + cErr := n.Start() if cErr != nil { errChan <- cErr return @@ -203,7 +217,7 @@ func (v *VirtualHarness) Start() chan error { for { started := true for idx, _ := range v.Central.Routers { - if v.Nylons[idx] == nil { + if v.Nylons[idx].Load() == nil { started = false break } @@ -228,7 +242,7 @@ func (v *VirtualHarness) Stop() { println("Stopping VirtualHarness") v.Cancel(fmt.Errorf("stopping harness")) for idx, _ := range v.Central.Routers { - core.Stop(v.Nylons[idx]) + v.Nylons[idx].Load().Stop() } v.Net.Stop() println("Stopped VirtualHarness") @@ -252,14 +266,14 @@ type InMemoryNetwork struct { SelfHandler PacketFilter // packet filter for handling packets destined for the current node TransitHandler PacketFilter // packet filter for handling packets passing through the current node EpOutMapping OutMapping - ready bool + ready atomic.Bool readyCond *sync.Cond } func (i *InMemoryNetwork) WaitForReady() { i.readyCond.L.Lock() defer i.readyCond.L.Unlock() - for !i.ready { + for !i.ready.Load() { i.readyCond.Wait() } } @@ -267,7 +281,7 @@ func (i *InMemoryNetwork) WaitForReady() { func (i *InMemoryNetwork) Ready() { i.Lock() defer i.Unlock() - i.ready = true + i.ready.Store(true) i.readyCond.Broadcast() } @@ -313,13 +327,19 @@ func (i *InMemoryNetwork) virtualRouteTable(node state.NodeId, src, dst netip.Ad func (i *InMemoryNetwork) virtualInternet(pkt []byte, len int, from, to bindtest.ChannelEndpoint2) { // simulate network conditions + i.cfg.linksMu.RLock() idx := slices.IndexFunc(i.cfg.Links, func(link *VirtualLink) bool { return link.Edge.V1 == from && link.Edge.V2 == to }) - if idx == -1 { + var link *VirtualLink + if idx != -1 { + link = i.cfg.Links[idx] + } + i.cfg.linksMu.RUnlock() + if link == nil { return // no connection, dropped packet } - i.cfg.Links[idx].simulate(pkt, len, from, to, i) + link.simulate(pkt, len, from, to, i) } func (i *InMemoryNetwork) Bind(node state.NodeId) conn.Bind { diff --git a/integration/ipc_test.go b/integration/ipc_test.go index c3e2c8e..d6b8600 100644 --- a/integration/ipc_test.go +++ b/integration/ipc_test.go @@ -63,49 +63,37 @@ func ipcCall(t *testing.T, n *core.Nylon, req *protocol.IpcRequest) *protocol.Ip func TestIPCStatus(t *testing.T) { defer goleak.VerifyNone(t) - vh, errs := setupTwoNodeHarness(t) + vh, _ := setupTwoNodeHarness(t) defer vh.Stop() // Wait for links to come up time.Sleep(3 * time.Second) - a := vh.Nylons[vh.IndexOf("a")] - done := make(chan *protocol.IpcResponse, 1) - a.Dispatch(func() error { - resp := ipcCall(t, a, &protocol.IpcRequest{ - Request: &protocol.IpcRequest_Status{Status: &protocol.StatusRequest{}}, - }) - done <- resp - return nil + a := vh.Nylons[vh.IndexOf("a")].Load() + resp := ipcCall(t, a, &protocol.IpcRequest{ + Request: &protocol.IpcRequest_Status{Status: &protocol.StatusRequest{}}, }) - select { - case resp := <-done: - assert.True(t, resp.Ok) - s := resp.GetStatus() - require.NotNil(t, s.GetNode()) - assert.Equal(t, "a", s.GetNode().NodeId) - assert.NotEmpty(t, s.GetNode().PublicKey) - assert.Equal(t, int32(1), s.GetNode().GetStats().NeighbourCount) - assert.GreaterOrEqual(t, s.GetNode().GetStats().SelectedRouteCount, int32(1)) - assert.GreaterOrEqual(t, s.GetNode().GetStats().AdvertisedPrefixCount, int32(1)) - - require.Len(t, s.GetNeighbours(), 1) - peer := s.GetNeighbours()[0] - assert.Equal(t, "b", peer.PeerId) - assert.NotEmpty(t, peer.PublicKey) - require.NotEmpty(t, peer.GetEndpoints()) - assert.GreaterOrEqual(t, len(peer.GetEndpoints()), 2) - assert.NotEmpty(t, peer.GetEndpoints()[0].Address) - - assert.GreaterOrEqual(t, len(s.GetRoutes().GetSelected()), 1) - assert.GreaterOrEqual(t, len(s.GetRoutes().GetForward()), 1) - assert.GreaterOrEqual(t, len(s.GetFeasibilityDistances()), 1) - case err := <-errs: - t.Fatal(err) - case <-time.After(10 * time.Second): - t.Fatal("timeout") - } + assert.True(t, resp.Ok) + s := resp.GetStatus() + require.NotNil(t, s.GetNode()) + assert.Equal(t, "a", s.GetNode().NodeId) + assert.NotEmpty(t, s.GetNode().PublicKey) + assert.Equal(t, int32(1), s.GetNode().GetStats().NeighbourCount) + assert.GreaterOrEqual(t, s.GetNode().GetStats().SelectedRouteCount, int32(1)) + assert.GreaterOrEqual(t, s.GetNode().GetStats().AdvertisedPrefixCount, int32(1)) + + require.Len(t, s.GetNeighbours(), 1) + peer := s.GetNeighbours()[0] + assert.Equal(t, "b", peer.PeerId) + assert.NotEmpty(t, peer.PublicKey) + require.NotEmpty(t, peer.GetEndpoints()) + assert.GreaterOrEqual(t, len(peer.GetEndpoints()), 2) + assert.NotEmpty(t, peer.GetEndpoints()[0].Address) + + assert.GreaterOrEqual(t, len(s.GetRoutes().GetSelected()), 1) + assert.GreaterOrEqual(t, len(s.GetRoutes().GetForward()), 1) + assert.GreaterOrEqual(t, len(s.GetFeasibilityDistances()), 1) } func TestIPCReloadConfig(t *testing.T) { @@ -121,7 +109,7 @@ func TestIPCReloadConfig(t *testing.T) { tmpFile := t.TempDir() + "/central.yaml" require.NoError(t, os.WriteFile(tmpFile, cfgData, 0600)) - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() a.ConfigPath = tmpFile done := make(chan *protocol.IpcResponse, 1) go func() { @@ -145,30 +133,18 @@ func TestIPCReloadConfig(t *testing.T) { func TestIPCProbeNonNeighbour(t *testing.T) { defer goleak.VerifyNone(t) - vh, errs := setupTwoNodeHarness(t) + vh, _ := setupTwoNodeHarness(t) defer vh.Stop() time.Sleep(1 * time.Second) - a := vh.Nylons[vh.IndexOf("a")] - done := make(chan *protocol.IpcResponse, 1) - a.Dispatch(func() error { - resp := ipcCall(t, a, &protocol.IpcRequest{ - Request: &protocol.IpcRequest_Probe{Probe: &protocol.ProbeRequest{PeerId: "nonexistent"}}, - }) - done <- resp - return nil + a := vh.Nylons[vh.IndexOf("a")].Load() + resp := ipcCall(t, a, &protocol.IpcRequest{ + Request: &protocol.IpcRequest_Probe{Probe: &protocol.ProbeRequest{PeerId: "nonexistent"}}, }) - select { - case resp := <-done: - assert.False(t, resp.Ok) - assert.Contains(t, resp.Error, "not a neighbour") - case err := <-errs: - t.Fatal(err) - case <-time.After(10 * time.Second): - t.Fatal("timeout") - } + assert.False(t, resp.Ok) + assert.Contains(t, resp.Error, "not a neighbour") } func TestIPCTraceDisabled(t *testing.T) { @@ -178,7 +154,7 @@ func TestIPCTraceDisabled(t *testing.T) { time.Sleep(1 * time.Second) - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() done := make(chan *protocol.IpcResponse, 1) a.Dispatch(func() error { // Trace should fail since DBG_trace_tc is false by default @@ -207,7 +183,7 @@ func TestIPCMalformedRequest(t *testing.T) { time.Sleep(1 * time.Second) - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() done := make(chan bool, 1) a.Dispatch(func() error { // Send garbage @@ -243,7 +219,7 @@ func TestIPCSocketResponseHasNoUAPIErrnoTrailer(t *testing.T) { time.Sleep(1 * time.Second) - a := vh.Nylons[vh.IndexOf("a")] + a := vh.Nylons[vh.IndexOf("a")].Load() client, server := net.Pipe() defer client.Close() go a.Device.IpcHandle(server) diff --git a/integration/race_test.go b/integration/race_test.go new file mode 100644 index 0000000..37b3d0f --- /dev/null +++ b/integration/race_test.go @@ -0,0 +1,121 @@ +//go:build integration + +package integration + +import ( + "fmt" + "log/slog" + "net/netip" + "testing" + "time" + + "github.com/encodeous/nylon/state" + "go.uber.org/goleak" +) + +func TestRapidToggleConfig(t *testing.T) { + defer goleak.VerifyNone(t) + + vh := &VirtualHarness{} + vh.LogLevel = new(slog.LevelInfo) + a1 := "192.168.1.1:1234" + vh.NewNode("a", "10.0.0.1/32") + b1 := "192.168.1.2:1234" + vh.NewNode("b", "10.0.0.2/32") + vh.Central.Graph = []string{ + "a, b", + } + vh.Endpoints = map[string]state.NodeId{ + a1: "a", + b1: "b", + } + vh.AddLink(a1, b1) + vh.AddLink(b1, a1) + + errs := vh.Start() + + vn := vh.Net + vn.SelfHandler = func(node state.NodeId, src, dst netip.Addr, data []byte) bool { + return true + } + + // Wait for initial convergence. + time.Sleep(3 * time.Second) + + _, baseCfg := vh.Central.Clone() + _, extraCfg := vh.Central.Clone() + + aIdx := vh.IndexOf("a") + extraRouter := extraCfg.Routers[aIdx] + for i := 0; i < 50; i++ { + extraRouter.Prefixes = append(extraRouter.Prefixes, state.PrefixHealthWrapper{ + PrefixHealth: &state.StaticPrefixHealth{ + Prefix: netip.MustParsePrefix(fmt.Sprintf("10.1.0.%d/32", i)), + Metric: 0, + }, + }) + } + extraCfg.Routers[aIdx] = extraRouter + + a := vh.Nylons[vh.IndexOf("a")].Load() + + // Break shared *DynamicEndpoint pointers from startup. + { + done := make(chan struct{}) + a.Dispatch(func() error { + _, nc := a.CentralCfg.Clone() + a.CentralCfg = *nc + close(done) + return nil + }) + <-done + } + + // Send packets every 10ms to trigger ForwardTable.Lookup via TC filter. + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-vh.Context.Done(): + return + case <-ticker.C: + vn.Send("a", "10.0.0.1", "10.0.0.2", []byte{1}, 64) + } + } + }() + + // Rapidly toggle config until test duration expires. + deadline := time.After(5 * time.Second) + for i := 0; ; i++ { + cfg := extraCfg + if i%2 == 1 { + cfg = baseCfg + } + cfg.Timestamp = baseCfg.Timestamp + int64(i) + 2 + done := make(chan struct{}) + _, ccfg := cfg.Clone() + a.Dispatch(func() error { + defer close(done) + _, err := a.ApplyCentralConfig(ccfg) + if err != nil { + return err + } + return nil + }) + select { + case <-done: + time.Sleep(10 * time.Millisecond) + case <-deadline: + goto end + case err := <-errs: + t.Fatalf("harness error: %v", err) + } + } +end: + close(stop) + vh.Stop() +} diff --git a/integration/routing_test.go b/integration/routing_test.go index ea2c12c..f07085e 100644 --- a/integration/routing_test.go +++ b/integration/routing_test.go @@ -3,11 +3,12 @@ package integration import ( - "github.com/encodeous/nylon/state" - "go.uber.org/goleak" "net/netip" "testing" "time" + + "github.com/encodeous/nylon/state" + "go.uber.org/goleak" ) func TestInProcessRouting(t *testing.T) { diff --git a/state/config.go b/state/config.go index b8bde37..083493e 100644 --- a/state/config.go +++ b/state/config.go @@ -7,6 +7,7 @@ import ( "slices" "strings" + "github.com/goccy/go-yaml" "go4.org/netipx" ) @@ -65,6 +66,18 @@ type LocalCfg struct { PostDown []string `yaml:"post_down,omitempty"` // a list of commands executed in order after the nylon interface is brought down } +func (c *CentralCfg) Clone() (error, *CentralCfg) { + data, err := yaml.Marshal(c) + if err != nil { + return err, nil + } + var dst CentralCfg + if err = yaml.Unmarshal(data, &dst); err != nil { + return err, nil + } + return nil, &dst +} + // GetPrefixes returns all unique prefixes from all nodes func (c *CentralCfg) GetPrefixes() []netip.Prefix { prefixMap := make(map[netip.Prefix]bool) From ee87d7140688c9f6b80c1293ce703b9399b01ad0 Mon Sep 17 00:00:00 2001 From: Adam Chen Date: Thu, 7 May 2026 00:37:57 +0000 Subject: [PATCH 2/2] fix(core): add synchronization to endpoint.go --- state/endpoint.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/state/endpoint.go b/state/endpoint.go index c404471..7b509e3 100644 --- a/state/endpoint.go +++ b/state/endpoint.go @@ -153,6 +153,7 @@ func (ep *DynamicEndpoint) MarshalYAML() (interface{}, error) { } type NylonEndpoint struct { + sync.RWMutex // this mutex is for rtt smoothing and metric calculation history []time.Duration histSort []time.Duration dirty bool @@ -198,12 +199,20 @@ func (n *Neighbour) BestEndpoint() Endpoint { return best } -func (u *NylonEndpoint) IsActive() bool { +func (u *NylonEndpoint) isActiveUnlocked() bool { return time.Since(u.lastHeardBack) <= LinkDeadThreshold } +func (u *NylonEndpoint) IsActive() bool { + u.RLock() + defer u.RUnlock() + return u.isActiveUnlocked() +} + func (u *NylonEndpoint) Renew() { - if !u.IsActive() { + u.Lock() + defer u.Unlock() + if !u.isActiveUnlocked() { u.history = u.history[:0] u.expRTT = math.Inf(1) u.dirty = true @@ -226,6 +235,8 @@ func NewEndpoint(endpoint *DynamicEndpoint, remoteInit bool, wgEndpoint conn.End } func (u *NylonEndpoint) calcR() (time.Duration, time.Duration, time.Duration) { + u.Lock() + defer u.Unlock() if len(u.history) < MinimumConfidenceWindow { return time.Second * 1, time.Second * 1, time.Second * 1 } @@ -265,6 +276,8 @@ func (u *NylonEndpoint) StabilizedPing() time.Duration { } func (u *NylonEndpoint) UpdatePing(ping time.Duration) { + u.Lock() + defer u.Unlock() // sometimes our system clock is not fast enough, so ping is 0 if ping == 0 { ping = time.Microsecond * 100 @@ -276,7 +289,7 @@ func (u *NylonEndpoint) UpdatePing(ping time.Duration) { u.expRTT = f } u.expRTT = alpha*f + (1-alpha)*u.expRTT - u.history = append(u.history, u.FilteredPing()) + u.history = append(u.history, time.Duration(int64(u.expRTT))) if len(u.history) > WindowSamples { u.history = u.history[1:] }