From 07c5a3719f4285660f76fad50bb941a489a0e42f Mon Sep 17 00:00:00 2001 From: augurier <14434658+augurier@user.noreply.gitee.com> Date: Sat, 12 Apr 2025 13:16:51 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E5=BA=94=E4=B8=A4=E7=A7=8D=E7=89=88?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/main.go | 2 +- internal/nodes/init.go | 10 +++- internal/nodes/replica.go | 11 ---- internal/nodes/thread_transport.go | 115 +++++++++++++++++++++++-------------- internal/nodes/vote.go | 6 -- 5 files changed, 82 insertions(+), 62 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index bd36b09..b3e9a2a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -72,7 +72,7 @@ func main() { defer storage.Close() // 初始化 - node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, !*isRestart) + node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, *isRestart) // 开启 raft quitChan := make(chan struct{}, 1) diff --git a/internal/nodes/init.go b/internal/nodes/init.go index 27999df..64eec06 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -159,7 +159,10 @@ func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) { if !ok || !ok2 { req.Done <- errors.New("type assertion failed for AppendEntries") } else { - req.Done <- node.AppendEntries(arg, resp) + var respCopy AppendEntriesReply + err := node.AppendEntries(arg, &respCopy) + *resp = respCopy + req.Done <- err } case "Node.RequestVote": @@ -168,7 +171,10 @@ func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) { if !ok || !ok2 { req.Done <- errors.New("type assertion failed for RequestVote") } else { - req.Done <- node.RequestVote(arg, resp) + var respCopy RequestVoteReply + err := node.RequestVote(arg, &respCopy) + *resp = respCopy + req.Done <- err } case "Node.WriteKV": diff --git a/internal/nodes/replica.go b/internal/nodes/replica.go index 6807254..de8b98b 100644 --- a/internal/nodes/replica.go +++ b/internal/nodes/replica.go @@ -19,7 +19,6 @@ type AppendEntriesArg struct { } type AppendEntriesReply struct { - Mu sync.Mutex Term int Success bool } @@ -121,7 +120,6 @@ func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) { return } - appendReply.Mu.Lock() if appendReply.Term != node.CurrTerm { log.Sugar().Infof("term=%s的leader[%s]因为[%s]收到更高的term=%s, 转换为follower", strconv.Itoa(node.CurrTerm), node.SelfId, peerId, strconv.Itoa(appendReply.Term)) @@ -132,17 +130,14 @@ func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) { node.VotedFor = "" node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor) node.ResetElectionTimer() - appendReply.Mu.Unlock() node.Mu.Unlock() return } if appendReply.Success { - appendReply.Mu.Unlock() break } - appendReply.Mu.Unlock() NextIndex-- // 失败往前传一格 } @@ -207,10 +202,8 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply // 如果 term 过期,拒绝接受日志 if node.CurrTerm > arg.Term { - reply.Mu.Lock() reply.Term = node.CurrTerm reply.Success = false - reply.Mu.Unlock() return nil } @@ -228,10 +221,8 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply // 检查 prevLogIndex 是否有效 if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { - reply.Mu.Lock() reply.Term = node.CurrTerm reply.Success = false - reply.Mu.Unlock() return nil } @@ -274,9 +265,7 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply // 在成功接受日志或心跳后,重置选举超时 node.ResetElectionTimer() - reply.Mu.Lock() reply.Term = node.CurrTerm reply.Success = true - reply.Mu.Unlock() return nil } \ No newline at end of file diff --git a/internal/nodes/thread_transport.go b/internal/nodes/thread_transport.go index 1b2e9d8..c437eee 100644 --- a/internal/nodes/thread_transport.go +++ b/internal/nodes/thread_transport.go @@ -120,42 +120,38 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod } var isConnected bool - if threadClient.SourceId == "" { // 来自客户端的连接 + if threadClient.SourceId == "" { isConnected = true } else { t.mu.Lock() - isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] // 检查连通性 - t.mu.Unlock() + isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] + t.mu.Unlock() + } + if !isConnected { + return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) } - - if !isConnected { - return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) - } - - // 获取目标节点的 channel targetChan, exists := t.getNodeChan(threadClient.TargetId) if !exists { return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId) } - // 创建响应通道(用于返回 RPC 结果) done := make(chan error, 1) - behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId) - // 发送请求 - request := RPCRequest{ - ServiceMethod: serviceMethod, - Args: args, - Reply: reply, - Done: done, - SourceId: threadClient.SourceId, - Behavior: behavior, + + // 辅助函数:复制 replyCopy 到原始 reply + copyReply := func(dst, src interface{}) { + switch d := dst.(type) { + case *AppendEntriesReply: + *d = *(src.(*AppendEntriesReply)) + case *RequestVoteReply: + *d = *(src.(*RequestVoteReply)) + } } - sendRequest := func(req RPCRequest, targetChan chan RPCRequest) bool { + sendRequest := func(req RPCRequest, ch chan RPCRequest) bool { select { - case targetChan <- req: + case ch <- req: return true default: return false @@ -168,36 +164,71 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod if !ok { log.Fatal("没有设置对应的retry次数") } - request.Behavior = NormalRpc - // 尝试发送多次, 期待同一个done + + var lastErr error for i := 0; i < retryTimes; i++ { + var replyCopy interface{} + useCopy := true + + switch r := reply.(type) { + case *AppendEntriesReply: + tmp := *r + replyCopy = &tmp + case *RequestVoteReply: + tmp := *r + replyCopy = &tmp + default: + replyCopy = reply // 其他类型不复制 + useCopy = false + } + + request := RPCRequest{ + ServiceMethod: serviceMethod, + Args: args, + Reply: replyCopy, + Done: done, + SourceId: threadClient.SourceId, + Behavior: NormalRpc, + } + if !sendRequest(request, targetChan) { return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) } + + select { + case err := <-done: + if err == nil && useCopy { + copyReply(reply, replyCopy) + } + if err == nil { + return nil + } + lastErr = err + case <-time.After(250 * time.Millisecond): + lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod) + } } + return lastErr default: + request := RPCRequest{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + SourceId: threadClient.SourceId, + Behavior: behavior, + } + if !sendRequest(request, targetChan) { return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) } - } - - // 等待响应或超时 - select { - case err := <-done: - if threadClient.SourceId == "" { // 来自客户端的连接 - isConnected = true - } else { - t.mu.Lock() - isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性 - t.mu.Unlock() - } - - if !isConnected { - return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId) + + select { + case err := <-done: + return err + case <-time.After(250 * time.Millisecond): + return fmt.Errorf("RPC 调用超时: %s", serviceMethod) } - return err - case <-time.After(250 * time.Millisecond): - return fmt.Errorf("RPC 调用超时: %s", serviceMethod) } -} +} \ No newline at end of file diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go index f1eb7ae..6ec64b6 100644 --- a/internal/nodes/vote.go +++ b/internal/nodes/vote.go @@ -18,7 +18,6 @@ type RequestVoteArgs struct { } type RequestVoteReply struct { - Mu sync.Mutex Term int // 当前节点的最新任期 VoteGranted bool // 是否同意投票 } @@ -82,7 +81,6 @@ func (n *Node) StartElection() { return } - reply.Mu.Lock() if reply.Term > n.CurrTerm { // 发现更高任期,回退为 Follower log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term) @@ -91,14 +89,12 @@ func (n *Node) StartElection() { n.VotedFor = "" n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) n.ResetElectionTimer() - reply.Mu.Unlock() return } if reply.VoteGranted { grantedVotes++ } - reply.Mu.Unlock() if grantedVotes == totalNodes / 2 + 1 { n.State = Leader @@ -151,8 +147,6 @@ func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error n.Mu.Lock() defer n.Mu.Unlock() - reply.Mu.Lock() - defer reply.Mu.Unlock() // 如果候选人的任期小于当前任期,则拒绝投票 if args.Term < n.CurrTerm { reply.Term = n.CurrTerm