diff --git a/internal/nodes/init.go b/internal/nodes/init.go index a393a5e..77c70f9 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -112,6 +112,22 @@ func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) for { select { case req := <-rpcChan: + switch req.Behavior { + case DelayRpc: + threadTran, ok := node.Transport.(*ThreadTransport) + if !ok { + log.Fatal("无效的delayRpc模式") + } + duration, ok2 := threadTran.Ctx.GetDelay(req.SourceId, node.SelfId) + if !ok2 { + log.Fatal("没有设置对应的delay时间") + } + time.Sleep(duration) + + case FailRpc: + continue + } + switch req.ServiceMethod { case "Node.AppendEntries": arg, ok := req.Args.(*AppendEntriesArg) diff --git a/internal/nodes/simulate_ctx.go b/internal/nodes/simulate_ctx.go new file mode 100644 index 0000000..0fa2336 --- /dev/null +++ b/internal/nodes/simulate_ctx.go @@ -0,0 +1,65 @@ +package nodes + +import ( + "fmt" + "sync" + "time" +) + +// Ctx 结构体:管理不同节点之间的通信行为 +type Ctx struct { + mu sync.Mutex + Behavior map[string]CallBehavior // (src,target) -> CallBehavior + Delay map[string]time.Duration // (src,target) -> 延迟时间 + Retries map[string]int // 记录 (src,target) 的重发调用次数 +} + +// NewCtx 创建上下文 +func NewCtx() *Ctx { + return &Ctx{ + Behavior: make(map[string]CallBehavior), + Delay: make(map[string]time.Duration), + Retries: make(map[string]int), + } +} + +// SetBehavior 设置 A->B 的 RPC 行为 +func (c *Ctx) SetBehavior(src, dst string, behavior CallBehavior, delay time.Duration, retries int) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + c.Behavior[key] = behavior + c.Delay[key] = delay + c.Retries[key] = retries +} + +// GetBehavior 获取 A->B 的行为 +func (c *Ctx) GetBehavior(src, dst string) (CallBehavior) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if state, exists := c.Behavior[key]; exists { + return state + } + return NormalRpc +} + +func (c *Ctx) GetDelay(src, dst string) (t time.Duration, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if t, ok = c.Delay[key]; ok { + return t, ok + } + return 0, ok +} + +func (c *Ctx) GetRetries(src, dst string) (times int, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if times, ok = c.Retries[key]; ok { + return times, ok + } + return 0, ok +} \ No newline at end of file diff --git a/internal/nodes/thread_transport.go b/internal/nodes/thread_transport.go index 9eb07b3..eaee05a 100644 --- a/internal/nodes/thread_transport.go +++ b/internal/nodes/thread_transport.go @@ -6,12 +6,26 @@ import ( "time" ) +type CallBehavior = uint8 + +const ( + NormalRpc CallBehavior = iota + 1 + DelayRpc + RetryRpc + FailRpc +) + // RPC 请求结构 type RPCRequest struct { ServiceMethod string Args interface{} Reply interface{} + Done chan error // 用于返回响应 + + SourceId string + // 模拟rpc请求状态 + Behavior CallBehavior } // 线程版 Transport @@ -19,6 +33,7 @@ type ThreadTransport struct { mu sync.Mutex nodeChans map[string]chan RPCRequest // 每个节点的消息通道 connectivityMap map[string]map[string]bool // 模拟网络分区 + Ctx *Ctx } // 线程版 dial的返回clientinterface @@ -32,10 +47,11 @@ func (c *ThreadClient) Close() error { } // 初始化线程通信系统 -func NewThreadTransport() *ThreadTransport { +func NewThreadTransport(ctx *Ctx) *ThreadTransport { return &ThreadTransport{ nodeChans: make(map[string]chan RPCRequest), connectivityMap: make(map[string]map[string]bool), + Ctx: ctx, } } @@ -101,7 +117,7 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod if !isConnected { - return fmt.Errorf("network partition: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) + return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) } // 获取目标节点的 channel @@ -113,35 +129,62 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod // 创建响应通道(用于返回 RPC 结果) done := make(chan error, 1) + behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId) // 发送请求 request := RPCRequest{ ServiceMethod: serviceMethod, Args: args, Reply: reply, Done: done, + SourceId: threadClient.SourceId, + Behavior: behavior, } - select { - case targetChan <- request: - // 等待响应或超时 + sendRequest := func(req RPCRequest, targetChan chan RPCRequest) bool { select { - case err := <-done: - if threadClient.SourceId == "" { // 来自客户端的连接 - isConnected = true - } else { - t.mu.Lock() - isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性 - t.mu.Unlock() - } + case targetChan <- req: + return true + default: + return false + } + } - if !isConnected { - return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId) + switch behavior { + case RetryRpc: + retryTimes, ok := t.Ctx.GetRetries(threadClient.SourceId, threadClient.TargetId) + if !ok { + log.Fatal("没有设置对应的retry次数") + } + request.Behavior = NormalRpc + // 尝试发送多次, 期待同一个done + for i := 0; i < retryTimes; i++ { + if !sendRequest(request, targetChan) { + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) } - return err - case <-time.After(100 * time.Millisecond): - return fmt.Errorf("RPC 调用超时: %s", serviceMethod) } + default: - return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) + if !sendRequest(request, targetChan) { + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) + } + } + + // 等待响应或超时 + select { + case err := <-done: + if threadClient.SourceId == "" { // 来自客户端的连接 + isConnected = true + } else { + t.mu.Lock() + isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性 + t.mu.Unlock() + } + + if !isConnected { + return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId) + } + return err + case <-time.After(100 * time.Millisecond): + return fmt.Errorf("RPC 调用超时: %s", serviceMethod) } } diff --git a/threadTest/election_test.go b/threadTest/election_test.go index a69d7ab..335c729 100644 --- a/threadTest/election_test.go +++ b/threadTest/election_test.go @@ -17,7 +17,7 @@ func TestInitElection(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -54,7 +54,7 @@ func TestRepeatElection(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -92,7 +92,7 @@ func TestBelowHalfCandidateElection(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -131,7 +131,7 @@ func TestOverHalfCandidateElection(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -159,4 +159,82 @@ func TestOverHalfCandidateElection(t *testing.T) { for i := 0; i < n; i++ { CheckTerm(t, nodeCollections[i], 2) } +} + +func TestRepeatVoteRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections, quitCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + ctx.SetBehavior("1", "2", nodes.RetryRpc, 0, 2) + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) +} + +func TestFailVoteRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections, quitCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + ctx.SetBehavior("1", "2", nodes.FailRpc, 0, 0) + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) } \ No newline at end of file diff --git a/threadTest/log_replication_test.go b/threadTest/log_replication_test.go index 15a897f..7cc0955 100644 --- a/threadTest/log_replication_test.go +++ b/threadTest/log_replication_test.go @@ -17,7 +17,7 @@ func TestNormalReplication(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -64,7 +64,7 @@ func TestParallelReplication(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -112,7 +112,7 @@ func TestFollowerLagging(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) diff --git a/threadTest/network_partition_test.go b/threadTest/network_partition_test.go index c615737..94b765b 100644 --- a/threadTest/network_partition_test.go +++ b/threadTest/network_partition_test.go @@ -11,7 +11,7 @@ import ( ) func TestBasicConnectivity(t *testing.T) { - transport := nodes.NewThreadTransport() + transport := nodes.NewThreadTransport(nodes.NewCtx()) transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10)) transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10)) @@ -44,7 +44,7 @@ func TestSingelPartition(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) @@ -160,7 +160,7 @@ func TestQuorumPartition(t *testing.T) { // 结点启动 var quitCollections []chan struct{} var nodeCollections []*nodes.Node - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) diff --git a/threadTest/restart_node_test.go b/threadTest/restart_node_test.go index 24804f6..9dc3b95 100644 --- a/threadTest/restart_node_test.go +++ b/threadTest/restart_node_test.go @@ -18,7 +18,7 @@ func TestNodeRestart(t *testing.T) { // 结点启动 var quitCollections []chan struct{} - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan) diff --git a/threadTest/server_client_test.go b/threadTest/server_client_test.go index 7cb8690..295f5a6 100644 --- a/threadTest/server_client_test.go +++ b/threadTest/server_client_test.go @@ -18,7 +18,7 @@ func TestServerClient(t *testing.T) { // 结点启动 var quitCollections []chan struct{} - threadTransport := nodes.NewThreadTransport() + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) for i := 0; i < n; i++ { _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) quitCollections = append(quitCollections, quitChan)