From f669abd7ad8611f05afbb287114c54b8e4ec38b6 Mon Sep 17 00:00:00 2001 From: augurier <14434658+augurier@user.noreply.gitee.com> Date: Sun, 30 Mar 2025 18:24:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E5=8C=BA=E6=B5=8B=E8=AF=95=EF=BC=88?= =?UTF-8?q?=E4=BB=A5=E5=8F=8A=E9=80=82=E5=BA=94=E7=BB=86=E7=B2=92=E5=BA=A6?= =?UTF-8?q?=E7=9A=84=E4=BF=AE=E6=94=B9=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/client_node.go | 12 ++- internal/nodes/init.go | 98 +++++++++--------- internal/nodes/log.go | 8 -- internal/nodes/node.go | 34 +++---- internal/nodes/real_transport.go | 2 +- internal/nodes/replica.go | 191 ++++++++++++++++++----------------- internal/nodes/server_node.go | 40 ++++---- internal/nodes/thread_transport.go | 60 +++++++++-- internal/nodes/transport.go | 2 +- internal/nodes/vote.go | 114 ++++++++++----------- test/restart_node_test.go | 2 +- test/server_client_test.go | 2 +- threadTest/network_partition_test.go | 150 +++++++++++++++++++++++++++ threadTest/restart_node_test.go | 2 +- threadTest/server_client_test.go | 2 +- 15 files changed, 457 insertions(+), 262 deletions(-) create mode 100644 threadTest/network_partition_test.go diff --git a/internal/client/client_node.go b/internal/client/client_node.go index a59a9aa..1ef1de6 100644 --- a/internal/client/client_node.go +++ b/internal/client/client_node.go @@ -35,7 +35,7 @@ func (client *Client) FindActiveNode() nodes.ClientInterface { var c nodes.ClientInterface for { // 直到找到一个可连接的节点(保证至少一个节点活着) peerId := getRandomAddress(client.PeerIds) - c, err = client.Transport.DialHTTPWithTimeout("tcp", peerId) + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", peerId) if err != nil { log.Error("dialing: ", zap.Error(err)) } else { @@ -72,9 +72,13 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status { if !reply.Isleader { // 对方不是leader,根据反馈找leader leaderId := reply.LeaderId client.CloseRpcClient(c) - c, err = client.Transport.DialHTTPWithTimeout("tcp", leaderId) - for err != nil { // 重新找下一个存活节点 + if leaderId == "" { // 这个节点不知道leader是谁,再随机找 c = client.FindActiveNode() + } else { // dial leader + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId) + for err != nil { // dial失败,重新找下一个存活节点 + c = client.FindActiveNode() + } } } else { // 成功 client.CloseRpcClient(c) @@ -132,7 +136,7 @@ func (client *Client) FindLeader() string { if !reply.Isleader { // 对方不是leader,根据反馈找leader client.CloseRpcClient(c) - c, err = client.Transport.DialHTTPWithTimeout("tcp", reply.LeaderId) + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", reply.LeaderId) for err != nil { // 重新找下一个存活节点 c = client.FindActiveNode() } diff --git a/internal/nodes/init.go b/internal/nodes/init.go index 4193d14..2c3b51e 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -16,7 +16,7 @@ import ( var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) // 运行在进程上的初始化 + rpc注册 -func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node { +func InitRPCNode(SelfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node { var nodeIds []string for id := range nodeAddr { nodeIds = append(nodeIds, id) @@ -24,29 +24,29 @@ func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *lev // 创建节点 node := &Node{ - selfId: selfId, - leaderId: "", - nodes: nodeIds, - maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 - currTerm: 1, - log: make([]RaftLogEntry, 0), - commitIndex: -1, - lastApplied: -1, - nextIndex: make(map[string]int), - matchIndex: make(map[string]int), - db: db, - storage: rstorage, - transport: &HTTPTransport{NodeMap: nodeAddr}, + SelfId: SelfId, + LeaderId: "", + Nodes: nodeIds, + MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 + CurrTerm: 1, + Log: make([]RaftLogEntry, 0), + CommitIndex: -1, + LastApplied: -1, + NextIndex: make(map[string]int), + MatchIndex: make(map[string]int), + Db: db, + Storage: rstorage, + Transport: &HTTPTransport{NodeMap: nodeAddr}, } node.initLeaderState() if isRestart { - node.currTerm = rstorage.GetCurrentTerm() - node.votedFor = rstorage.GetVotedFor() - node.log = rstorage.GetLogEntries() - log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log)) + node.CurrTerm = rstorage.GetCurrentTerm() + node.VotedFor = rstorage.GetVotedFor() + node.Log = rstorage.GetLogEntries() + log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log)) } - log.Sugar().Infof("[%s]开始监听" + port + "端口", selfId) + log.Sugar().Infof("[%s]开始监听" + port + "端口", SelfId) node.ListenPort(port) return node @@ -73,33 +73,33 @@ func (node *Node) ListenPort(port string) { } // 线程模拟的初始化 -func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) { +func InitThreadNode(SelfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) { rpcChan := make(chan RPCRequest, 100) // 要监听的chan // 创建节点 node := &Node{ - selfId: selfId, - leaderId: "", - nodes: peerIds, - maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 - currTerm: 1, - log: make([]RaftLogEntry, 0), - commitIndex: -1, - lastApplied: -1, - nextIndex: make(map[string]int), - matchIndex: make(map[string]int), - db: db, - storage: rstorage, - transport: threadTransport, + SelfId: SelfId, + LeaderId: "", + Nodes: peerIds, + MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 + CurrTerm: 1, + Log: make([]RaftLogEntry, 0), + CommitIndex: -1, + LastApplied: -1, + NextIndex: make(map[string]int), + MatchIndex: make(map[string]int), + Db: db, + Storage: rstorage, + Transport: threadTransport, } node.initLeaderState() if isRestart { - node.currTerm = rstorage.GetCurrentTerm() - node.votedFor = rstorage.GetVotedFor() - node.log = rstorage.GetLogEntries() - log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log)) + node.CurrTerm = rstorage.GetCurrentTerm() + node.VotedFor = rstorage.GetVotedFor() + node.Log = rstorage.GetLogEntries() + log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log)) } - threadTransport.RegisterNodeChan(selfId, rpcChan) + threadTransport.RegisterNodeChan(SelfId, rpcChan) quitChan := make(chan struct{}, 1) go node.listenForChan(rpcChan, quitChan) @@ -107,7 +107,7 @@ func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *R } func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) { - defer node.db.Close() + defer node.Db.Close() for { select { @@ -162,7 +162,7 @@ func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod) } case <-quitChan: - log.Sugar().Infof("[%s] 监听线程收到退出信号", node.selfId) + log.Sugar().Infof("[%s] 监听线程收到退出信号", node.SelfId) return } } @@ -170,14 +170,14 @@ func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) // 共同部分和启动 func (n *Node) initLeaderState() { - for _, peerId := range n.nodes { - n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引 - n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引 + for _, peerId := range n.Nodes { + n.NextIndex[peerId] = len(n.Log) // 发送日志的下一个索引 + n.MatchIndex[peerId] = 0 // 复制日志的最新匹配索引 } } func Start(node *Node, quitChan chan struct{}) { - node.state = Follower // 所有节点以 Follower 状态启动 + node.State = Follower // 所有节点以 Follower 状态启动 node.resetElectionTimer() // 启动选举超时定时器 go func() { @@ -187,20 +187,20 @@ func Start(node *Node, quitChan chan struct{}) { for { select { case <-quitChan: - fmt.Printf("[%s] Raft start 退出...\n", node.selfId) + fmt.Printf("[%s] Raft start 退出...\n", node.SelfId) return // 退出 goroutine case <-ticker.C: - switch node.state { + switch node.State { case Follower: // 监听心跳超时 - fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId) + // fmt.Printf("[%s] is a follower, 监听中...\n", node.SelfId) case Leader: // 发送心跳 - fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId) + // fmt.Printf("[%s] is the leader, 发送心跳...\n", node.SelfId) node.resetElectionTimer() // leader 不主动触发选举 - node.BroadCastKV(Normal) + node.BroadCastKV() } } } diff --git a/internal/nodes/log.go b/internal/nodes/log.go index b2f1537..08c0830 100644 --- a/internal/nodes/log.go +++ b/internal/nodes/log.go @@ -2,13 +2,6 @@ package nodes import "strconv" -type CallMode = uint8 -const ( - Normal CallMode = iota + 1 - Delay - Fail -) - type LogEntry struct { Key string Value string @@ -28,7 +21,6 @@ func (RLogE *RaftLogEntry) print() string { type LogEntryCall struct { LogE LogEntry - CallState CallMode } type KVReply struct { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 0e3f053..b1e1c8b 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -16,48 +16,48 @@ const ( ) type Node struct { - mu sync.Mutex + Mu sync.Mutex // 当前节点id - selfId string + SelfId string // 记录的leader(不能用votedfor:投票的leader可能没有收到多数票) - leaderId string + LeaderId string // 除当前节点外其他节点id - nodes []string + Nodes []string // 当前节点状态 - state State + State State // 任期 - currTerm int + CurrTerm int // 简单的kv存储 - log []RaftLogEntry + Log []RaftLogEntry // leader用来标记新log, = log.len - maxLogId int + MaxLogId int // 已提交的index - commitIndex int + CommitIndex int // 最后应用(写到db)的index - lastApplied int + LastApplied int // 需要发送给每个节点的下一个索引 - nextIndex map[string]int + NextIndex map[string]int // 已经发送给每个节点的最大索引 - matchIndex map[string]int + MatchIndex map[string]int // 存kv(模拟状态机) - db *leveldb.DB + Db *leveldb.DB // 持久化节点数据(currterm votedfor log) - storage *RaftStorage + Storage *RaftStorage - votedFor string - electionTimer *time.Timer + VotedFor string + ElectionTimer *time.Timer // 通信方式 - transport Transport + Transport Transport } diff --git a/internal/nodes/real_transport.go b/internal/nodes/real_transport.go index fea359a..e5f45eb 100644 --- a/internal/nodes/real_transport.go +++ b/internal/nodes/real_transport.go @@ -13,7 +13,7 @@ type HTTPTransport struct{ } // 封装有超时的dial -func (t *HTTPTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) { +func (t *HTTPTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) { done := make(chan struct{}) var client *rpc.Client var err error diff --git a/internal/nodes/replica.go b/internal/nodes/replica.go index 49401d0..be380f7 100644 --- a/internal/nodes/replica.go +++ b/internal/nodes/replica.go @@ -1,10 +1,9 @@ package nodes import ( - "math/rand" "sort" "strconv" - "time" + "sync" "go.uber.org/zap" ) @@ -24,31 +23,31 @@ type AppendEntriesReply struct { } // leader收到新内容要广播,以及心跳广播(同步自己的log) -func (node *Node) BroadCastKV(callMode CallMode) { +func (node *Node) BroadCastKV() { + log.Sugar().Infof("leader[%s]广播消息", node.SelfId) + failCount := 0 + // 这里增加一个锁,防止并发修改成功计数 + var failMutex sync.Mutex // 遍历所有节点 - for _, id := range node.nodes { - go func(id string, kv CallMode) { - node.sendKV(id, callMode) - }(id, callMode) + for _, id := range node.Nodes { + go func(id string) { + node.sendKV(id, &failCount, &failMutex) + }(id) } } -func (node *Node) sendKV(peerId string, callMode CallMode) { - - switch callMode { - case Fail: - log.Info("模拟发送失败") - // 这么写向所有的node发送都失败,也可以随机数确定是否失败 - case Delay: - log.Info("模拟发送延迟") - // 随机延迟0-5ms - time.Sleep(time.Millisecond * time.Duration(rand.Intn(5))) - default: - } - - client, err := node.transport.DialHTTPWithTimeout("tcp", peerId) +func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) { + client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId) if err != nil { - log.Error(node.selfId + "dialling [" + peerId + "] fail: ", zap.Error(err)) + log.Error("[" + node.SelfId + "]dialling [" + peerId + "] fail: ", zap.Error(err)) + failMutex.Lock() + *failCount++ + if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级 + node.LeaderId = "" + node.State = Follower + node.resetElectionTimer() + } + failMutex.Unlock() return } @@ -59,69 +58,78 @@ func (node *Node) sendKV(peerId string, callMode CallMode) { } }(client) - node.mu.Lock() - defer node.mu.Unlock() + node.Mu.Lock() + defer node.Mu.Unlock() var appendReply AppendEntriesReply appendReply.Success = false - nextIndex := node.nextIndex[peerId] - // log.Info("nextindex " + strconv.Itoa(nextIndex)) + NextIndex := node.NextIndex[peerId] + // log.Info("NextIndex " + strconv.Itoa(NextIndex)) for (!appendReply.Success) { - if nextIndex < 0 { + if NextIndex < 0 { log.Fatal("assert >= 0 here") } - sendEntries := node.log[nextIndex:] + sendEntries := node.Log[NextIndex:] arg := AppendEntriesArg{ - Term: node.currTerm, - PrevLogIndex: nextIndex - 1, + Term: node.CurrTerm, + PrevLogIndex: NextIndex - 1, Entries: sendEntries, - LeaderCommit: node.commitIndex, - LeaderId: node.selfId, + LeaderCommit: node.CommitIndex, + LeaderId: node.SelfId, } if arg.PrevLogIndex >= 0 { - arg.PrevLogTerm = node.log[arg.PrevLogIndex].Term + arg.PrevLogTerm = node.Log[arg.PrevLogIndex].Term } - callErr := node.transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC + callErr := node.Transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC if callErr != nil { - log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr)) + log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr)) + failMutex.Lock() + *failCount++ + if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级 + node.LeaderId = "" + node.State = Follower + node.resetElectionTimer() + } + failMutex.Unlock() return } - if appendReply.Term != node.currTerm { - log.Info("term=" + strconv.Itoa(node.currTerm) + "的Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower") - node.currTerm = appendReply.Term - node.state = Follower - node.votedFor = "" - node.storage.SetTermAndVote(node.currTerm, node.votedFor) + if appendReply.Term != node.CurrTerm { + log.Info("term=" + strconv.Itoa(node.CurrTerm) + "的Leader[" + node.SelfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower") + node.LeaderId = "" + node.CurrTerm = appendReply.Term + node.State = Follower + node.VotedFor = "" + node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor) node.resetElectionTimer() return } - nextIndex-- // 失败往前传一格 + NextIndex-- // 失败往前传一格 } // 不变成follower情况下 - node.nextIndex[peerId] = node.maxLogId + 1 - node.matchIndex[peerId] = node.maxLogId + node.NextIndex[peerId] = node.MaxLogId + 1 + node.MatchIndex[peerId] = node.MaxLogId node.updateCommitIndex() } func (node *Node) updateCommitIndex() { - totalNodes := len(node.nodes) + totalNodes := len(node.Nodes) - // 收集所有 matchIndex 并排序 - matchIndexes := make([]int, 0, totalNodes) - for _, index := range node.matchIndex { - matchIndexes = append(matchIndexes, index) + // 收集所有 MatchIndex 并排序 + MatchIndexes := make([]int, 0, totalNodes) + for _, index := range node.MatchIndex { + MatchIndexes = append(MatchIndexes, index) } - sort.Ints(matchIndexes) // 排序 + sort.Ints(MatchIndexes) // 排序 - // 计算多数派 commitIndex - majorityIndex := matchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派) + // 计算多数派 CommitIndex + majorityIndex := MatchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派) // 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志 - if majorityIndex > node.commitIndex && majorityIndex < len(node.log) && node.log[majorityIndex].Term == node.currTerm { - node.commitIndex = majorityIndex - log.Info("Leader[" + node.selfId + "]更新 commitIndex: " + strconv.Itoa(majorityIndex)) + if majorityIndex > node.CommitIndex && majorityIndex < len(node.Log) && node.Log[majorityIndex].Term == node.CurrTerm { + node.CommitIndex = majorityIndex + log.Info("Leader[" + node.SelfId + "]更新 CommitIndex: " + strconv.Itoa(majorityIndex)) // 应用日志到状态机 node.applyCommittedLogs() @@ -130,13 +138,13 @@ func (node *Node) updateCommitIndex() { // 应用日志到状态机 func (node *Node) applyCommittedLogs() { - for node.lastApplied < node.commitIndex { - node.lastApplied++ - logEntry := node.log[node.lastApplied] - log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.selfId) - err := node.db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil) + for node.LastApplied < node.CommitIndex { + node.LastApplied++ + logEntry := node.Log[node.LastApplied] + log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.SelfId) + err := node.Db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil) if err != nil { - log.Error(node.selfId + "应用状态机失败: ", zap.Error(err)) + log.Error(node.SelfId + "应用状态机失败: ", zap.Error(err)) } } } @@ -147,65 +155,66 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply // defer func() { // log.Sugar().Infof("AppendEntries 处理时间: %v", time.Since(start)) // }() - node.mu.Lock() - defer node.mu.Unlock() + log.Sugar().Infof("[%s]收到[%s]的AppendEntries", node.SelfId, arg.LeaderId) + node.Mu.Lock() + defer node.Mu.Unlock() // 如果 term 过期,拒绝接受日志 - if node.currTerm > arg.Term { - *reply = AppendEntriesReply{node.currTerm, false} + if node.CurrTerm > arg.Term { + *reply = AppendEntriesReply{node.CurrTerm, false} return nil } - node.leaderId = arg.LeaderId // 记录Leader + node.LeaderId = arg.LeaderId // 记录Leader // 如果term比自己高,或自己不是follower但收到相同term的心跳 - if node.currTerm < arg.Term || node.state != Follower { - log.Sugar().Infof("[%s]发现更高 term(%s)", node.selfId, strconv.Itoa(arg.Term)) - node.currTerm = arg.Term - node.state = Follower - node.votedFor = "" - // node.storage.SetTermAndVote(node.currTerm, node.votedFor) + if node.CurrTerm < arg.Term || node.State != Follower { + log.Sugar().Infof("[%s]发现更高 term(%s)", node.SelfId, strconv.Itoa(arg.Term)) + node.CurrTerm = arg.Term + node.State = Follower + node.VotedFor = "" + // node.storage.SetTermAndVote(node.CurrTerm, node.VotedFor) } - node.storage.SetTermAndVote(node.currTerm, node.votedFor) + node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor) // 检查 prevLogIndex 是否有效 - if arg.PrevLogIndex >= len(node.log) || (arg.PrevLogIndex >= 0 && node.log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { - *reply = AppendEntriesReply{node.currTerm, false} + if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { + *reply = AppendEntriesReply{node.CurrTerm, false} return nil } // 处理日志冲突(如果存在不同 term,则截断日志) idx := arg.PrevLogIndex + 1 - for i := idx; i < len(node.log) && i-idx < len(arg.Entries); i++ { - if node.log[i].Term != arg.Entries[i-idx].Term { - node.log = node.log[:idx] + for i := idx; i < len(node.Log) && i-idx < len(arg.Entries); i++ { + if node.Log[i].Term != arg.Entries[i-idx].Term { + node.Log = node.Log[:idx] break } } - // log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.log))) + // log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.Log))) // 追加新的日志条目 for _, raftLogEntry := range arg.Entries { - log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.selfId) - if idx < len(node.log) { - node.log[idx] = raftLogEntry + log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.SelfId) + if idx < len(node.Log) { + node.Log[idx] = raftLogEntry } else { - node.log = append(node.log, raftLogEntry) + node.Log = append(node.Log, raftLogEntry) } idx++ } // 暴力持久化 - node.storage.WriteLog(node.log) + node.Storage.WriteLog(node.Log) - // 更新 maxLogId - node.maxLogId = len(node.log) - 1 + // 更新 MaxLogId + node.MaxLogId = len(node.Log) - 1 - // 更新 commitIndex - if arg.LeaderCommit < node.maxLogId { - node.commitIndex = arg.LeaderCommit + // 更新 CommitIndex + if arg.LeaderCommit < node.MaxLogId { + node.CommitIndex = arg.LeaderCommit } else { - node.commitIndex = node.maxLogId + node.CommitIndex = node.MaxLogId } // 提交已提交的日志 @@ -213,6 +222,6 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply // 在成功接受日志或心跳后,重置选举超时 node.resetElectionTimer() - *reply = AppendEntriesReply{node.currTerm, true} + *reply = AppendEntriesReply{node.CurrTerm, true} return nil } \ No newline at end of file diff --git a/internal/nodes/server_node.go b/internal/nodes/server_node.go index daa1089..09b3223 100644 --- a/internal/nodes/server_node.go +++ b/internal/nodes/server_node.go @@ -1,8 +1,6 @@ package nodes import ( - "strconv" - "github.com/syndtr/goleveldb/leveldb" ) @@ -15,38 +13,34 @@ type ServerReply struct{ } // RPC call func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error { - log.Sugar().Infof("[%s]收到客户端write请求", node.selfId) + log.Sugar().Infof("[%s]收到客户端write请求", node.SelfId) // 自己不是leader,转交leader地址回复 - if node.state != Leader { + if node.State != Leader { reply.Isleader = false - if (node.leaderId == "") { - log.Fatal("还没选出第一个leader") - return nil - } - reply.LeaderId = node.leaderId - log.Sugar().Infof("[%s]转交给[%s]", node.selfId, node.leaderId) + reply.LeaderId = node.LeaderId // 可能是空,那client就随机再找一个节点 + log.Sugar().Infof("[%s]转交给[%s]", node.SelfId, node.LeaderId) return nil } // 自己是leader,修改自己的记录并广播 - node.maxLogId++ - logId := node.maxLogId - rLogE := RaftLogEntry{kvCall.LogE, logId, node.currTerm} - node.log = append(node.log, rLogE) - node.storage.AppendLog(rLogE) - log.Info("leader[" + node.selfId + "]处理请求 : " + kvCall.LogE.print() + ", 模拟方式 : " + strconv.Itoa(int(kvCall.CallState))) + node.MaxLogId++ + logId := node.MaxLogId + rLogE := RaftLogEntry{kvCall.LogE, logId, node.CurrTerm} + node.Log = append(node.Log, rLogE) + node.Storage.AppendLog(rLogE) + log.Info("leader[" + node.SelfId + "]处理请求 : " + kvCall.LogE.print()) // 广播给其它节点 - node.BroadCastKV(kvCall.CallState) + node.BroadCastKV() reply.Isleader = true return nil } // RPC call func (node *Node) ReadKey(key *string, reply *ServerReply) error { - log.Sugar().Infof("[%s]收到客户端read请求", node.selfId) + log.Sugar().Infof("[%s]收到客户端read请求", node.SelfId) // 先只读自己(无论自己是不是leader),也方便测试 - value, err := node.db.Get([]byte(*key), nil) + value, err := node.Db.Get([]byte(*key), nil) if err == leveldb.ErrNotFound { reply.HaveValue = false } else { @@ -64,17 +58,17 @@ type FindLeaderReply struct{ } func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error { // 自己不是leader,转交leader地址回复 - if node.state != Leader { + if node.State != Leader { reply.Isleader = false - if (node.leaderId == "") { + if (node.LeaderId == "") { log.Fatal("还没选出第一个leader") return nil } - reply.LeaderId = node.leaderId + reply.LeaderId = node.LeaderId return nil } - reply.LeaderId = node.selfId + reply.LeaderId = node.SelfId reply.Isleader = true return nil } diff --git a/internal/nodes/thread_transport.go b/internal/nodes/thread_transport.go index 879577d..9eb07b3 100644 --- a/internal/nodes/thread_transport.go +++ b/internal/nodes/thread_transport.go @@ -18,11 +18,13 @@ type RPCRequest struct { type ThreadTransport struct { mu sync.Mutex nodeChans map[string]chan RPCRequest // 每个节点的消息通道 + connectivityMap map[string]map[string]bool // 模拟网络分区 } // 线程版 dial的返回clientinterface type ThreadClient struct { - targetId string + SourceId string + TargetId string } func (c *ThreadClient) Close() error { @@ -33,6 +35,7 @@ func (c *ThreadClient) Close() error { func NewThreadTransport() *ThreadTransport { return &ThreadTransport{ nodeChans: make(map[string]chan RPCRequest), + connectivityMap: make(map[string]map[string]bool), } } @@ -41,6 +44,24 @@ func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) { t.mu.Lock() defer t.mu.Unlock() t.nodeChans[nodeId] = ch + + // 初始化连通性(默认所有节点互相可达) + if _, exists := t.connectivityMap[nodeId]; !exists { + t.connectivityMap[nodeId] = make(map[string]bool) + } + for peerId := range t.nodeChans { + t.connectivityMap[nodeId][peerId] = true + t.connectivityMap[peerId][nodeId] = true + } +} + +// 设置两个节点的连通性 +func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) { + t.mu.Lock() + defer t.mu.Unlock() + if _, exists := t.connectivityMap[from]; exists { + t.connectivityMap[from][to] = isConnected + } } // 获取节点的 channel @@ -52,27 +73,41 @@ func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) { } // 模拟 Dial 操作 -func (t *ThreadTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) { +func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) { t.mu.Lock() defer t.mu.Unlock() if _, exists := t.nodeChans[peerId]; !exists { return nil, fmt.Errorf("节点 [%s] 不存在", peerId) } - return &ThreadClient{targetId: peerId}, nil + return &ThreadClient{SourceId: myId, TargetId: peerId}, nil } // 模拟 Call 操作 func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { threadClient, ok := client.(*ThreadClient) if !ok { - return fmt.Errorf("无效的客户端") + return fmt.Errorf("无效的caller") } + var isConnected bool + if threadClient.SourceId == "" { // 来自客户端的连接 + isConnected = true + } else { + t.mu.Lock() + isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] // 检查连通性 + t.mu.Unlock() + } + + + if !isConnected { + return fmt.Errorf("network partition: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) + } + // 获取目标节点的 channel - targetChan, exists := t.getNodeChan(threadClient.targetId) + targetChan, exists := t.getNodeChan(threadClient.TargetId) if !exists { - return fmt.Errorf("目标节点 [%s] 不存在", threadClient.targetId) + return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId) } // 创建响应通道(用于返回 RPC 结果) @@ -91,11 +126,22 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod // 等待响应或超时 select { case err := <-done: + if threadClient.SourceId == "" { // 来自客户端的连接 + isConnected = true + } else { + t.mu.Lock() + isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性 + t.mu.Unlock() + } + + if !isConnected { + return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId) + } return err case <-time.After(100 * time.Millisecond): return fmt.Errorf("RPC 调用超时: %s", serviceMethod) } default: - return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.targetId) + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) } } diff --git a/internal/nodes/transport.go b/internal/nodes/transport.go index 40f4cdd..8d92985 100644 --- a/internal/nodes/transport.go +++ b/internal/nodes/transport.go @@ -5,6 +5,6 @@ type ClientInterface interface{ } type Transport interface { - DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) + DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error } diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go index a87a8d6..2f6fefb 100644 --- a/internal/nodes/vote.go +++ b/internal/nodes/vote.go @@ -22,15 +22,15 @@ type RequestVoteReply struct { } func (n *Node) startElection() { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() // 增加当前任期,转换为 Candidate - n.currTerm++ - n.state = Candidate - n.votedFor = n.selfId // 自己投自己 - n.storage.SetTermAndVote(n.currTerm, n.votedFor) + n.CurrTerm++ + n.State = Candidate + n.VotedFor = n.SelfId // 自己投自己 + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) - log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.selfId, n.currTerm) + log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.SelfId, n.CurrTerm) // 重新设置选举超时,防止重复选举 n.resetElectionTimer() @@ -39,40 +39,40 @@ func (n *Node) startElection() { var lastLogIndex int var lastLogTerm int - if len(n.log) == 0 { + if len(n.Log) == 0 { lastLogIndex = 0 lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0 } else { - lastLogIndex = len(n.log) - 1 - lastLogTerm = n.log[lastLogIndex].Term + lastLogIndex = len(n.Log) - 1 + lastLogTerm = n.Log[lastLogIndex].Term } args := RequestVoteArgs{ - Term: n.currTerm, - CandidateId: n.selfId, + Term: n.CurrTerm, + CandidateId: n.SelfId, LastLogIndex: lastLogIndex, LastLogTerm: lastLogTerm, } // 并行向其他节点发送请求投票 - var mu sync.Mutex - totalNodes := len(n.nodes) + var Mu sync.Mutex + totalNodes := len(n.Nodes) grantedVotes := 1 // 自己的票 - for _, peerId := range n.nodes { + for _, peerId := range n.Nodes { go func(peerId string) { reply := RequestVoteReply{} if n.sendRequestVote(peerId, &args, &reply) { - mu.Lock() + Mu.Lock() - if reply.Term > n.currTerm { + if reply.Term > n.CurrTerm { // 发现更高任期,回退为 Follower - log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.selfId, reply.Term) - n.currTerm = reply.Term - n.state = Follower - n.votedFor = "" - n.storage.SetTermAndVote(n.currTerm, n.votedFor) + log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term) + n.CurrTerm = reply.Term + n.State = Follower + n.VotedFor = "" + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) n.resetElectionTimer() - mu.Unlock() + Mu.Unlock() return } @@ -81,32 +81,32 @@ func (n *Node) startElection() { } if grantedVotes == totalNodes / 2 + 1 { - n.state = Leader - log.Sugar().Infof("[%s] 当选 Leader!", n.selfId) + n.State = Leader + log.Sugar().Infof("[%s] 当选 Leader!", n.SelfId) n.initLeaderState() } - mu.Unlock() + Mu.Unlock() } }(peerId) } // 等待选举结果 time.Sleep(300 * time.Millisecond) - mu.Lock() - if n.state == Candidate { - log.Sugar().Infof("[%s] 选举超时,重新发起选举", n.selfId) - // n.state = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower + Mu.Lock() + if n.State == Candidate { + log.Sugar().Infof("[%s] 选举超时,重新发起选举", n.SelfId) + // n.State = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower n.resetElectionTimer() } - mu.Unlock() + Mu.Unlock() } func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool { - log.Sugar().Infof("[%s] 请求 [%s] 投票", node.selfId, peerId) - client, err := node.transport.DialHTTPWithTimeout("tcp", peerId) + log.Sugar().Infof("[%s] 请求 [%s] 投票", node.SelfId, peerId) + client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId) if err != nil { - log.Error(node.selfId + "dialing [" + peerId + "] fail: ", zap.Error(err)) + log.Error("[" + node.SelfId + "]dialing [" + peerId + "] fail: ", zap.Error(err)) return false } @@ -117,50 +117,50 @@ func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *R } }(client) - callErr := node.transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC + callErr := node.Transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC if callErr != nil { - log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr)) + log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr)) } return callErr == nil } func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() // 如果候选人的任期小于当前任期,则拒绝投票 - if args.Term < n.currTerm { - reply.Term = n.currTerm + if args.Term < n.CurrTerm { + reply.Term = n.CurrTerm reply.VoteGranted = false return nil } // 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower - if args.Term > n.currTerm { - n.currTerm = args.Term - n.state = Follower - n.votedFor = "" + if args.Term > n.CurrTerm { + n.CurrTerm = args.Term + n.State = Follower + n.VotedFor = "" n.resetElectionTimer() // 重新设置选举超时 } // 检查是否已经投过票,且是否投给了同一个候选人 - if n.votedFor == "" || n.votedFor == args.CandidateId { + if n.VotedFor == "" || n.VotedFor == args.CandidateId { // 检查日志是否足够新 var lastLogIndex int var lastLogTerm int - if len(n.log) == 0 { + if len(n.Log) == 0 { lastLogIndex = -1 lastLogTerm = 0 } else { - lastLogIndex = len(n.log) - 1 - lastLogTerm = n.log[lastLogIndex].Term + lastLogIndex = len(n.Log) - 1 + lastLogTerm = n.Log[lastLogIndex].Term } if args.LastLogTerm > lastLogTerm || (args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) { // 够新就投票给候选人 - n.votedFor = args.CandidateId - log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.currTerm), n.selfId, n.votedFor) + n.VotedFor = args.CandidateId + log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.CurrTerm), n.SelfId, n.VotedFor) reply.VoteGranted = true n.resetElectionTimer() } else { @@ -170,23 +170,23 @@ func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error reply.VoteGranted = false } - n.storage.SetTermAndVote(n.currTerm, n.votedFor) - reply.Term = n.currTerm + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) + reply.Term = n.CurrTerm return nil } // follower 500-1000ms内没收到appendentries心跳,就变成candidate发起选举 func (node *Node) resetElectionTimer() { - if node.electionTimer == nil { - node.electionTimer = time.NewTimer(time.Duration(500+rand.Intn(500)) * time.Millisecond) + if node.ElectionTimer == nil { + node.ElectionTimer = time.NewTimer(time.Duration(500+rand.Intn(500)) * time.Millisecond) go func() { for { - <-node.electionTimer.C + <-node.ElectionTimer.C node.startElection() } }() } else { - node.electionTimer.Stop() - node.electionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond) + node.ElectionTimer.Stop() + node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond) } } \ No newline at end of file diff --git a/test/restart_node_test.go b/test/restart_node_test.go index a9136f6..5efcfbe 100644 --- a/test/restart_node_test.go +++ b/test/restart_node_test.go @@ -52,7 +52,7 @@ func TestNodeRestart(t *testing.T) { for i := 0; i < 5; i++ { key := strconv.Itoa(i) newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + s := cWrite.Write(nodes.LogEntryCall{LogE: newlog}) if s != clientPkg.Ok { t.Errorf("write test fail") } diff --git a/test/server_client_test.go b/test/server_client_test.go index 822c548..8751c67 100644 --- a/test/server_client_test.go +++ b/test/server_client_test.go @@ -52,7 +52,7 @@ func TestServerClient(t *testing.T) { for i := 0; i < 10; i++ { key := strconv.Itoa(i) newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + s := c.Write(nodes.LogEntryCall{LogE: newlog}) if s != clientPkg.Ok { t.Errorf("write test fail") } diff --git a/threadTest/network_partition_test.go b/threadTest/network_partition_test.go new file mode 100644 index 0000000..3e23382 --- /dev/null +++ b/threadTest/network_partition_test.go @@ -0,0 +1,150 @@ +package threadTest + +import ( + "fmt" + clientPkg "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "strings" + "testing" + "time" +) + +func TestBasicConnectivity(t *testing.T) { + transport := nodes.NewThreadTransport() + + transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10)) + transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10)) + + // 断开 A 和 B + transport.SetConnectivity("1", "2", false) + + err := transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{}) + if err == nil { + t.Errorf("Expected network partition error, but got nil") + } + + // 恢复连接 + transport.SetConnectivity("1", "2", true) + + err = transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{}) + if !strings.Contains(err.Error(), "RPC 调用超时") { + t.Errorf("Expected success, but got error: %v", err) + } +} + +func TestElectionWithPartition(t *testing.T) { + // 登记结点信息 + n := 3 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport() + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(2 * time.Second) // 等待启动完毕 + fmt.Println("开始分区模拟1") + var leaderNo int + for i := 0; i < n; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderNo = i + for j := 0; j < n; j++ { + if i != j { // 切断其它节点到leader的消息 + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false) + } + } + } + } + time.Sleep(2 * time.Second) + + if nodeCollections[leaderNo].State == nodes.Leader { + t.Errorf("分区退选失败") + } + + // 恢复网络 + for j := 0; j < n; j++ { + if leaderNo != j { // 恢复其它节点到leader的消息 + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[leaderNo].SelfId, true) + } + } + + time.Sleep(1 * time.Second) + + var leaderCnt int + for i := 0; i < n; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderCnt++ + leaderNo = i + } + } + if leaderCnt != 1 { + t.Errorf("多leader产生") + } + + fmt.Println("开始分区模拟2") + for j := 0; j < n; j++ { + if leaderNo != j { // 切断leader到其它节点的消息 + threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, false) + } + } + time.Sleep(1 * time.Second) + + if nodeCollections[leaderNo].State == nodes.Leader { + t.Errorf("分区退选失败") + } + + leaderCnt = 0 + for j := 0; j < n; j++ { + if nodeCollections[j].State == nodes.Leader { + leaderCnt++ + } + } + if leaderCnt != 1 { + t.Errorf("多leader产生") + } + + // client启动 + c := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport} + var s clientPkg.Status + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = c.Write(nodes.LogEntryCall{LogE: newlog}) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + + // 恢复网络 + for j := 0; j < n; j++ { + if leaderNo != j { + threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, true) + } + } + + time.Sleep(time.Second) + // 日志一致性检查 + for i := 0; i < n; i++ { + if len(nodeCollections[i].Log) != 5 { + t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log))) + } + } +} \ No newline at end of file diff --git a/threadTest/restart_node_test.go b/threadTest/restart_node_test.go index 2f57593..24804f6 100644 --- a/threadTest/restart_node_test.go +++ b/threadTest/restart_node_test.go @@ -39,7 +39,7 @@ func TestNodeRestart(t *testing.T) { for i := 0; i < 5; i++ { key := strconv.Itoa(i) newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + s = cWrite.Write(nodes.LogEntryCall{LogE: newlog}) if s != clientPkg.Ok { t.Errorf("write test fail") } diff --git a/threadTest/server_client_test.go b/threadTest/server_client_test.go index 7768e66..7cb8690 100644 --- a/threadTest/server_client_test.go +++ b/threadTest/server_client_test.go @@ -40,7 +40,7 @@ func TestServerClient(t *testing.T) { for i := 0; i < 10; i++ { key := strconv.Itoa(i) newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + s = c.Write(nodes.LogEntryCall{LogE: newlog}) if s != clientPkg.Ok { t.Errorf("write test fail") }