diff --git a/app/app.go b/app/app.go index 7c7c02c4d..51926f3ab 100644 --- a/app/app.go +++ b/app/app.go @@ -664,7 +664,7 @@ func wireCoreWorkflow(ctx context.Context, life *lifecycle.Manager, conf Config, // Consensus consensusController, err := consensus.NewConsensusController( ctx, eth2Cl, p2pNode, sender, peers, p2pKey, - deadlineFunc, gaterFunc, consensusDebugger) + deadlineFunc, gaterFunc, consensusDebugger, featureset.Enabled(featureset.ChainSplitHalt)) if err != nil { return err } diff --git a/core/consensus/controller.go b/core/consensus/controller.go index 67dc3d2ed..76201b4bc 100644 --- a/core/consensus/controller.go +++ b/core/consensus/controller.go @@ -40,11 +40,11 @@ type consensusController struct { // NewConsensusController creates a new consensus controller with the default consensus protocol. func NewConsensusController(ctx context.Context, eth2Cl eth2wrap.Client, p2pNode host.Host, sender *p2p.Sender, peers []p2p.Peer, p2pKey *k1.PrivateKey, deadlineFunc core.DeadlineFunc, - gaterFunc core.DutyGaterFunc, debugger Debugger, + gaterFunc core.DutyGaterFunc, debugger Debugger, compareAttestations bool, ) (core.ConsensusController, error) { qbftDeadliner := core.NewDeadliner(ctx, "consensus.qbft", deadlineFunc) - defaultConsensus, err := qbft.NewConsensus(ctx, eth2Cl, p2pNode, sender, peers, p2pKey, qbftDeadliner, gaterFunc, debugger.AddInstance) + defaultConsensus, err := qbft.NewConsensus(ctx, eth2Cl, p2pNode, sender, peers, p2pKey, qbftDeadliner, gaterFunc, debugger.AddInstance, compareAttestations) if err != nil { return nil, err } diff --git a/core/consensus/controller_test.go b/core/consensus/controller_test.go index 3d102fdd4..8572465e0 100644 --- a/core/consensus/controller_test.go +++ b/core/consensus/controller_test.go @@ -66,7 +66,7 @@ func TestConsensusController(t *testing.T) { bmock, err := beaconmock.New(ctx) require.NoError(t, err) - controller, err := consensus.NewConsensusController(ctx, bmock, hosts[0], new(p2p.Sender), peers, p2pkeys[0], deadlineFunc, gaterFunc, debugger) + controller, err := consensus.NewConsensusController(ctx, bmock, hosts[0], new(p2p.Sender), peers, p2pkeys[0], deadlineFunc, gaterFunc, debugger, false) require.NoError(t, err) require.NotNil(t, controller) diff --git a/core/consensus/instance/instance_io.go b/core/consensus/instance/instance_io.go index 8d3845ab5..b3e6b74bb 100644 --- a/core/consensus/instance/instance_io.go +++ b/core/consensus/instance/instance_io.go @@ -12,7 +12,7 @@ import ( ) const ( - RecvBufferSize = 512 // Allow buffering some initial messages when this node is late to start an instance. + RecvBufferSize = 100 // Allow buffering some initial messages when this node is late to start an instance. ) // NewIO returns a new instanceIO. diff --git a/core/consensus/qbft/msg.go b/core/consensus/qbft/msg.go index c67563a50..8ad601704 100644 --- a/core/consensus/qbft/msg.go +++ b/core/consensus/qbft/msg.go @@ -41,7 +41,7 @@ func newMsg(pbMsg *pbv1.QBFTMsg, justification []*pbv1.QBFTMsg, values map[[32]b } } - var justImpls []qbft.Msg[core.Duty, [32]byte] + var justImpls []qbft.Msg[core.Duty, [32]byte, proto.Message] for _, j := range justification { impl, err := newMsg(j, nil, values) @@ -62,7 +62,7 @@ func newMsg(pbMsg *pbv1.QBFTMsg, justification []*pbv1.QBFTMsg, values map[[32]b }, nil } -// Msg wraps *pbv1.QBFTMsg and justifications and implements qbft.Msg[core.Duty, [32]byte]. +// Msg wraps *pbv1.QBFTMsg and justifications and implements qbft.Msg[core.Duty, [32]byte, proto.Message]. type Msg struct { msg *pbv1.QBFTMsg valueHash [32]byte @@ -70,7 +70,7 @@ type Msg struct { values map[[32]byte]*anypb.Any justificationProtos []*pbv1.QBFTMsg - justification []qbft.Msg[core.Duty, [32]byte] + justification []qbft.Msg[core.Duty, [32]byte, proto.Message] } func (m Msg) Type() qbft.MsgType { @@ -118,7 +118,7 @@ func (m Msg) PreparedValue() [32]byte { return m.preparedValueHash } -func (m Msg) Justification() []qbft.Msg[core.Duty, [32]byte] { +func (m Msg) Justification() []qbft.Msg[core.Duty, [32]byte, proto.Message] { return m.justification } @@ -227,4 +227,4 @@ func toHash32(val []byte) ([32]byte, bool) { return resp, true } -var _ qbft.Msg[core.Duty, [32]byte] = Msg{} // Interface assertion +var _ qbft.Msg[core.Duty, [32]byte, proto.Message] = Msg{} // Interface assertion diff --git a/core/consensus/qbft/qbft.go b/core/consensus/qbft/qbft.go index 76d47ea51..77caa588e 100644 --- a/core/consensus/qbft/qbft.go +++ b/core/consensus/qbft/qbft.go @@ -5,6 +5,7 @@ package qbft import ( "context" "fmt" + "slices" "strings" "sync" "time" @@ -36,20 +37,22 @@ import ( type subscriber func(ctx context.Context, duty core.Duty, value proto.Message) error +var supportedCompareDuties = []core.DutyType{core.DutyAttester} + // newDefinition returns a qbft definition (this is constant across all consensus instances). func newDefinition(nodes int, subs func() []subscriber, roundTimer timer.RoundTimer, - decideCallback func(qcommit []qbft.Msg[core.Duty, [32]byte]), -) qbft.Definition[core.Duty, [32]byte] { - quorum := qbft.Definition[int, int]{Nodes: nodes}.Quorum() + decideCallback func(qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]), compareAttestations bool, +) qbft.Definition[core.Duty, [32]byte, proto.Message] { + quorum := qbft.Definition[core.Duty, [32]byte, proto.Message]{Nodes: nodes}.Quorum() - return qbft.Definition[core.Duty, [32]byte]{ + return qbft.Definition[core.Duty, [32]byte, proto.Message]{ // IsLeader is a deterministic leader election function. IsLeader: func(duty core.Duty, round, process int64) bool { return leader(duty, round, nodes) == process }, // Decide sends consensus output to subscribers. - Decide: func(ctx context.Context, duty core.Duty, _ [32]byte, qcommit []qbft.Msg[core.Duty, [32]byte]) { + Decide: func(ctx context.Context, duty core.Duty, _ [32]byte, qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]) { msg, ok := qcommit[0].(Msg) if !ok { log.Error(ctx, "Internal error: Invalid message type in qcommit. This indicates a consensus protocol bug that should be reported", nil, z.Str("got_type", fmt.Sprintf("%T", qcommit[0]))) @@ -77,18 +80,84 @@ func newDefinition(nodes int, subs func() []subscriber, roundTimer timer.RoundTi } }, + Compare: func(ctx context.Context, msg qbft.Msg[core.Duty, [32]byte, proto.Message], inputValueSourceCh <-chan proto.Message, inputValueSource proto.Message, returnErrCh chan error, returnProtoCh chan proto.Message) { + if !compareAttestations { + returnErrCh <- nil + return + } + + if !slices.Contains(supportedCompareDuties, msg.Instance().Type) { + returnErrCh <- nil + return + } + + attLeaderAnyPbProto, err := msg.ValueSource() + if err != nil { + returnErrCh <- errors.Wrap(err, "msg has no value source", z.Any("msg", msg)) + return + } + + attLeaderAnyPb, ok := attLeaderAnyPbProto.(*anypb.Any) + if !ok { + returnErrCh <- errors.New("protoMessage interface to *anypb.Any struct", z.Any("attLeaderAnyPbProto", attLeaderAnyPbProto)) + return + } + + attLeaderSetProto, err := attLeaderAnyPb.UnmarshalNew() + if err != nil { + returnErrCh <- errors.Wrap(err, "unmarshal *anypb.Any", z.Any("attLeaderAnyPb", attLeaderAnyPb)) + return + } + + attLeaderSet, ok := attLeaderSetProto.(*pbv1.UnsignedDataSet) + if !ok { + returnErrCh <- errors.New("protoMessage interface to *pbv1.UnsignedDataSet struct", z.Any("attLeaderSetProto", attLeaderSetProto)) + return + } + + switch msg.Instance().Type { + case core.DutyAttester: + if inputValueSource == nil { + select { + case <-ctx.Done(): + returnErrCh <- errors.New("timeout on waiting for local value") + return + case inputValueSource = <-inputValueSourceCh: + returnProtoCh <- inputValueSource + } + } + + attLocalSet, ok := inputValueSource.(*pbv1.UnsignedDataSet) + if !ok { + returnErrCh <- errors.New("inputValueSource to pbv1.UnsignedDataSet") + return + } + + err = attestationChecker(ctx, attLeaderSet, attLocalSet) + if err != nil { + returnErrCh <- errors.Wrap(err, "attestation checker failed") + return + } + default: + returnErrCh <- errors.New("bug: checking not supported duty", z.Any("duty", msg.Instance().Type)) + return + } + + returnErrCh <- nil + }, + NewTimer: roundTimer.Timer, // LogUponRule logs upon rules at debug level. LogUponRule: func(ctx context.Context, _ core.Duty, _, round int64, - _ qbft.Msg[core.Duty, [32]byte], uponRule qbft.UponRule, + _ qbft.Msg[core.Duty, [32]byte, proto.Message], uponRule qbft.UponRule, ) { log.Debug(ctx, "QBFT upon rule triggered", z.Any("rule", uponRule), z.I64("round", round)) }, // LogRoundChange logs round changes at debug level. LogRoundChange: func(ctx context.Context, duty core.Duty, process, round, newRound int64, //nolint:revive // keep process variable name for clarity - uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte], + uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte, proto.Message], ) { fields := []z.Field{ z.Any("rule", uponRule), @@ -108,24 +177,13 @@ func newDefinition(nodes int, subs func() []subscriber, roundTimer timer.RoundTi log.Debug(ctx, "QBFT round changed", fields...) }, - LogUnjust: func(ctx context.Context, _ core.Duty, _ int64, msg qbft.Msg[core.Duty, [32]byte]) { + LogUnjust: func(ctx context.Context, _ core.Duty, _ int64, msg qbft.Msg[core.Duty, [32]byte, proto.Message]) { log.Warn(ctx, "Unjustified consensus message from peer", nil, z.Any("type", msg.Type()), z.I64("peer", msg.Source()), ) }, - LogDebug: func(ctx context.Context, _ core.Duty, _ int64, msg qbft.Msg[core.Duty, [32]byte], logMsg string) { - if msg != nil { - log.Debug(ctx, logMsg, - z.Any("type", msg.Type()), - z.I64("peer", msg.Source()), - ) - } else { - log.Debug(ctx, logMsg) - } - }, - // Nodes is the number of nodes. Nodes: nodes, @@ -134,9 +192,62 @@ func newDefinition(nodes int, subs func() []subscriber, roundTimer timer.RoundTi } } +func attestationChecker(ctx context.Context, attLeaderSet *pbv1.UnsignedDataSet, attLocalSet *pbv1.UnsignedDataSet) error { + attLocalSetCore, err := core.UnsignedDataSetFromProto(core.DutyAttester, attLocalSet) + if err != nil { + return errors.Wrap(err, "attLocal to unsigned data set duty attester") + } + + attLeaderSetCore, err := core.UnsignedDataSetFromProto(core.DutyAttester, attLeaderSet) + if err != nil { + return errors.Wrap(err, "attLeader to unsigned data set duty attester") + } + + for attLeaderKey, attLeaderData := range attLeaderSetCore { + attLocalData, ok := attLocalSetCore[attLeaderKey] + if !ok { + log.Warn(ctx, "", errors.New("no local attestation found, skipping"), z.Any("pk", attLeaderKey)) + continue + } + + attLeaderAttestationData, ok := attLeaderData.(core.AttestationData) + if !ok { + return errors.New("unable to parse leader unsigned data to core attestation data", z.Any("data", attLeaderData)) + } + + attLocalAttestationData, ok := attLocalData.(core.AttestationData) + if !ok { + return errors.New("unable to parse local unsigned data to core attestation data", z.Any("data", attLocalData)) + } + + mismatch := "" + if attLeaderAttestationData.Data.Source.Epoch != attLocalAttestationData.Data.Source.Epoch { + mismatch += "leader attestation source epoch differs from local source epoch;" + } + + if attLeaderAttestationData.Data.Source.Root != attLocalAttestationData.Data.Source.Root { + mismatch += "leader attestation source root differs from local source root;" + } + + if attLeaderAttestationData.Data.Target.Epoch != attLocalAttestationData.Data.Target.Epoch { + mismatch += "leader attestation target epoch differs from local target epoch;" + } + + if attLeaderAttestationData.Data.Target.Root != attLocalAttestationData.Data.Target.Root { + mismatch += "leader attestation target root differs from local target root;" + } + + if mismatch != "" { + return errors.New(mismatch, z.Any("public_key", attLeaderKey), z.Any("leader", attLeaderAttestationData.Data), z.Any("local", attLocalAttestationData.Data)) + } + } + + return nil +} + // NewConsensus returns a new consensus QBFT component. func NewConsensus(ctx context.Context, eth2Cl eth2wrap.Client, p2pNode host.Host, sender *p2p.Sender, peers []p2p.Peer, p2pKey *k1.PrivateKey, - deadliner core.Deadliner, gaterFunc core.DutyGaterFunc, snifferFunc func(*pbv1.SniffedConsensusInstance), + deadliner core.Deadliner, gaterFunc core.DutyGaterFunc, snifferFunc func(*pbv1.SniffedConsensusInstance), compareAttestations bool, ) (*Consensus, error) { // Extract peer pubkeys. keys := make(map[int64]*k1.PublicKey) @@ -164,18 +275,19 @@ func NewConsensus(ctx context.Context, eth2Cl eth2wrap.Client, p2pNode host.Host } c := &Consensus{ - p2pNode: p2pNode, - sender: sender, - peers: peers, - peerLabels: labels, - privkey: p2pKey, - pubkeys: keys, - deadliner: deadliner, - snifferFunc: snifferFunc, - gaterFunc: gaterFunc, - dropFilter: log.Filter(), - timerFunc: timer.GetRoundTimerFunc(genesisTime, slotDuration), - metrics: metrics.NewConsensusMetrics(protocols.QBFTv2ProtocolID), + p2pNode: p2pNode, + sender: sender, + peers: peers, + peerLabels: labels, + privkey: p2pKey, + pubkeys: keys, + deadliner: deadliner, + snifferFunc: snifferFunc, + gaterFunc: gaterFunc, + dropFilter: log.Filter(), + timerFunc: timer.GetRoundTimerFunc(genesisTime, slotDuration), + metrics: metrics.NewConsensusMetrics(protocols.QBFTv2ProtocolID), + compareAttestations: compareAttestations, } c.mutable.instances = make(map[core.Duty]*instance.IO[Msg]) @@ -185,19 +297,20 @@ func NewConsensus(ctx context.Context, eth2Cl eth2wrap.Client, p2pNode host.Host // Consensus implements core.Consensus & priority.coreConsensus. type Consensus struct { // Immutable state - p2pNode host.Host - sender *p2p.Sender - peerLabels []string - peers []p2p.Peer - pubkeys map[int64]*k1.PublicKey - privkey *k1.PrivateKey - subs []subscriber - deadliner core.Deadliner - snifferFunc func(*pbv1.SniffedConsensusInstance) - gaterFunc core.DutyGaterFunc - dropFilter z.Field // Filter buffer overflow errors (possible DDoS) - timerFunc timer.RoundTimerFunc - metrics metrics.ConsensusMetrics + p2pNode host.Host + sender *p2p.Sender + peerLabels []string + peers []p2p.Peer + pubkeys map[int64]*k1.PublicKey + privkey *k1.PrivateKey + subs []subscriber + deadliner core.Deadliner + snifferFunc func(*pbv1.SniffedConsensusInstance) + gaterFunc core.DutyGaterFunc + dropFilter z.Field // Filter buffer overflow errors (possible DDoS) + timerFunc timer.RoundTimerFunc + metrics metrics.ConsensusMetrics + compareAttestations bool // Mutable state mutable struct { @@ -318,6 +431,14 @@ func (c *Consensus) propose(ctx context.Context, duty core.Duty, value proto.Mes return errors.New("input channel full") } + if c.compareAttestations { + select { + case inst.VerifyCh <- value: + default: + return errors.New("input channel full") + } + } + // Instrument consensus duration using decidedAt output. proposedAt := time.Now() @@ -397,12 +518,6 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err z.Any("timer", string(roundTimer.Type())), ) - log.Debug(ctx, "QBFT fetching instance IO", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - inst := c.getInstanceIO(duty) defer func() { @@ -411,19 +526,8 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err var span trace.Span - log.Debug(ctx, "QBFT starting duty trace", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - ctx, span = core.StartDutyTrace(ctx, duty, "core/qbft.runInstance") - log.Debug(ctx, "QBFT checking if duty is expired", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) if !c.deadliner.Add(duty) { span.AddEvent("Expired Duty Skipped") log.Warn(ctx, "Skipping consensus for expired duty", nil) @@ -431,12 +535,6 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err return nil } - log.Debug(ctx, "QBFT getting peer index", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - peerIdx, err := c.getPeerIdx() if err != nil { return err @@ -458,7 +556,7 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err span.End() }() - decideCallback := func(qcommit []qbft.Msg[core.Duty, [32]byte]) { + decideCallback := func(qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]) { round := qcommit[0].Round() decided = true @@ -485,16 +583,10 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err cancel() } - log.Debug(ctx, "QBFT create new definition", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - // Create a new qbft definition for this instance. - def := newDefinition(len(c.peers), c.subscribers, roundTimer, decideCallback) + def := newDefinition(len(c.peers), c.subscribers, roundTimer, decideCallback, c.compareAttestations) origLogRoundChange := def.LogRoundChange - def.LogRoundChange = func(ctx context.Context, instance core.Duty, process, round, newRound int64, uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte]) { + def.LogRoundChange = func(ctx context.Context, instance core.Duty, process, round, newRound int64, uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte, proto.Message]) { if origLogRoundChange != nil { origLogRoundChange(ctx, instance, process, round, newRound, uponRule, msgs) } @@ -503,43 +595,25 @@ func (c *Consensus) runInstance(parent context.Context, duty core.Duty) (err err span.SetAttributes(attribute.Int64("new_round", newRound)) } - log.Debug(ctx, "QBFT create new transport", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - // Create a new transport that handles sending and receiving for this instance. - t := newTransport(c, c.privkey, inst.ValueCh, make(chan qbft.Msg[core.Duty, [32]byte]), newSniffer(int64(def.Nodes), peerIdx)) + t := newTransport(c, c.privkey, inst.ValueCh, make(chan qbft.Msg[core.Duty, [32]byte, proto.Message]), newSniffer(int64(def.Nodes), peerIdx)) // Provide sniffed buffer to snifferFunc at the end. defer func() { c.snifferFunc(t.SnifferInstance()) }() - log.Debug(ctx, "QBFT start a receiving go routine", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - // Start a receiving goroutine. go t.ProcessReceives(ctx, c.getRecvBuffer(duty)) // Create a qbft transport from the transport - qt := qbft.Transport[core.Duty, [32]byte]{ + qt := qbft.Transport[core.Duty, [32]byte, proto.Message]{ Broadcast: t.Broadcast, Receive: t.RecvBuffer(), } - log.Debug(ctx, "QBFT run", - z.Any("peer", p2p.PeerName(c.p2pNode.ID())), - z.Any("peers", c.peerLabels), - z.Any("timer", string(roundTimer.Type())), - ) - // Run the algo, blocking until the context is cancelled. - err = qbft.Run(ctx, def, qt, duty, peerIdx, inst.HashCh) + err = qbft.Run(ctx, def, qt, duty, peerIdx, inst.HashCh, inst.VerifyCh) if err != nil && !isContextErr(err) { span.AddEvent("qbft.Error") c.metrics.IncConsensusError() @@ -613,15 +687,8 @@ func (c *Consensus) handle(ctx context.Context, _ peer.ID, req proto.Message) (p return nil, false, errors.New("duty expired", z.Any("duty", duty), c.dropFilter) } - recvBuffer := c.getRecvBuffer(duty) - log.Debug(ctx, "QBFT recv buffer enqueue", - z.Any("duty", duty), - z.Int("buffer_len", len(recvBuffer)), - z.Int("buffer_cap", cap(recvBuffer)), - ) - select { - case recvBuffer <- msg: + case c.getRecvBuffer(duty) <- msg: return nil, false, nil case <-ctx.Done(): return nil, false, errors.Wrap(ctx.Err(), "timeout enqueuing receive buffer", @@ -730,7 +797,7 @@ type roundStep struct { } // groupRoundMessages groups messages by type and returns which peers were present and missing for each type. -func groupRoundMessages(msgs []qbft.Msg[core.Duty, [32]byte], peers int, round int64, leader int) []roundStep { +func groupRoundMessages(msgs []qbft.Msg[core.Duty, [32]byte, proto.Message], peers int, round int64, leader int) []roundStep { // checkPeers returns two slices of peer indexes, one with peers // present with the message type and one with messing peers. checkPeers := func(typ qbft.MsgType) (present []int, missing []int) { diff --git a/core/consensus/qbft/qbft_internal_test.go b/core/consensus/qbft/qbft_internal_test.go index 72f830109..419e6a73e 100644 --- a/core/consensus/qbft/qbft_internal_test.go +++ b/core/consensus/qbft/qbft_internal_test.go @@ -31,7 +31,7 @@ func TestDebugRoundChange(t *testing.T) { tests := []struct { name string - msgs []qbft.Msg[core.Duty, [32]byte] + msgs []qbft.Msg[core.Duty, [32]byte, proto.Message] round int64 leader int }{ @@ -45,7 +45,7 @@ func TestDebugRoundChange(t *testing.T) { }, { name: "quorum", - msgs: []qbft.Msg[core.Duty, [32]byte]{ + msgs: []qbft.Msg[core.Duty, [32]byte, proto.Message]{ m(0, qbft.MsgRoundChange), m(1, qbft.MsgRoundChange), m(2, qbft.MsgRoundChange), @@ -125,7 +125,7 @@ func (t testMsg) PreparedValue() [32]byte { panic("implement me") } -func (t testMsg) Justification() []qbft.Msg[core.Duty, [32]byte] { +func (t testMsg) Justification() []qbft.Msg[core.Duty, [32]byte, proto.Message] { panic("implement me") } diff --git a/core/consensus/qbft/qbft_test.go b/core/consensus/qbft/qbft_test.go index dff346556..137cc8550 100644 --- a/core/consensus/qbft/qbft_test.go +++ b/core/consensus/qbft/qbft_test.go @@ -142,7 +142,7 @@ func testQBFTConsensus(t *testing.T, threshold, nodes int) { bmock, err := beaconmock.New(t.Context(), beaconmock.WithGenesisTime(time.Time{})) require.NoError(t, err) - c, err := qbft.NewConsensus(t.Context(), bmock, hosts[i], new(p2p.Sender), peers, p2pkeys[i], deadliner, gaterFunc, sniffer) + c, err := qbft.NewConsensus(t.Context(), bmock, hosts[i], new(p2p.Sender), peers, p2pkeys[i], deadliner, gaterFunc, sniffer, false) require.NoError(t, err) c.Subscribe(func(_ context.Context, _ core.Duty, set core.UnsignedDataSet) error { results <- set diff --git a/core/consensus/qbft/sniffed_internal_test.go b/core/consensus/qbft/sniffed_internal_test.go index 25b6eab7f..517fbf04e 100644 --- a/core/consensus/qbft/sniffed_internal_test.go +++ b/core/consensus/qbft/sniffed_internal_test.go @@ -79,9 +79,9 @@ func testSniffedInstance(ctx context.Context, t *testing.T, instance *pbv1.Sniff return nil }} - }, timer.NewIncreasingRoundTimer(), func(qcommit []qbft.Msg[core.Duty, [32]byte]) {}) + }, timer.NewIncreasingRoundTimer(), func(qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]) {}, false) - recvBuffer := make(chan qbft.Msg[core.Duty, [32]byte], len(instance.GetMsgs())) + recvBuffer := make(chan qbft.Msg[core.Duty, [32]byte, proto.Message], len(instance.GetMsgs())) var duty core.Duty @@ -102,10 +102,10 @@ func testSniffedInstance(ctx context.Context, t *testing.T, instance *pbv1.Sniff } // Create a qbft transport from the transport - qt := qbft.Transport[core.Duty, [32]byte]{ + qt := qbft.Transport[core.Duty, [32]byte, proto.Message]{ Broadcast: func(context.Context, qbft.MsgType, core.Duty, int64, int64, [32]byte, int64, [32]byte, - []qbft.Msg[core.Duty, [32]byte], + []qbft.Msg[core.Duty, [32]byte, proto.Message], ) error { return nil }, @@ -113,7 +113,7 @@ func testSniffedInstance(ctx context.Context, t *testing.T, instance *pbv1.Sniff } // Run the algo, blocking until the context is cancelled. - err := qbft.Run(ctx, def, qt, duty, instance.GetPeerIdx(), qbft.InputValue([32]byte{1})) + err := qbft.Run(ctx, def, qt, duty, instance.GetPeerIdx(), qbft.InputValue([32]byte{1}), qbft.InputValueSource(proto.Message(newRandomQBFTMsg(t)))) if expectDecided { require.ErrorIs(t, err, context.Canceled) } else { diff --git a/core/consensus/qbft/strategysim_internal_test.go b/core/consensus/qbft/strategysim_internal_test.go index 3499be79f..5b104e7ac 100644 --- a/core/consensus/qbft/strategysim_internal_test.go +++ b/core/consensus/qbft/strategysim_internal_test.go @@ -23,6 +23,7 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -359,7 +360,7 @@ func testStrategySimulator(t *testing.T, conf ssConfig, syncer zapcore.WriteSync var ( peerIDs []peerID - transports []qbft.Transport[core.Duty, [32]byte] + transports []qbft.Transport[core.Duty, [32]byte, proto.Message] ) for peerIdx := range conf.latencyPerPeer { peerIDs = append(peerIDs, peerID{Idx: peerIdx, OK: true}) @@ -376,7 +377,7 @@ func testStrategySimulator(t *testing.T, conf ssConfig, syncer zapcore.WriteSync def := newSimDefinition( len(conf.latencyPerPeer), conf.roundTimerFunc(clock), - func(qcommit []qbft.Msg[core.Duty, [32]byte]) { + func(qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]) { res = result{ PeerIdx: p.Idx, Decided: true, @@ -390,6 +391,7 @@ func testStrategySimulator(t *testing.T, conf ssConfig, syncer zapcore.WriteSync // Setup unique non-zero value per peer valCh := make(chan [32]byte, 1) + valSrcCh := make(chan proto.Message, 1) enqueueValue := func() { var val [32]byte @@ -419,7 +421,7 @@ func testStrategySimulator(t *testing.T, conf ssConfig, syncer zapcore.WriteSync log.Debug(ctx, "Starting peer") - err := qbft.Run(ctx, def, transports[p.Idx], core.Duty{Slot: uint64(conf.seed)}, p.Idx, valCh) + err := qbft.Run(ctx, def, transports[p.Idx], core.Duty{Slot: uint64(conf.seed)}, p.Idx, valCh, valSrcCh) if err != nil && !errors.Is(err, context.Canceled) { return res, err } @@ -484,22 +486,24 @@ func gosched() { } func newSimDefinition(nodes int, roundTimer timer.RoundTimer, - decideCallback func(qcommit []qbft.Msg[core.Duty, [32]byte]), -) qbft.Definition[core.Duty, [32]byte] { - quorum := qbft.Definition[int, int]{Nodes: nodes}.Quorum() + decideCallback func(qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]), +) qbft.Definition[core.Duty, [32]byte, proto.Message] { + quorum := qbft.Definition[int, int, int]{Nodes: nodes}.Quorum() - return qbft.Definition[core.Duty, [32]byte]{ + return qbft.Definition[core.Duty, [32]byte, proto.Message]{ IsLeader: func(duty core.Duty, round, process int64) bool { return leader(duty, round, nodes) == process }, - Decide: func(ctx context.Context, duty core.Duty, _ [32]byte, qcommit []qbft.Msg[core.Duty, [32]byte]) { + Decide: func(ctx context.Context, duty core.Duty, _ [32]byte, qcommit []qbft.Msg[core.Duty, [32]byte, proto.Message]) { decideCallback(qcommit) }, + Compare: func(ctx context.Context, qcommit qbft.Msg[core.Duty, [32]byte, proto.Message], inputValueSourceCh <-chan proto.Message, inputValueSource proto.Message, returnErr chan error, returnRes chan proto.Message) { + returnErr <- nil + }, NewTimer: roundTimer.Timer, - LogUnjust: func(context.Context, core.Duty, int64, qbft.Msg[core.Duty, [32]byte]) {}, - LogDebug: func(context.Context, core.Duty, int64, qbft.Msg[core.Duty, [32]byte], string) {}, + LogUnjust: func(context.Context, core.Duty, int64, qbft.Msg[core.Duty, [32]byte, proto.Message]) {}, LogRoundChange: func(ctx context.Context, duty core.Duty, process, - round, newRound int64, uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte], + round, newRound int64, uponRule qbft.UponRule, msgs []qbft.Msg[core.Duty, [32]byte, proto.Message], ) { fields := []z.Field{ z.Any("rule", uponRule), @@ -520,7 +524,7 @@ func newSimDefinition(nodes int, roundTimer timer.RoundTimer, }, // LogUponRule logs upon rules at debug level. LogUponRule: func(ctx context.Context, _ core.Duty, _, round int64, - _ qbft.Msg[core.Duty, [32]byte], uponRule qbft.UponRule, + _ qbft.Msg[core.Duty, [32]byte, proto.Message], uponRule qbft.UponRule, ) { log.Debug(ctx, "QBFT upon rule triggered", z.Any("rule", uponRule), z.I64("round", round)) }, @@ -540,7 +544,7 @@ type result struct { } type tuple struct { - Msg qbft.Msg[core.Duty, [32]byte] + Msg qbft.Msg[core.Duty, [32]byte, proto.Message] To int64 Arrive time.Time } @@ -568,7 +572,7 @@ type transportSimulator struct { instances map[int64]*transportInstance } -func (s *transportSimulator) enqueue(msg qbft.Msg[core.Duty, [32]byte]) { +func (s *transportSimulator) enqueue(msg qbft.Msg[core.Duty, [32]byte, proto.Message]) { s.mu.Lock() defer s.mu.Unlock() @@ -615,7 +619,7 @@ func (s *transportSimulator) processBuffer() { s.buffer = remaining } -func (s *transportSimulator) instance(peerIdx int64) qbft.Transport[core.Duty, [32]byte] { +func (s *transportSimulator) instance(peerIdx int64) qbft.Transport[core.Duty, [32]byte, proto.Message] { s.mu.Lock() defer s.mu.Unlock() @@ -624,12 +628,12 @@ func (s *transportSimulator) instance(peerIdx int64) qbft.Transport[core.Duty, [ inst = &transportInstance{ transportSimulator: s, peerIdx: peerIdx, - receive: make(chan qbft.Msg[core.Duty, [32]byte], 1000), + receive: make(chan qbft.Msg[core.Duty, [32]byte, proto.Message], 1000), } s.instances[peerIdx] = inst } - return qbft.Transport[core.Duty, [32]byte]{ + return qbft.Transport[core.Duty, [32]byte, proto.Message]{ Broadcast: inst.Broadcast, Receive: inst.Receive(), } @@ -639,12 +643,12 @@ type transportInstance struct { *transportSimulator peerIdx int64 - receive chan qbft.Msg[core.Duty, [32]byte] + receive chan qbft.Msg[core.Duty, [32]byte, proto.Message] } func (i *transportInstance) Broadcast(_ context.Context, typ qbft.MsgType, duty core.Duty, source int64, round int64, value [32]byte, - pr int64, pv [32]byte, justification []qbft.Msg[core.Duty, [32]byte], + pr int64, pv [32]byte, justification []qbft.Msg[core.Duty, [32]byte, proto.Message], ) error { dummy, _ := anypb.New(timestamppb.Now()) values := map[[32]byte]*anypb.Any{ @@ -686,7 +690,7 @@ func (i *transportInstance) Broadcast(_ context.Context, typ qbft.MsgType, return nil } -func (i *transportInstance) Receive() <-chan qbft.Msg[core.Duty, [32]byte] { +func (i *transportInstance) Receive() <-chan qbft.Msg[core.Duty, [32]byte, proto.Message] { return i.receive } diff --git a/core/consensus/qbft/transport.go b/core/consensus/qbft/transport.go index a09c43973..ced773030 100644 --- a/core/consensus/qbft/transport.go +++ b/core/consensus/qbft/transport.go @@ -12,8 +12,6 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/obolnetwork/charon/app/errors" - "github.com/obolnetwork/charon/app/log" - "github.com/obolnetwork/charon/app/z" "github.com/obolnetwork/charon/core" pbv1 "github.com/obolnetwork/charon/core/corepb/v1" "github.com/obolnetwork/charon/core/qbft" @@ -29,7 +27,7 @@ type transport struct { // Immutable state broadcaster broadcaster privkey *k1.PrivateKey - recvBuffer chan qbft.Msg[core.Duty, [32]byte] // Instance inner receive buffer. + recvBuffer chan qbft.Msg[core.Duty, [32]byte, proto.Message] // Instance inner receive buffer. sniffer *sniffer // Mutable state @@ -40,7 +38,7 @@ type transport struct { // newTransport creates a new qbftTransport. func newTransport(broadcaster broadcaster, privkey *k1.PrivateKey, valueCh <-chan proto.Message, - recvBuffer chan qbft.Msg[core.Duty, [32]byte], sniffer *sniffer, + recvBuffer chan qbft.Msg[core.Duty, [32]byte, proto.Message], sniffer *sniffer, ) *transport { return &transport{ broadcaster: broadcaster, @@ -94,7 +92,7 @@ func (t *transport) getValue(hash [32]byte) (*anypb.Any, error) { // Broadcast creates a msg and sends it to all peers (including self). func (t *transport) Broadcast(ctx context.Context, typ qbft.MsgType, duty core.Duty, peerIdx int64, round int64, valueHash [32]byte, pr int64, pvHash [32]byte, - justification []qbft.Msg[core.Duty, [32]byte], + justification []qbft.Msg[core.Duty, [32]byte, proto.Message], ) error { // Get all hashes var hashes [][32]byte @@ -143,12 +141,6 @@ func (t *transport) Broadcast(ctx context.Context, typ qbft.MsgType, duty core.D } }() - log.Debug(ctx, "QBFT broadcasting msg", - z.Str("duty", duty.String()), - z.I64("round", round), - z.Str("type", typ.String()), - z.I64("peer_idx", peerIdx)) - return t.broadcaster.Broadcast(ctx, msg.ToConsensusMsg()) } @@ -165,12 +157,6 @@ func (t *transport) ProcessReceives(ctx context.Context, outerBuffer chan Msg) { case <-ctx.Done(): return case t.recvBuffer <- msg: - log.Debug(ctx, "QBFT received msg", - z.Str("duty", msg.msg.GetDuty().String()), - z.I64("round", msg.msg.GetRound()), - z.Str("type", qbft.MsgType(msg.msg.GetType()).String()), - z.I64("peer_idx", msg.msg.GetPeerIdx())) - t.sniffer.Add(msg.ToConsensusMsg()) } } @@ -183,7 +169,7 @@ func (t *transport) SnifferInstance() *pbv1.SniffedConsensusInstance { } // RecvBuffer returns the inner receive buffer. -func (t *transport) RecvBuffer() chan qbft.Msg[core.Duty, [32]byte] { +func (t *transport) RecvBuffer() chan qbft.Msg[core.Duty, [32]byte, proto.Message] { return t.recvBuffer } @@ -191,7 +177,7 @@ func (t *transport) RecvBuffer() chan qbft.Msg[core.Duty, [32]byte] { // and wrapping that in a msg type. func createMsg(typ qbft.MsgType, duty core.Duty, peerIdx int64, round int64, vHash [32]byte, pr int64, pvHash [32]byte, - values map[[32]byte]*anypb.Any, justification []qbft.Msg[core.Duty, [32]byte], + values map[[32]byte]*anypb.Any, justification []qbft.Msg[core.Duty, [32]byte, proto.Message], privkey *k1.PrivateKey, ) (Msg, error) { pbMsg := &pbv1.QBFTMsg{ diff --git a/core/deadline.go b/core/deadline.go index dbee972c3..663e711fb 100644 --- a/core/deadline.go +++ b/core/deadline.go @@ -4,7 +4,6 @@ package core import ( "context" - "sync" "testing" "time" @@ -20,9 +19,7 @@ import ( const ( // marginFactor defines the fraction of the slot duration to use as a margin. // This is to consider network delays and other factors that may affect the timing. - marginFactor = 12 - expiredBufferSize = 10 - tickerInterval = time.Second + marginFactor = 12 ) // DeadlineFunc is a function that returns the deadline for a duty. @@ -42,15 +39,19 @@ type Deadliner interface { C() <-chan Duty } +// deadlineInput represents the input to inputChan. +type deadlineInput struct { + duty Duty + success chan<- bool +} + // deadliner implements the Deadliner interface. type deadliner struct { - lock sync.Mutex label string - deadlineFunc DeadlineFunc - duties map[Duty]time.Time - expiredChan chan Duty + inputChan chan deadlineInput + deadlineChan chan Duty clock clockwork.Clock - done chan struct{} + quit chan struct{} } // NewDeadlinerForT returns a Deadline for use in tests. @@ -63,7 +64,7 @@ func NewDeadlinerForT(ctx context.Context, t *testing.T, deadlineFunc DeadlineFu // NewDeadliner returns a new instance of Deadline. // // It also starts a goroutine which is responsible for reading and storing duties, -// and sending the deadlined duty to receiver's expiredChan until the context is closed. +// and sending the deadlined duty to receiver's deadlineChan until the context is closed. func NewDeadliner(ctx context.Context, label string, deadlineFunc DeadlineFunc) Deadliner { return newDeadliner(ctx, label, deadlineFunc, clockwork.NewRealClock()) } @@ -112,103 +113,128 @@ func NewDutyDeadlineFunc(ctx context.Context, eth2Cl eth2wrap.Client) (DeadlineF // newDeadliner returns a new Deadliner, this is for internal use only. func newDeadliner(ctx context.Context, label string, deadlineFunc DeadlineFunc, clock clockwork.Clock) Deadliner { + // outputBuffer big enough to support all duty types, which can expire at the same time + // while external consumer is synchronously adding duties (so not reading output). + const outputBuffer = 10 + d := &deadliner{ label: label, - deadlineFunc: deadlineFunc, - duties: make(map[Duty]time.Time), - expiredChan: make(chan Duty, expiredBufferSize), + inputChan: make(chan deadlineInput), // Not buffering this since writer wait for response. + deadlineChan: make(chan Duty, outputBuffer), clock: clock, - done: make(chan struct{}), + quit: make(chan struct{}), } - go d.run(ctx) + go d.run(ctx, deadlineFunc) return d } -func (d *deadliner) run(ctx context.Context) { - defer close(d.done) +func (d *deadliner) run(ctx context.Context, deadlineFunc DeadlineFunc) { + duties := make(map[Duty]bool) + currDuty, currDeadline := getCurrDuty(duties, deadlineFunc) + currTimer := d.clock.NewTimer(currDeadline.Sub(d.clock.Now())) - // The simple approach does not require a min-heap or priority queue to store the duties and their deadlines, - // but it is sufficient for our use case as the number of duties is expected to be small. - // A disadvantage of this approach is the expiration precision is rounded to the nearest second. - timer := d.clock.NewTicker(tickerInterval) - defer timer.Stop() + defer func() { + close(d.quit) + currTimer.Stop() + }() + + setCurrState := func() { + currTimer.Stop() + + currDuty, currDeadline = getCurrDuty(duties, deadlineFunc) + currTimer = d.clock.NewTimer(currDeadline.Sub(d.clock.Now())) + } + // TODO(dhruv): optimise getCurrDuty and updating current state if earlier deadline detected, + // using min heap or ordered map for { select { case <-ctx.Done(): return - case <-timer.Chan(): - // Get all expired duties at the current time. - expiredDuties := d.getExpiredDuties(d.clock.Now()) - if len(expiredDuties) == 0 { + case input := <-d.inputChan: + deadline, canExpire := deadlineFunc(input.duty) + if !canExpire { + // Drop duties that never expire + input.success <- false + continue + } + + expired := deadline.Before(d.clock.Now()) + + input.success <- !expired + + // Ignore expired duties + if expired { continue } - log.Debug(ctx, "Deadliner.run() got expired duties", z.Int("count", len(expiredDuties))) + duties[input.duty] = true - for _, expiredDuty := range expiredDuties { - // Send the expired duty to the receiver. - select { - case <-ctx.Done(): - return - case d.expiredChan <- expiredDuty: - } + if deadline.Before(currDeadline) { + setCurrState() } + case <-currTimer.Chan(): + // Send deadlined duty to receiver. + select { + case <-ctx.Done(): + return + case d.deadlineChan <- currDuty: + default: + log.Warn(ctx, "Deadliner output channel full", nil, + z.Str("label", d.label), + z.Any("duty", currDuty), + ) + } + + delete(duties, currDuty) + setCurrState() } } } // Add adds a duty to be notified of the deadline. It returns true if the duty was added successfully. func (d *deadliner) Add(duty Duty) bool { - log.Debug(context.Background(), "Deadliner.Add()", z.Any("duty", duty)) + success := make(chan bool) select { - case <-d.done: - // Run goroutine has stopped, ignore new duties. - return false - default: - } - - deadline, canExpire := d.deadlineFunc(duty) - if !canExpire { - // Drop duties that never expire + case <-d.quit: return false + case d.inputChan <- deadlineInput{duty: duty, success: success}: } - expired := deadline.Before(d.clock.Now()) - if expired { - // Drop expired duties + select { + case <-d.quit: return false + case ok := <-success: + return ok } - - d.lock.Lock() - defer d.lock.Unlock() - - d.duties[duty] = deadline - - return true } // C returns the deadline channel. func (d *deadliner) C() <-chan Duty { - return d.expiredChan + return d.deadlineChan } -// getExpiredDuties selects all expired duties. -func (d *deadliner) getExpiredDuties(now time.Time) []Duty { - expiredDuties := []Duty{} +// getCurrDuty gets the duty to process next along-with the duty deadline. It selects duty with the latest deadline. +func getCurrDuty(duties map[Duty]bool, deadlineFunc DeadlineFunc) (Duty, time.Time) { + var currDuty Duty - d.lock.Lock() - defer d.lock.Unlock() + currDeadline := time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC) + + for duty := range duties { + dutyDeadline, ok := deadlineFunc(duty) + if !ok { + // Ignore the duties that never expire. + continue + } - for duty, deadline := range d.duties { - if deadline.Before(now) { - expiredDuties = append(expiredDuties, duty) - delete(d.duties, duty) + if currDeadline.After(dutyDeadline) { + currDuty = duty + currDeadline = dutyDeadline } } - return expiredDuties + return currDuty, currDeadline } diff --git a/core/deadline_test.go b/core/deadline_test.go index 756416571..258158784 100644 --- a/core/deadline_test.go +++ b/core/deadline_test.go @@ -18,6 +18,8 @@ import ( "github.com/obolnetwork/charon/testutil/beaconmock" ) +//go:generate go test . + func TestDeadliner(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() @@ -43,10 +45,6 @@ func TestDeadliner(t *testing.T) { deadliner := core.NewDeadlinerForT(ctx, t, deadlineFuncProvider(), clock) - // Wait for the run goroutine to be waiting on the ticker before interacting. - err := clock.BlockUntilContext(ctx, 1) - require.NoError(t, err) - wg := &sync.WaitGroup{} // Add our duties to the deadliner. @@ -75,9 +73,8 @@ func TestDeadliner(t *testing.T) { } } - // Advance clock past the latest deadline to trigger expiration of all non-expired duties. - // Use maxSlot+1 because Before() is strict (not <=). - clock.Advance(time.Duration(maxSlot+1) * time.Second) + // Advance clock to trigger deadline of all non-expired duties. + clock.Advance(time.Duration(maxSlot) * time.Second) var actualDuties []core.Duty for range len(nonExpiredDuties) { diff --git a/core/qbft/qbft.go b/core/qbft/qbft.go index 199736ad8..f1093f02a 100644 --- a/core/qbft/qbft.go +++ b/core/qbft/qbft.go @@ -12,39 +12,41 @@ import ( "time" "github.com/obolnetwork/charon/app/errors" + "github.com/obolnetwork/charon/app/log" ) // Transport abstracts the transport layer between processes in the consensus system. -type Transport[I any, V comparable] struct { +type Transport[I any, V comparable, C any] struct { // Broadcast sends a message with the provided fields to all other // processes in the system (including this process). // // Note that a non-nil error exits the algorithm. - Broadcast func(ctx context.Context, typ MsgType, instance I, source int64, round int64, value V, pr int64, pv V, justification []Msg[I, V]) error + Broadcast func(ctx context.Context, typ MsgType, instance I, source int64, round int64, value V, pr int64, pv V, justification []Msg[I, V, C]) error // Receive returns a stream of messages received // from other processes in the system (including this process). - Receive <-chan Msg[I, V] + Receive <-chan Msg[I, V, C] } // Definition defines the consensus system parameters that are external to the qbft algorithm. // This remains constant across multiple instances of consensus (calls to Run). -type Definition[I any, V comparable] struct { +type Definition[I any, V comparable, C any] struct { // IsLeader is a deterministic leader election function. IsLeader func(instance I, round, process int64) bool // NewTimer returns a new timer channel and stop function for the round. NewTimer func(round int64) (<-chan time.Time, func()) + // Compare is called when leader proposes value and we compare it with our local value. + // Compare is an opt-in feature that should instantly return nil on returnErr channel if it is not turned on. + Compare func(ctx context.Context, qcommit Msg[I, V, C], inputValueSourceCh <-chan C, inputValueSource C, returnErr chan error, returnValue chan C) // Decide is called when consensus has been reached on a value. - Decide func(ctx context.Context, instance I, value V, qcommit []Msg[I, V]) + Decide func(ctx context.Context, instance I, value V, qcommit []Msg[I, V, C]) // LogUponRule allows debug logging of triggered upon rules on message receipt. - LogUponRule func(ctx context.Context, instance I, process, round int64, msg Msg[I, V], uponRule UponRule) + LogUponRule func(ctx context.Context, instance I, process, round int64, msg Msg[I, V, C], uponRule UponRule) // LogRoundChange allows debug logging of round changes. // It includes the rule that triggered it and all received round messages. - LogRoundChange func(ctx context.Context, instance I, process, round, newRound int64, uponRule UponRule, msgs []Msg[I, V]) + LogRoundChange func(ctx context.Context, instance I, process, round, newRound int64, uponRule UponRule, msgs []Msg[I, V, C]) // LogUnjust allows debug logging of unjust messages. - LogUnjust func(ctx context.Context, instance I, process int64, msg Msg[I, V]) - // LogDebug allows arbitrary debug logging. - LogDebug func(ctx context.Context, instance I, process int64, msg Msg[I, V], logMsg string) + LogUnjust func(ctx context.Context, instance I, process int64, msg Msg[I, V, C]) // Nodes is the total number of nodes/processes participating in consensus. Nodes int @@ -54,13 +56,13 @@ type Definition[I any, V comparable] struct { // Quorum returns the quorum count for the system. // See IBFT 2.0 paper for correct formula: https://arxiv.org/pdf/1909.10194.pdf -func (d Definition[I, V]) Quorum() int { +func (d Definition[I, V, C]) Quorum() int { return int(math.Ceil(float64(d.Nodes*2) / 3)) } // Faulty returns the maximum number of faulty/byzantium nodes supported in the system. // See IBFT 2.0 paper for correct formula: https://arxiv.org/pdf/1909.10194.pdf -func (d Definition[I, V]) Faulty() int { +func (d Definition[I, V, C]) Faulty() int { return int(math.Floor(float64(d.Nodes-1) / 3)) } @@ -98,7 +100,7 @@ var typeLabels = map[MsgType]string{ } // Msg defines the inter process messages. -type Msg[I any, V comparable] interface { +type Msg[I any, V comparable, C any] interface { // Type of the message. Type() MsgType // Instance identifies the consensus instance. @@ -109,12 +111,14 @@ type Msg[I any, V comparable] interface { Round() int64 // Value being proposed, usually a hash. Value() V + // ValueSource being proposed, usually the value that was hashed and is returned in Value(). + ValueSource() (C, error) // PreparedRound is the justified prepared round. PreparedRound() int64 // PreparedValue is the justified prepared value. PreparedValue() V // Justification is the set of messages that explicitly justifies this message. - Justification() []Msg[I, V] + Justification() []Msg[I, V, C] } // UponRule defines the event based rules that are triggered when messages are received. @@ -154,6 +158,12 @@ type dedupKey struct { Round int64 } +// errors +var ( + errCompare = errors.New("compare leader value with local value failed") + errTimeout = errors.New("timeout") +) + // InputValue is a convenience function to create a populated input value channel. func InputValue[V comparable](inputValue V) <-chan V { ch := make(chan V, 1) @@ -162,10 +172,19 @@ func InputValue[V comparable](inputValue V) <-chan V { return ch } +// InputValueSource is a convenience function to create a populated input value source channel. +func InputValueSource[C any](inputValueSource C) <-chan C { + ch := make(chan C, 1) + ch <- inputValueSource + + return ch +} + // Run executes the consensus algorithm until the context closed. // The generic type I is the instance of consensus and can be anything. // The generic type V is the arbitrary data value being proposed; it only requires an Equal method. -func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transport[I, V], instance I, process int64, inputValueCh <-chan V) (err error) { +// The generic type C is the compare value, used to compare leader's proposed value with local value and can be anything. +func Run[I any, V comparable, C any](ctx context.Context, d Definition[I, V, C], t Transport[I, V, C], instance I, process int64, inputValueCh <-chan V, inputValueSourceCh <-chan C) (err error) { defer func() { // Panics are used for assertions and sanity checks to reduce lines of code // and to improve readability. Catch them here. @@ -183,12 +202,14 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo var ( round int64 = 1 inputValue V - ppjCache []Msg[I, V] // Cached pre-prepare justification for the current round (nil value is unset). + inputValueSource C + ppjCache []Msg[I, V, C] // Cached pre-prepare justification for the current round (nil value is unset). preparedRound int64 preparedValue V - preparedJustification []Msg[I, V] - qCommit []Msg[I, V] - buffer = make(map[int64][]Msg[I, V]) + compareFailureRound int64 + preparedJustification []Msg[I, V, C] + qCommit []Msg[I, V, C] + buffer = make(map[int64][]Msg[I, V, C]) dedupRules = make(map[dedupKey]bool) timerChan <-chan time.Time stopTimer func() @@ -197,7 +218,7 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo // === Helpers == // broadcastMsg broadcasts a non-ROUND-CHANGE message for current round. - broadcastMsg := func(typ MsgType, value V, justification []Msg[I, V]) error { + broadcastMsg := func(typ MsgType, value V, justification []Msg[I, V, C]) error { return t.Broadcast(ctx, typ, instance, process, round, value, 0, zeroVal[V](), justification) } @@ -211,7 +232,7 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo // broadcastOwnPrePrepare broadcasts a PRE-PREPARE message with current state // and our own input value if present, otherwise it caches the justification // to be used when the input value becomes available. - broadcastOwnPrePrepare := func(justification []Msg[I, V]) error { + broadcastOwnPrePrepare := func(justification []Msg[I, V, C]) error { if justification == nil { panic("bug: justification must not be nil") } else if ppjCache != nil { @@ -228,13 +249,11 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo } // bufferMsg adds the message to each process' FIFO queue. - bufferMsg := func(msg Msg[I, V]) { + bufferMsg := func(msg Msg[I, V, C]) { fifo := buffer[msg.Source()] fifo = append(fifo, msg) if len(fifo) > d.FIFOLimit { - d.LogDebug(ctx, instance, process, msg, "new QBFT message - FIFO Limit reached") - fifo = fifo[len(fifo)-d.FIFOLimit:] } @@ -269,83 +288,52 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo // === Algorithm === { // Algorithm 1:11 - d.LogDebug(ctx, instance, process, nil, "QBFT check if leader") - if d.IsLeader(instance, round, process) { // Note round==1 at this point. - d.LogDebug(ctx, instance, process, nil, "QBFT leader, broadcast own pre-prepare") - - err := broadcastOwnPrePrepare([]Msg[I, V]{}) // Empty justification since round==1 - - d.LogDebug(ctx, instance, process, nil, "QBFT leader, broadcast own pre-prepare finished") - + err := broadcastOwnPrePrepare([]Msg[I, V, C]{}) // Empty justification since round==1 if err != nil { return err } } - d.LogDebug(ctx, instance, process, nil, "QBFT new timer") timerChan, stopTimer = d.NewTimer(round) } - d.LogDebug(ctx, instance, process, nil, "QBFT start handling events") // Handle events until cancelled. for { var err error select { case inputValue = <-inputValueCh: - d.LogDebug(ctx, instance, process, nil, "QBFT new inputValue - step 1 received") - if isZeroVal(inputValue) { return errors.New("zero input value not supported") } if ppjCache != nil { // Broadcast the pre-prepare now that we have a input value using the cached justification. - d.LogDebug(ctx, instance, process, nil, "QBFT new inputValue - step 2 broadcast") - err = broadcastMsg(MsgPrePrepare, inputValue, ppjCache) - - d.LogDebug(ctx, instance, process, nil, "QBFT new inputValue - step 3 broadcast finished") } inputValueCh = nil // Don't read from this channel again. - d.LogDebug(ctx, instance, process, nil, "QBFT new inputValue - step 2-4 finished") - case msg := <-t.Receive: - d.LogDebug(ctx, instance, process, msg, "new QBFT message - received") // Just send Qcommit if consensus already decided if len(qCommit) > 0 { - d.LogDebug(ctx, instance, process, msg, "new QBFT message - qCommits greater than 0") - if msg.Source() != process && msg.Type() == MsgRoundChange { // Algorithm 3:17 - d.LogDebug(ctx, instance, process, msg, "new QBFT message - consensus already decided, broadcast") - err = broadcastMsg(MsgDecided, qCommit[0].Value(), qCommit) - - d.LogDebug(ctx, instance, process, msg, "new QBFT message - consensus already decided, broadcast finished") } break } - if !isJustified(d, instance, msg) { // Drop unjust messages - d.LogDebug(ctx, instance, process, msg, "new QBFT message - unjustified, dropping") + if !isJustified(d, instance, msg, compareFailureRound) { // Drop unjust messages d.LogUnjust(ctx, instance, process, msg) - break } - d.LogDebug(ctx, instance, process, msg, "new QBFT message - buffering") bufferMsg(msg) - d.LogDebug(ctx, instance, process, msg, "new QBFT message - buffered") - - d.LogDebug(ctx, instance, process, msg, "new QBFT message - classifying") rule, justification := classify(d, instance, round, process, buffer, msg) if rule == UponNothing || isDuplicatedRule(rule, msg.Round()) { - d.LogDebug(ctx, instance, process, msg, "new QBFT message - classified as duplicate") // Do nothing more if no rule or duplicate rule was triggered break } @@ -354,80 +342,75 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo switch rule { case UponJustifiedPrePrepare: // Algorithm 2:1 - d.LogDebug(ctx, instance, process, msg, "UponJustifiedPrePrepare QBFT step 1 - starting procedure") // Applicable to current or future rounds (since justified) changeRound(msg.Round(), rule) - d.LogDebug(ctx, instance, process, msg, "UponJustifiedPrePrepare QBFT step 2 - stop timer") stopTimer() timerChan, stopTimer = d.NewTimer(round) - d.LogDebug(ctx, instance, process, msg, "UponJustifiedPrePrepare QBFT step 3 - broadcast") - err = broadcastMsg(MsgPrepare, msg.Value(), nil) - d.LogDebug(ctx, instance, process, msg, "UponJustifiedPrePrepare QBFT step 3 - broadcast finished") - + var errC error + + inputValueSource, errC = compare(ctx, d, msg, inputValueSourceCh, inputValueSource, timerChan) + if errC != nil { + switch { + case errors.Is(errC, errCompare): + compareFailureRound = msg.Round() + case errors.Is(errC, errTimeout): + // As compare function is blocking on waiting local data, round might timeout in the meantime. + // If this happens, we trigger round change. + // Algorithm 3:1 + changeRound(round+1, UponRoundTimeout) + + stopTimer() + timerChan, stopTimer = d.NewTimer(round) + + err = broadcastRoundChange() + default: + err = errors.New("bug: expected only comparison or timeout error") + } + } else { + err = broadcastMsg(MsgPrepare, msg.Value(), nil) + } case UponQuorumPrepares: // Algorithm 2:4 - d.LogDebug(ctx, instance, process, msg, "UponQuorumPrepares QBFT step 1 - starting procedure") // Only applicable to current round preparedRound = round /* == msg.Round*/ preparedValue = msg.Value() preparedJustification = justification - d.LogDebug(ctx, instance, process, msg, "UponQuorumPrepares QBFT step 2 - broadcast") - err = broadcastMsg(MsgCommit, preparedValue, nil) - d.LogDebug(ctx, instance, process, msg, "UponQuorumPrepares QBFT step 3 - broadcast finished") - case UponQuorumCommits, UponJustifiedDecided: // Algorithm 2:8 - d.LogDebug(ctx, instance, process, msg, "UponQuorumCommits, UponJustifiedDecided QBFT step 1 - starting procedure") // Applicable to any round (since can be justified) changeRound(msg.Round(), rule) qCommit = justification - d.LogDebug(ctx, instance, process, msg, "UponQuorumCommits, UponJustifiedDecided QBFT step 2 - stop timer") stopTimer() timerChan = nil - d.LogDebug(ctx, instance, process, msg, "UponQuorumCommits, UponJustifiedDecided QBFT step 3 - decide") d.Decide(ctx, instance, msg.Value(), justification) - d.LogDebug(ctx, instance, process, msg, "UponQuorumCommits, UponJustifiedDecided QBFT step 4 - decided") case UponFPlus1RoundChanges: // Algorithm 3:5 - d.LogDebug(ctx, instance, process, msg, "UponFPlus1RoundChanges QBFT step 1 - starting procedure") // Only applicable to future rounds changeRound(nextMinRound(d, justification, round /* < msg.Round */), rule) - d.LogDebug(ctx, instance, process, msg, "UponFPlus1RoundChanges QBFT step 2 - stop timer") stopTimer() timerChan, stopTimer = d.NewTimer(round) - d.LogDebug(ctx, instance, process, msg, "UponFPlus1RoundChanges QBFT step 3 - broadcast round change") - err = broadcastRoundChange() - d.LogDebug(ctx, instance, process, msg, "UponFPlus1RoundChanges QBFT step 4 - broadcasted round change") - case UponQuorumRoundChanges: // Algorithm 3:11 - d.LogDebug(ctx, instance, process, msg, "UponQuorumRoundChanges QBFT step 1 - starting procedure") // Only applicable to current round (round > 1) - if _, pv, ok := getSingleJustifiedPrPv(d, justification); ok { + pr, pv, ok := getSingleJustifiedPrPv(d, justification) + if ok && compareFailureRound != pr { // Send pre-prepare using prepared value (not our own input value) - d.LogDebug(ctx, instance, process, msg, "UponQuorumRoundChanges QBFT step 2 - broadcast pre-prepare with prepared value") - err = broadcastMsg(MsgPrePrepare, pv, justification) } else { // Send pre-prepare using our own input value - d.LogDebug(ctx, instance, process, msg, "UponQuorumRoundChanges QBFT step 2 - broadcast pre-prepare with own input value") - err = broadcastOwnPrePrepare(justification) } - - d.LogDebug(ctx, instance, process, msg, "UponQuorumRoundChanges QBFT step 3 - broadcasted pre-prepare") case UponUnjustQuorumRoundChanges: - d.LogDebug(ctx, instance, process, msg, "UponUnjustQuorumRoundChanges QBFT step 1 - starting procedure") // Ignore bug or byzantine default: @@ -435,34 +418,56 @@ func Run[I any, V comparable](ctx context.Context, d Definition[I, V], t Transpo } case <-timerChan: // Algorithm 3:1 - d.LogDebug(ctx, instance, process, nil, "RoundTimeout QBFT - step 1 starting procedure") changeRound(round+1, UponRoundTimeout) - d.LogDebug(ctx, instance, process, nil, "RoundTimeout QBFT - step 2 stop timer") stopTimer() timerChan, stopTimer = d.NewTimer(round) - d.LogDebug(ctx, instance, process, nil, "RoundTimeout QBFT - step 3 broadcast round change") - err = broadcastRoundChange() - d.LogDebug(ctx, instance, process, nil, "RoundTimeout QBFT - step 4 broadcast finished") - case <-ctx.Done(): // Cancelled - d.LogDebug(ctx, instance, process, nil, "QBFT event handling context done") return ctx.Err() } if err != nil { // Errors are considered fatal. - d.LogDebug(ctx, instance, process, nil, "QBFT fatal error") return err } } } +func compare[I any, V comparable, C any](ctx context.Context, d Definition[I, V, C], msg Msg[I, V, C], inputValueSourceCh <-chan C, inputValueSource C, timerChan <-chan time.Time) (C, error) { + compareErr := make(chan error, 1) + compareValue := make(chan C, 1) + + ctxCompare, cancel := context.WithCancel(ctx) + defer cancel() + + // d.Compare has 2 roles: + // 1. Read from the inputValueSourceCh (if inputValueSource is empty). If it read from the channel, it returns the value on compareValue channel. + // 2. Compare the value read from inputValueSourceCh (or inputValueSource if it is not empty) to the value proposed by the leader. + // If comparison or any other unexpected error occurs, the error is returned on compareErr channel. + go d.Compare(ctxCompare, msg, inputValueSourceCh, inputValueSource, compareErr, compareValue) + + for { + select { + case err := <-compareErr: + if err != nil { + log.Warn(ctx, errCompare.Error(), err) + return inputValueSource, errCompare + } + + return inputValueSource, nil + case inputValueSource = <-compareValue: + case <-timerChan: + log.Warn(ctx, "", errors.New("timeout waiting for local data, used for comparing with leader's proposed data")) + return inputValueSource, errTimeout + } + } +} + // extractRoundMsgs returns all messages from the provided round. -func extractRoundMsgs[I any, V comparable](buffer map[int64][]Msg[I, V], round int64) []Msg[I, V] { - var resp []Msg[I, V] +func extractRoundMsgs[I any, V comparable, C any](buffer map[int64][]Msg[I, V, C], round int64) []Msg[I, V, C] { + var resp []Msg[I, V, C] for _, msgs := range buffer { for _, msg := range msgs { @@ -476,7 +481,7 @@ func extractRoundMsgs[I any, V comparable](buffer map[int64][]Msg[I, V], round i } // classify returns the rule triggered upon receipt of the last message and its justifications. -func classify[I any, V comparable](d Definition[I, V], instance I, round, process int64, buffer map[int64][]Msg[I, V], msg Msg[I, V]) (UponRule, []Msg[I, V]) { +func classify[I any, V comparable, C any](d Definition[I, V, C], instance I, round, process int64, buffer map[int64][]Msg[I, V, C], msg Msg[I, V, C]) (UponRule, []Msg[I, V, C]) { switch msg.Type() { case MsgDecided: return UponJustifiedDecided, msg.Justification() @@ -554,7 +559,7 @@ func classify[I any, V comparable](d Definition[I, V], instance I, round, proces // nextMinRound implements algorithm 3:6 and returns the next minimum round // from received round change messages. -func nextMinRound[I any, V comparable](d Definition[I, V], frc []Msg[I, V], round int64) int64 { +func nextMinRound[I any, V comparable, C any](d Definition[I, V, C], frc []Msg[I, V, C], round int64) int64 { // Get all RoundChange messages with round (rj) higher than current round (ri) if len(frc) < d.Faulty()+1 { panic("bug: Frc too short") @@ -579,11 +584,11 @@ func nextMinRound[I any, V comparable](d Definition[I, V], frc []Msg[I, V], roun } // isJustified returns true if message is justified or if it does not need justification. -func isJustified[I any, V comparable](d Definition[I, V], instance I, msg Msg[I, V]) bool { +func isJustified[I any, V comparable, C any](d Definition[I, V, C], instance I, msg Msg[I, V, C], compareFailureRound int64) bool { //nolint:revive // `case MsgPrepare` and `case MsgCommit` having same result is not an issue, it improves readability. switch msg.Type() { case MsgPrePrepare: - return isJustifiedPrePrepare(d, instance, msg) + return isJustifiedPrePrepare(d, instance, msg, compareFailureRound) case MsgPrepare: return true case MsgCommit: @@ -599,7 +604,7 @@ func isJustified[I any, V comparable](d Definition[I, V], instance I, msg Msg[I, // isJustifiedRoundChange returns true if the ROUND_CHANGE message's // prepared round and value is justified. -func isJustifiedRoundChange[I any, V comparable](d Definition[I, V], msg Msg[I, V]) bool { +func isJustifiedRoundChange[I any, V comparable, C any](d Definition[I, V, C], msg Msg[I, V, C]) bool { if msg.Type() != MsgRoundChange { panic("bug: not a round change message") } @@ -619,7 +624,7 @@ func isJustifiedRoundChange[I any, V comparable](d Definition[I, V], msg Msg[I, return false } - uniq := uniqSource[I, V]() + uniq := uniqSource[I, V, C]() for _, prepare := range prepares { if !uniq(prepare) { return false @@ -643,7 +648,7 @@ func isJustifiedRoundChange[I any, V comparable](d Definition[I, V], msg Msg[I, // isJustifiedDecided returns true if the decided message is justified by quorum COMMIT messages // of identical round and value. -func isJustifiedDecided[I any, V comparable](d Definition[I, V], msg Msg[I, V]) bool { +func isJustifiedDecided[I any, V comparable, C any](d Definition[I, V, C], msg Msg[I, V, C]) bool { if msg.Type() != MsgDecided { panic("bug: not a decided message") } @@ -655,7 +660,7 @@ func isJustifiedDecided[I any, V comparable](d Definition[I, V], msg Msg[I, V]) } // isJustifiedPrePrepare returns true if the PRE-PREPARE message is justified. -func isJustifiedPrePrepare[I any, V comparable](d Definition[I, V], instance I, msg Msg[I, V]) bool { +func isJustifiedPrePrepare[I any, V comparable, C any](d Definition[I, V, C], instance I, msg Msg[I, V, C], compareFailureRound int64) bool { if msg.Type() != MsgPrePrepare { panic("bug: not a preprepare message") } @@ -664,7 +669,8 @@ func isJustifiedPrePrepare[I any, V comparable](d Definition[I, V], instance I, return false } - if msg.Round() == 1 { + // Justified if PrePrepare is the first round OR if comparison failed previous round. + if msg.Round() == 1 || (msg.Round() == compareFailureRound+1) { return true } @@ -682,7 +688,7 @@ func isJustifiedPrePrepare[I any, V comparable](d Definition[I, V], instance I, // containsJustifiedQrc implements algorithm 4:1 and returns true and pv if // the messages contains a justified quorum ROUND_CHANGEs (Qrc). -func containsJustifiedQrc[I any, V comparable](d Definition[I, V], justification []Msg[I, V], round int64) (V, bool) { +func containsJustifiedQrc[I any, V comparable, C any](d Definition[I, V, C], justification []Msg[I, V, C], round int64) (V, bool) { qrc := filterRoundChange(justification, round) if len(qrc) < d.Quorum() { return zeroVal[V](), false @@ -732,12 +738,12 @@ func containsJustifiedQrc[I any, V comparable](d Definition[I, V], justification // getSingleJustifiedPrPv extracts the single justified Pr and Pv from quorum // PREPARES in list of messages. It expects only one possible combination. -func getSingleJustifiedPrPv[I any, V comparable](d Definition[I, V], msgs []Msg[I, V]) (int64, V, bool) { +func getSingleJustifiedPrPv[I any, V comparable, C any](d Definition[I, V, C], msgs []Msg[I, V, C]) (int64, V, bool) { var ( pr int64 pv V count int - uniq = uniqSource[I, V]() + uniq = uniqSource[I, V, C]() ) for _, msg := range msgs { @@ -763,7 +769,7 @@ func getSingleJustifiedPrPv[I any, V comparable](d Definition[I, V], msgs []Msg[ } // getJustifiedQrc implements algorithm 4:1 and returns a justified quorum ROUND_CHANGEs (Qrc). -func getJustifiedQrc[I any, V comparable](d Definition[I, V], all []Msg[I, V], round int64) ([]Msg[I, V], bool) { +func getJustifiedQrc[I any, V comparable, C any](d Definition[I, V, C], all []Msg[I, V, C], round int64) ([]Msg[I, V, C], bool) { if qrc, ok := quorumNullPrepared(d, all, round); ok { // Return any quorum null pv ROUND_CHANGE messages as Qrc. return qrc, true @@ -774,11 +780,11 @@ func getJustifiedQrc[I any, V comparable](d Definition[I, V], all []Msg[I, V], r for _, prepares := range getPrepareQuorums(d, all) { // See if we have quorum ROUND-CHANGE with HIGHEST_PREPARED(qrc) == prepares.Round. var ( - qrc []Msg[I, V] + qrc []Msg[I, V, C] hasHighestPrepared bool pr = prepares[0].Round() pv = prepares[0].Value() - uniq = uniqSource[I, V]() + uniq = uniqSource[I, V, C]() ) for _, rc := range roundChanges { if rc.PreparedRound() > pr { @@ -807,8 +813,8 @@ func getJustifiedQrc[I any, V comparable](d Definition[I, V], all []Msg[I, V], r // getFPlus1RoundChanges returns true and Faulty+1 ROUND-CHANGE messages (Frc) with // the rounds higher than the provided round. It returns the highest round // per process in order to jump furthest. -func getFPlus1RoundChanges[I any, V comparable](d Definition[I, V], all []Msg[I, V], round int64) ([]Msg[I, V], bool) { - highestBySource := make(map[int64]Msg[I, V]) +func getFPlus1RoundChanges[I any, V comparable, C any](d Definition[I, V, C], all []Msg[I, V, C], round int64) ([]Msg[I, V, C], bool) { + highestBySource := make(map[int64]Msg[I, V, C]) for _, msg := range all { if msg.Type() != MsgRoundChange { @@ -834,7 +840,7 @@ func getFPlus1RoundChanges[I any, V comparable](d Definition[I, V], all []Msg[I, return nil, false } - var resp []Msg[I, V] + var resp []Msg[I, V, C] for _, msg := range highestBySource { resp = append(resp, msg) } @@ -843,26 +849,26 @@ func getFPlus1RoundChanges[I any, V comparable](d Definition[I, V], all []Msg[I, } // preparedKey defines the round and value of set of identical PREPARE messages. -type preparedKey[I any, V comparable] struct { +type preparedKey[I any, V comparable, C any] struct { round int64 value V } // getPrepareQuorums returns all sets of quorum PREPARE messages // with identical rounds and values. -func getPrepareQuorums[I any, V comparable](d Definition[I, V], all []Msg[I, V]) [][]Msg[I, V] { - sets := make(map[preparedKey[I, V]]map[int64]Msg[I, V]) // map[preparedKey]map[process]Msg +func getPrepareQuorums[I any, V comparable, C any](d Definition[I, V, C], all []Msg[I, V, C]) [][]Msg[I, V, C] { + sets := make(map[preparedKey[I, V, C]]map[int64]Msg[I, V, C]) // map[preparedKey]map[process]Msg for _, msg := range all { // Flatten to get PREPARES included as ROUND-CHANGE justifications. if msg.Type() != MsgPrepare { continue } - key := preparedKey[I, V]{round: msg.Round(), value: msg.Value()} + key := preparedKey[I, V, C]{round: msg.Round(), value: msg.Value()} msgs, ok := sets[key] if !ok { - msgs = make(map[int64]Msg[I, V]) + msgs = make(map[int64]Msg[I, V, C]) } msgs[msg.Source()] = msg @@ -870,14 +876,14 @@ func getPrepareQuorums[I any, V comparable](d Definition[I, V], all []Msg[I, V]) } // Return all quorums - var quorums [][]Msg[I, V] + var quorums [][]Msg[I, V, C] for _, msgs := range sets { if len(msgs) < d.Quorum() { continue } - var quorum []Msg[I, V] + var quorum []Msg[I, V, C] for _, msg := range msgs { quorum = append(quorum, msg) } @@ -890,7 +896,7 @@ func getPrepareQuorums[I any, V comparable](d Definition[I, V], all []Msg[I, V]) // quorumNullPrepared implements condition J1 and returns Qrc and true if a quorum // of round changes messages (Qrc) for the round have null prepared round and value. -func quorumNullPrepared[I any, V comparable](d Definition[I, V], all []Msg[I, V], round int64) ([]Msg[I, V], bool) { +func quorumNullPrepared[I any, V comparable, C any](d Definition[I, V, C], all []Msg[I, V, C], round int64) ([]Msg[I, V, C], bool) { var ( nullPr int64 nullPv V @@ -902,21 +908,21 @@ func quorumNullPrepared[I any, V comparable](d Definition[I, V], all []Msg[I, V] } // filterByRoundAndValue returns the messages matching the type and value. -func filterByRoundAndValue[I any, V comparable](msgs []Msg[I, V], typ MsgType, round int64, value V) []Msg[I, V] { +func filterByRoundAndValue[I any, V comparable, C any](msgs []Msg[I, V, C], typ MsgType, round int64, value V) []Msg[I, V, C] { return filterMsgs(msgs, typ, round, &value, nil, nil) } // filterRoundChange returns all round change messages for the provided round. -func filterRoundChange[I any, V comparable](msgs []Msg[I, V], round int64) []Msg[I, V] { +func filterRoundChange[I any, V comparable, C any](msgs []Msg[I, V, C], round int64) []Msg[I, V, C] { return filterMsgs(msgs, MsgRoundChange, round, nil, nil, nil) } // filterMsgs returns one message per process matching the provided type and round // and optional value, pr, pv. -func filterMsgs[I any, V comparable](msgs []Msg[I, V], typ MsgType, round int64, value *V, pr *int64, pv *V) []Msg[I, V] { +func filterMsgs[I any, V comparable, C any](msgs []Msg[I, V, C], typ MsgType, round int64, value *V, pr *int64, pv *V) []Msg[I, V, C] { var ( - resp []Msg[I, V] - uniq = uniqSource[I, V]() + resp []Msg[I, V, C] + uniq = uniqSource[I, V, C]() ) for _, msg := range msgs { @@ -961,8 +967,8 @@ func isZeroVal[V comparable](v V) bool { // flatten returns the buffer as a list containing all the buffered messages // as well as all their justifications. -func flatten[I any, V comparable](buffer map[int64][]Msg[I, V]) []Msg[I, V] { - var resp []Msg[I, V] +func flatten[I any, V comparable, C any](buffer map[int64][]Msg[I, V, C]) []Msg[I, V, C] { + var resp []Msg[I, V, C] for _, msgs := range buffer { for _, msg := range msgs { @@ -980,7 +986,7 @@ func flatten[I any, V comparable](buffer map[int64][]Msg[I, V]) []Msg[I, V] { } // uniqSource returns a function that returns true if the message is from a unique source. -func uniqSource[I any, V comparable](msgs ...Msg[I, V]) func(Msg[I, V]) bool { +func uniqSource[I any, V comparable, C any](msgs ...Msg[I, V, C]) func(Msg[I, V, C]) bool { dedup := make(map[int64]bool) for _, msg := range msgs { if dedup[msg.Source()] { @@ -990,7 +996,7 @@ func uniqSource[I any, V comparable](msgs ...Msg[I, V]) func(Msg[I, V]) bool { dedup[msg.Source()] = true } - return func(msg Msg[I, V]) bool { + return func(msg Msg[I, V, C]) bool { if dedup[msg.Source()] { return false } diff --git a/core/qbft/qbft_internal_test.go b/core/qbft/qbft_internal_test.go index f28064df3..a078d6742 100644 --- a/core/qbft/qbft_internal_test.go +++ b/core/qbft/qbft_internal_test.go @@ -293,15 +293,15 @@ func testQBFT(t *testing.T, test test) { var ( ctx, cancel = context.WithCancel(context.Background()) clock = new(fakeClock) - receives = make(map[int64]chan Msg[int64, int64]) - broadcast = make(chan Msg[int64, int64]) - resultChan = make(chan []Msg[int64, int64], n) + receives = make(map[int64]chan Msg[int64, int64, int64]) + broadcast = make(chan Msg[int64, int64, int64]) + resultChan = make(chan []Msg[int64, int64, int64], n) runChan = make(chan error, n) ) defer cancel() isLeader := makeIsLeader(n) - defs := Definition[int64, int64]{ + defs := Definition[int64, int64, int64]{ IsLeader: isLeader, NewTimer: func(round int64) (<-chan time.Time, func()) { d := time.Second @@ -311,20 +311,23 @@ func testQBFT(t *testing.T, test test) { return clock.NewTimer(d) }, - Decide: func(_ context.Context, instance int64, value int64, qcommit []Msg[int64, int64]) { + Decide: func(_ context.Context, instance int64, value int64, qcommit []Msg[int64, int64, int64]) { resultChan <- qcommit }, - LogRoundChange: func(ctx context.Context, instance int64, process, round, newRound int64, rule UponRule, msgs []Msg[int64, int64]) { + Compare: func(ctx context.Context, qcommit Msg[int64, int64, int64], inputValueSourceCh <-chan int64, inputValueSource int64, returnErr chan error, returnRes chan int64) { + returnErr <- nil + }, + LogRoundChange: func(ctx context.Context, instance int64, process, round, newRound int64, rule UponRule, msgs []Msg[int64, int64, int64]) { t.Logf("%s %v@%d change to %d ~= %v", clock.NowStr(), process, round, newRound, rule) }, - LogUponRule: func(_ context.Context, instance int64, process, round int64, msg Msg[int64, int64], rule UponRule) { + LogUponRule: func(_ context.Context, instance int64, process, round int64, msg Msg[int64, int64, int64], rule UponRule) { t.Logf("%s %d => %v@%d -> %v@%d ~= %v", clock.NowStr(), msg.Source(), msg.Type(), msg.Round(), process, round, rule) if round > maxRound { cancel() } }, - LogUnjust: func(_ context.Context, instance int64, process int64, msg Msg[int64, int64]) { + LogUnjust: func(_ context.Context, instance int64, process int64, msg Msg[int64, int64, int64]) { if test.Fuzz { return // Ignore unjust messages when fuzzing. } @@ -332,23 +335,16 @@ func testQBFT(t *testing.T, test test) { t.Logf("Unjust: %#v", msg) cancel() }, - LogDebug: func(_ context.Context, instance int64, process int64, msg Msg[int64, int64], logMsg string) { - if test.Fuzz { - return // Ignore debug messages when fuzzing. - } - - t.Logf("Debug: %s - %#v", logMsg, msg) - }, Nodes: n, FIFOLimit: fifoLimit, } for i := int64(1); i <= n; i++ { - receive := make(chan Msg[int64, int64], 1000) + receive := make(chan Msg[int64, int64, int64], 1000) receives[i] = receive - trans := Transport[int64, int64]{ + trans := Transport[int64, int64, int64]{ Broadcast: func(ctx context.Context, typ MsgType, instance int64, source int64, round int64, value int64, - pr int64, pv int64, justify []Msg[int64, int64], + pr int64, pv int64, justify []Msg[int64, int64, int64], ) error { if round > maxRound { return errors.New("max round reach") @@ -395,6 +391,7 @@ func testQBFT(t *testing.T, test test) { // - or expect multiple rounds // - or otherwise only the leader of round 1. vChan := make(chan int64, 1) + vsChan := make(chan int64, 1) if delay, ok := test.ValueDelay[i]; ok { go func() { @@ -411,7 +408,7 @@ func testQBFT(t *testing.T, test test) { go func() { vChan <- i }() } - runChan <- Run(ctx, defs, trans, test.Instance, i, vChan) + runChan <- Run(ctx, defs, trans, test.Instance, i, vChan, vsChan) }(i) } @@ -420,7 +417,7 @@ func testQBFT(t *testing.T, test test) { } var ( - results = make(map[int64]Msg[int64, int64]) + results = make(map[int64]Msg[int64, int64, int64]) count int decided bool done int @@ -496,7 +493,7 @@ func testQBFT(t *testing.T, test test) { } // fuzz broadcasts random messages from the peer every 100ms (10/round). -func fuzz(ctx context.Context, clock *fakeClock, broadcast chan Msg[int64, int64], instance, peerIdx int64) { +func fuzz(ctx context.Context, clock *fakeClock, broadcast chan Msg[int64, int64, int64], instance, peerIdx int64) { for { timer, stop := clock.NewTimer(time.Millisecond * 100) select { @@ -524,7 +521,7 @@ func randomMsg(instance, peerIdx int64) msg { } // bcast delays the message broadcast by between 1x and 2x jitterMS and drops messages. -func bcast(t *testing.T, broadcast chan Msg[int64, int64], msg Msg[int64, int64], jitterMS int, clock *fakeClock) { +func bcast(t *testing.T, broadcast chan Msg[int64, int64, int64], msg Msg[int64, int64, int64], jitterMS int, clock *fakeClock) { t.Helper() if jitterMS == 0 { @@ -545,8 +542,8 @@ func bcast(t *testing.T, broadcast chan Msg[int64, int64], msg Msg[int64, int64] // newMsg returns a new message to be broadcast. func newMsg(typ MsgType, instance int64, source int64, round int64, value int64, valueSource int64, - pr int64, pv int64, justify []Msg[int64, int64], -) Msg[int64, int64] { + pr int64, pv int64, justify []Msg[int64, int64, int64], +) Msg[int64, int64, int64] { var msgs []msg for _, j := range justify { @@ -568,7 +565,7 @@ func newMsg(typ MsgType, instance int64, source int64, round int64, value int64, } } -var _ Msg[int64, int64] = msg{} +var _ Msg[int64, int64, int64] = msg{} type msg struct { msgType MsgType @@ -614,8 +611,8 @@ func (m msg) PreparedValue() int64 { return m.pv } -func (m msg) Justification() []Msg[int64, int64] { - var resp []Msg[int64, int64] +func (m msg) Justification() []Msg[int64, int64, int64] { + var resp []Msg[int64, int64, int64] for _, msg := range m.justify { resp = append(resp, msg) } @@ -640,12 +637,12 @@ func TestIsJustifiedPrePrepare(t *testing.T) { {msgType: 2, instance: 1, peerIdx: 2, round: 2, value: 2, pr: 0, pv: 0}, }} - def := Definition[int64, int64]{ + def := Definition[int64, int64, int64]{ IsLeader: makeIsLeader(n), Nodes: n, } - ok := isJustifiedPrePrepare(def, instance, preprepare) + ok := isJustifiedPrePrepare(def, instance, preprepare, 0) require.True(t, ok) } @@ -654,7 +651,7 @@ func TestFormulas(t *testing.T) { assert := func(t *testing.T, n, q, f int) { t.Helper() - d := Definition[any, int64]{Nodes: n} + d := Definition[any, int64, any]{Nodes: n} require.Equalf(t, q, d.Quorum(), "Quorum given N=%d", n) require.Equalf(t, f, d.Faulty(), "Faulty given N=%d", n) } @@ -700,7 +697,7 @@ func TestDuplicatePrePreparesRules(t *testing.T) { leader = 2 ) - newPreprepare := func(round int64) Msg[int64, int64] { + newPreprepare := func(round int64) Msg[int64, int64, int64] { return msg{ msgType: MsgPrePrepare, peerIdx: leader, @@ -713,7 +710,7 @@ func TestDuplicatePrePreparesRules(t *testing.T) { def.IsLeader = func(_ int64, _ int64, process int64) bool { return process == leader } - def.LogUponRule = func(ctx context.Context, instance int64, process, round int64, msg Msg[int64, int64], uponRule UponRule) { + def.LogUponRule = func(ctx context.Context, instance int64, process, round int64, msg Msg[int64, int64, int64], uponRule UponRule) { log.Info(ctx, "UponRule", z.Str("rule", uponRule.String()), z.I64("round", msg.Round())) require.Equal(t, uponRule, UponJustifiedPrePrepare) @@ -728,8 +725,11 @@ func TestDuplicatePrePreparesRules(t *testing.T) { require.Fail(t, "unexpected round", "round=%d", round) } + def.Compare = func(ctx context.Context, qcommit Msg[int64, int64, int64], inputValueSourceCh <-chan int64, inputValueSource int64, returnErr chan error, returnValue chan int64) { + returnErr <- nil + } - rChan := make(chan Msg[int64, int64], 2) + rChan := make(chan Msg[int64, int64, int64], 2) rChan <- newPreprepare(1) rChan <- newPreprepare(2) @@ -737,22 +737,255 @@ func TestDuplicatePrePreparesRules(t *testing.T) { transport := noopTransport transport.Receive = rChan - _ = Run(ctx, def, transport, 0, noLeader, InputValue(int64(1))) + _ = Run(ctx, def, transport, 0, noLeader, InputValue(int64(1)), InputValueSource(int64(2))) } // noopTransport is a transport that does nothing. -var noopTransport = Transport[int64, int64]{ - Broadcast: func(context.Context, MsgType, int64, int64, int64, int64, int64, int64, []Msg[int64, int64]) error { +var noopTransport = Transport[int64, int64, int64]{ + Broadcast: func(context.Context, MsgType, int64, int64, int64, int64, int64, int64, []Msg[int64, int64, int64]) error { return nil }, } // noopDef is a definition that does nothing. -var noopDef = Definition[int64, int64]{ +var noopDef = Definition[int64, int64, int64]{ IsLeader: func(int64, int64, int64) bool { return false }, NewTimer: func(int64) (<-chan time.Time, func()) { return nil, func() {} }, - LogUponRule: func(context.Context, int64, int64, int64, Msg[int64, int64], UponRule) {}, - LogRoundChange: func(context.Context, int64, int64, int64, int64, UponRule, []Msg[int64, int64]) {}, - LogUnjust: func(context.Context, int64, int64, Msg[int64, int64]) {}, - LogDebug: func(context.Context, int64, int64, Msg[int64, int64], string) {}, + LogUponRule: func(context.Context, int64, int64, int64, Msg[int64, int64, int64], UponRule) {}, + LogRoundChange: func(context.Context, int64, int64, int64, int64, UponRule, []Msg[int64, int64, int64]) {}, + LogUnjust: func(context.Context, int64, int64, Msg[int64, int64, int64]) {}, +} + +type testChainSplit struct { + ValueSource map[int64]int64 // Use different value source for certain processes (used for chain-split-halt feature). + DecideRound int // Deterministic consensus at specific round. + PreparedVal int // If prepared value decided, as opposed to leader's value. + ShouldHalt bool // If halt is expected (no consensus reachead). +} + +var errChainSplitHalt = errors.New("chain split halt") + +func TestChainSplit(t *testing.T) { + t.Run("same value", func(t *testing.T) { + testQBFTChainSplit(t, testChainSplit{ + DecideRound: 1, + ValueSource: map[int64]int64{ + 1: 1, + 2: 1, + 3: 1, + 4: 1, + }, + PreparedVal: 1, + }) + }) + + t.Run("non-leader peer has different value", func(t *testing.T) { + testQBFTChainSplit(t, testChainSplit{ + DecideRound: 1, + ValueSource: map[int64]int64{ + 1: 1, + 2: 3, + 3: 1, + 4: 1, + }, + PreparedVal: 1, + }) + }) + + t.Run("first leader has different value, second leader succeeds", func(t *testing.T) { + testQBFTChainSplit(t, testChainSplit{ + DecideRound: 2, + ValueSource: map[int64]int64{ + 1: 3, + 2: 1, + 3: 1, + 4: 1, + }, + PreparedVal: 1, + }) + }) + + t.Run("no consensus - halt", func(t *testing.T) { + testQBFTChainSplit(t, testChainSplit{ + ValueSource: map[int64]int64{ + 1: 1, + 2: 1, + 3: 3, + 4: 3, + }, + ShouldHalt: true, + }) + }) +} + +func testQBFTChainSplit(t *testing.T, test testChainSplit) { + t.Helper() + + const ( + n = 4 + maxRound = 10 + fifoLimit = 100 + ) + + var ( + ctx, cancel = context.WithCancel(context.Background()) + clock = new(fakeClock) + receiveChannelsPerNode = make(map[int64]chan Msg[int64, int64, int64]) + broadcast = make(chan Msg[int64, int64, int64]) + resultChan = make(chan []Msg[int64, int64, int64], n) + runChan = make(chan error, n) + instance = int64(0) + ) + defer cancel() + + isLeader := makeIsLeader(n) + defs := Definition[int64, int64, int64]{ + IsLeader: isLeader, + NewTimer: func(round int64) (<-chan time.Time, func()) { + return clock.NewTimer(time.Duration(math.Pow(2, float64(round-1))) * time.Second) + }, + Decide: func(_ context.Context, instance int64, value int64, qcommit []Msg[int64, int64, int64]) { + resultChan <- qcommit + }, + Compare: func(ctx context.Context, qcommit Msg[int64, int64, int64], inputValueSourceCh <-chan int64, inputValueSource int64, returnCh chan error, returnIVS chan int64) { + vs, _ := qcommit.ValueSource() + + if inputValueSource == 0 { + inputValueSource = <-inputValueSourceCh + returnIVS <- inputValueSource + } + + if vs != inputValueSource { + returnCh <- errors.New("mismatch", z.I64("leadervalue", vs), z.I64("localvalue", inputValueSource)) + return + } + + returnCh <- nil + }, + LogRoundChange: func(ctx context.Context, instance int64, process, round, newRound int64, rule UponRule, msgs []Msg[int64, int64, int64]) { + t.Logf("%s %v@%d change to %d ~= %v", clock.NowStr(), process, round, newRound, rule) + }, + LogUponRule: func(_ context.Context, instance int64, process, round int64, msg Msg[int64, int64, int64], rule UponRule) { + t.Logf("%s %d => %v@%d -> %v@%d ~= %v", clock.NowStr(), msg.Source(), msg.Type(), msg.Round(), process, round, rule) + + if round > maxRound { + cancel() + } + }, + LogUnjust: func(_ context.Context, instance int64, process int64, msg Msg[int64, int64, int64]) { + t.Logf("Unjust: %#v", msg) + }, + Nodes: n, + FIFOLimit: fifoLimit, + } + + // Start each charon node + for i := int64(1); i <= n; i++ { + receive := make(chan Msg[int64, int64, int64], 1000) + receiveChannelsPerNode[i] = receive + transport := Transport[int64, int64, int64]{ + Broadcast: func(ctx context.Context, typ MsgType, instance int64, source int64, round int64, value int64, + pr int64, pv int64, justify []Msg[int64, int64, int64], + ) error { + if round > maxRound { + if test.ShouldHalt { + return errChainSplitHalt + } + + return errors.New("max round reach") + } + + t.Logf("%s %v => %v@%d", clock.NowStr(), source, typ, round) + + msg := newMsg(typ, instance, source, round, value, value, pr, pv, justify) + receive <- msg // Always send to self first (no jitter, no drops). + + bcast(t, broadcast, msg, 0, clock) + + return nil + }, + Receive: receive, + } + + go func(i int64) { + // Only enqueue input values for instances that: + // - have a value delay + // - or expect multiple rounds + // - or otherwise only the leader of round 1. + vChan := make(chan int64, 1) + vsChan := make(chan int64, 1) + + go func() { + vChan <- test.ValueSource[i] + + vsChan <- test.ValueSource[i] + }() + + runChan <- Run(ctx, defs, transport, instance, i, vChan, vsChan) + }(i) + } + + var ( + results = make(map[int64]Msg[int64, int64, int64]) + count int + decided bool + done int + ) + + for { + select { + case msg := <-broadcast: + for target, out := range receiveChannelsPerNode { + if target == msg.Source() { + continue // Do not broadcast to self, we sent to self already. + } + + out <- msg + + if rand.Float64() < 0.1 { // Send 10% messages twice + out <- msg + } + } + case qCommit := <-resultChan: + for _, commit := range qCommit { + // Ensure that all results are the same + for _, previous := range results { + require.Equal(t, previous.Value(), commit.Value(), "commit values") + } + + require.EqualValues(t, test.DecideRound, commit.Round(), "wrong decide round") + + if test.PreparedVal != 0 { // Check prepared value if set + require.EqualValues(t, test.PreparedVal, commit.Value(), "wrong prepared value") + } + + results[commit.Source()] = commit + } + + count++ + if count != n { + continue + } + + round := qCommit[0].Round() + t.Logf("Got all results in round %d after %s: %#v", round, clock.SinceT0(), results) + + // Trigger shutdown + decided = true + + cancel() + case err := <-runChan: + if !decided && !errors.Is(err, errChainSplitHalt) { + require.Fail(t, "unexpected run error", err) + } + + done++ + if done == n { + return + } + default: + time.Sleep(time.Microsecond) + clock.Advance(time.Millisecond * 1) + } + } }