Browse Source

适应两种版本

ld
augurier 5 months ago
parent
commit
07c5a3719f
5 changed files with 82 additions and 62 deletions
  1. +1
    -1
      cmd/main.go
  2. +8
    -2
      internal/nodes/init.go
  3. +0
    -11
      internal/nodes/replica.go
  4. +73
    -42
      internal/nodes/thread_transport.go
  5. +0
    -6
      internal/nodes/vote.go

+ 1
- 1
cmd/main.go View File

@ -72,7 +72,7 @@ func main() {
defer storage.Close()
// 初始化
node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, !*isRestart)
node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, *isRestart)
// 开启 raft
quitChan := make(chan struct{}, 1)

+ 8
- 2
internal/nodes/init.go View File

@ -159,7 +159,10 @@ func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) {
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for AppendEntries")
} else {
req.Done <- node.AppendEntries(arg, resp)
var respCopy AppendEntriesReply
err := node.AppendEntries(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.RequestVote":
@ -168,7 +171,10 @@ func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) {
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for RequestVote")
} else {
req.Done <- node.RequestVote(arg, resp)
var respCopy RequestVoteReply
err := node.RequestVote(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.WriteKV":

+ 0
- 11
internal/nodes/replica.go View File

@ -19,7 +19,6 @@ type AppendEntriesArg struct {
}
type AppendEntriesReply struct {
Mu sync.Mutex
Term int
Success bool
}
@ -121,7 +120,6 @@ func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) {
return
}
appendReply.Mu.Lock()
if appendReply.Term != node.CurrTerm {
log.Sugar().Infof("term=%s的leader[%s]因为[%s]收到更高的term=%s, 转换为follower",
strconv.Itoa(node.CurrTerm), node.SelfId, peerId, strconv.Itoa(appendReply.Term))
@ -132,17 +130,14 @@ func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) {
node.VotedFor = ""
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
node.ResetElectionTimer()
appendReply.Mu.Unlock()
node.Mu.Unlock()
return
}
if appendReply.Success {
appendReply.Mu.Unlock()
break
}
appendReply.Mu.Unlock()
NextIndex-- // 失败往前传一格
}
@ -207,10 +202,8 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
// 如果 term 过期,拒绝接受日志
if node.CurrTerm > arg.Term {
reply.Mu.Lock()
reply.Term = node.CurrTerm
reply.Success = false
reply.Mu.Unlock()
return nil
}
@ -228,10 +221,8 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
// 检查 prevLogIndex 是否有效
if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) {
reply.Mu.Lock()
reply.Term = node.CurrTerm
reply.Success = false
reply.Mu.Unlock()
return nil
}
@ -274,9 +265,7 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
// 在成功接受日志或心跳后,重置选举超时
node.ResetElectionTimer()
reply.Mu.Lock()
reply.Term = node.CurrTerm
reply.Success = true
reply.Mu.Unlock()
return nil
}

+ 73
- 42
internal/nodes/thread_transport.go View File

@ -120,42 +120,38 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod
}
var isConnected bool
if threadClient.SourceId == "" { // 来自客户端的连接
if threadClient.SourceId == "" {
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] // 检查连通性
t.mu.Unlock()
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId]
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
}
if !isConnected {
return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
}
// 获取目标节点的 channel
targetChan, exists := t.getNodeChan(threadClient.TargetId)
if !exists {
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId)
}
// 创建响应通道(用于返回 RPC 结果)
done := make(chan error, 1)
behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId)
// 发送请求
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
SourceId: threadClient.SourceId,
Behavior: behavior,
// 辅助函数:复制 replyCopy 到原始 reply
copyReply := func(dst, src interface{}) {
switch d := dst.(type) {
case *AppendEntriesReply:
*d = *(src.(*AppendEntriesReply))
case *RequestVoteReply:
*d = *(src.(*RequestVoteReply))
}
}
sendRequest := func(req RPCRequest, targetChan chan RPCRequest) bool {
sendRequest := func(req RPCRequest, ch chan RPCRequest) bool {
select {
case targetChan <- req:
case ch <- req:
return true
default:
return false
@ -168,36 +164,71 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod
if !ok {
log.Fatal("没有设置对应的retry次数")
}
request.Behavior = NormalRpc
// 尝试发送多次, 期待同一个done
var lastErr error
for i := 0; i < retryTimes; i++ {
var replyCopy interface{}
useCopy := true
switch r := reply.(type) {
case *AppendEntriesReply:
tmp := *r
replyCopy = &tmp
case *RequestVoteReply:
tmp := *r
replyCopy = &tmp
default:
replyCopy = reply // 其他类型不复制
useCopy = false
}
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: replyCopy,
Done: done,
SourceId: threadClient.SourceId,
Behavior: NormalRpc,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
select {
case err := <-done:
if err == nil && useCopy {
copyReply(reply, replyCopy)
}
if err == nil {
return nil
}
lastErr = err
case <-time.After(250 * time.Millisecond):
lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
return lastErr
default:
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
SourceId: threadClient.SourceId,
Behavior: behavior,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
}
// 等待响应或超时
select {
case err := <-done:
if threadClient.SourceId == "" { // 来自客户端的连接
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId)
select {
case err := <-done:
return err
case <-time.After(250 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
return err
case <-time.After(250 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
}

+ 0
- 6
internal/nodes/vote.go View File

@ -18,7 +18,6 @@ type RequestVoteArgs struct {
}
type RequestVoteReply struct {
Mu sync.Mutex
Term int // 当前节点的最新任期
VoteGranted bool // 是否同意投票
}
@ -82,7 +81,6 @@ func (n *Node) StartElection() {
return
}
reply.Mu.Lock()
if reply.Term > n.CurrTerm {
// 发现更高任期,回退为 Follower
log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term)
@ -91,14 +89,12 @@ func (n *Node) StartElection() {
n.VotedFor = ""
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
n.ResetElectionTimer()
reply.Mu.Unlock()
return
}
if reply.VoteGranted {
grantedVotes++
}
reply.Mu.Unlock()
if grantedVotes == totalNodes / 2 + 1 {
n.State = Leader
@ -151,8 +147,6 @@ func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error
n.Mu.Lock()
defer n.Mu.Unlock()
reply.Mu.Lock()
defer reply.Mu.Unlock()
// 如果候选人的任期小于当前任期,则拒绝投票
if args.Term < n.CurrTerm {
reply.Term = n.CurrTerm

Loading…
Cancel
Save