diff --git a/cmd/main.go b/cmd/main.go index b0b8bd8..0f813b0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,10 +29,9 @@ func main() { signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) port := flag.String("port", ":9091", "rpc listen port") - cluster := flag.String("cluster", "127.0.0.1:9092,127.0.0.1:9093", "comma sep") + cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep") id := flag.String("id", "1", "node ID") pipe := flag.String("pipe", "", "input from scripts") - isLeader := flag.Bool("isleader", false, "init node state") isNewDb := flag.Bool("isNewDb", true, "new test or restart") // 参数解析 @@ -47,6 +46,7 @@ func main() { for _, addr := range clusters { if idCnt == selfi { idCnt++ // 命令行cluster按id排序传入,记录时跳过自己的id,先保证所有节点互相记录的id一致 + continue } idClusterPairs[strconv.Itoa(idCnt)] = addr idCnt++ @@ -76,7 +76,7 @@ func main() { // 监听rpc node.Rpc(*port) // 开启 raft - nodes.Start(node, *isLeader) + nodes.Start(node) sig := <-sigs fmt.Println("node_" + *id + "接收到信号:", sig) diff --git a/internal/client/client_node.go b/internal/client/client_node.go index 7eaf99d..a144e10 100644 --- a/internal/client/client_node.go +++ b/internal/client/client_node.go @@ -26,31 +26,34 @@ const ( func (client *Client) Write(kvCall nodes.LogEntryCall) Status { log.Info("client write request key :" + kvCall.LogE.Key) - c, err := rpc.DialHTTP("tcp", client.Address) - if err != nil { - log.Error("dialing: ", zap.Error(err)) - return Fail - } - defer func(server *rpc.Client) { - err := c.Close() + var reply nodes.ServerReply + reply.Isleader = false + addr := client.Address + for !reply.Isleader { + c, err := rpc.DialHTTP("tcp", addr) if err != nil { - log.Error("client close err: ", zap.Error(err)) + log.Error("dialing: ", zap.Error(err)) + return Fail } - }(c) - var reply nodes.ServerReply - callErr := c.Call("Node.WriteKV", kvCall, &reply) // RPC - if callErr != nil { - log.Error("dialing: ", zap.Error(callErr)) - return Fail - } + callErr := c.Call("Node.WriteKV", kvCall, &reply) // RPC + if callErr != nil { + log.Error("dialing: ", zap.Error(callErr)) + return Fail + } + err = c.Close() + if err != nil { + log.Error("client close err: ", zap.Error(err)) + } - if reply.Isconnect { // 发送成功 - return Ok - } else { // 失败 - return Fail + if !reply.Isleader { // 发过去的不是leader + addr = reply.LeaderAddress + } else { // 成功 + return Ok + } } + return Fail } func (client *Client) Read(key string, value *string) Status { // 查不到value为空 @@ -79,15 +82,12 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value return Fail } - if reply.Isconnect { // 发送成功 - if reply.HaveValue { - *value = reply.Value - return Ok - } else { - return NotFound - } - } else { // 失败 - return Fail + // 目前一定发送成功 + if reply.HaveValue { + *value = reply.Value + return Ok + } else { + return NotFound } } diff --git a/internal/nodes/init.go b/internal/nodes/init.go index 1a68bdd..bda72d3 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -1,13 +1,12 @@ package nodes import ( - "io" + "fmt" + "math/rand" "net" "net/http" "net/rpc" - "os" "simple-kv-store/internal/logprovider" - "strconv" "time" "github.com/syndtr/goleveldb/leveldb" @@ -23,7 +22,7 @@ func newNode(address string) *Public_node_info { } } -func Init(id string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node { +func Init(selfId string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node { ns := make(map[string]*Public_node_info) for id, addr := range nodeAddr { ns[id] = newNode(addr) @@ -31,86 +30,125 @@ func Init(id string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *N // 创建节点 node := &Node{ - selfId: id, - nodes: ns, - pipeAddr: pipe, - maxLogId: -1, - currTerm: 1, - log: make([]RaftLogEntry, 0), + selfId: selfId, + leaderId: "", + nodes: ns, + pipeAddr: pipe, + maxLogId: -1, + currTerm: 1, + log: make([]RaftLogEntry, 0), commitIndex: -1, lastApplied: -1, - nextIndex: make(map[string]int), - matchIndex: make(map[string]int), - db: db, + nextIndex: make(map[string]int), + matchIndex: make(map[string]int), + db: db, } - for nodeId := range nodeAddr { - if nodeId != id { // 不初始化自身 - node.nextIndex[nodeId] = node.maxLogId + 1 - node.matchIndex[nodeId] = 0 - } - } + node.initLeaderState() return node } -func Start(node *Node, isLeader bool) { - if isLeader { - node.state = Candidate // 需要身份转变 - } else { - node.state = Follower +// func Start(node *Node, isLeader bool) { +// if isLeader { +// node.state = Candidate // 需要身份转变 +// } else { +// node.state = Follower +// } + +// go func() { +// for { +// switch node.state { +// case Follower: + +// case Candidate: +// // todo 成为leader的初始化 +// // node.currTerm = 1 + +// // candidate发布一个监听输入线程后,变成leader +// node.state = Leader +// go func() { +// if node.pipeAddr == "" { // 客户端远程调用server_node方法 +// log.Info("请运行客户端进程进行读写") +// } else { // 命令行提供了管道,支持管道(键盘)输入 +// pipe, err := os.Open(node.pipeAddr) +// if err != nil { +// log.Error("Failed to open pipe") +// } +// defer pipe.Close() + +// // 不断读取管道中的输入 +// buffer := make([]byte, 256) +// for { +// n, err := pipe.Read(buffer) +// if err != nil && err != io.EOF { +// log.Error("Error reading from pipe") +// } +// if n > 0 { +// input := string(buffer[:n]) +// // 将用户输入封装成一个 LogEntry +// kv := LogEntry{input, ""} // 目前键盘输入key,value 0 +// logId := node.maxLogId +// node.maxLogId++ +// node.log[logId] = RaftLogEntry{kv, logId, node.currTerm} + +// log.Info("send : logId = " + strconv.Itoa(logId) + ", key = " + input) +// // 广播给其它节点 +// node.BroadCastKV(Normal) +// // 持久化 +// node.db.Put([]byte(kv.Key), []byte(kv.Value), nil) +// } +// } +// } +// }() +// case Leader: +// time.Sleep(50 * time.Millisecond) +// } +// } +// }() +// } +func (n *Node) initLeaderState() { + for peerId := range n.nodes { + n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引 + n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引 } +} + +func Start(node *Node) { + node.state = Follower // 所有节点以 Follower 状态启动 + node.resetElectionTimer() // 启动选举超时定时器 go func() { for { switch node.state { case Follower: + // 监听心跳超时 + fmt.Printf("Node %s is a follower, 监听中...\n", node.selfId) - case Candidate: - // todo 成为leader的初始化 - // node.currTerm = 1 - - // candidate发布一个监听输入线程后,变成leader - node.state = Leader - go func() { - if node.pipeAddr == "" { // 客户端远程调用server_node方法 - log.Info("请运行客户端进程进行读写") - } else { // 命令行提供了管道,支持管道(键盘)输入 - pipe, err := os.Open(node.pipeAddr) - if err != nil { - log.Error("Failed to open pipe") - } - defer pipe.Close() - - // 不断读取管道中的输入 - buffer := make([]byte, 256) - for { - n, err := pipe.Read(buffer) - if err != nil && err != io.EOF { - log.Error("Error reading from pipe") - } - if n > 0 { - input := string(buffer[:n]) - // 将用户输入封装成一个 LogEntry - kv := LogEntry{input, ""} // 目前键盘输入key,value 0 - logId := node.maxLogId - node.maxLogId++ - node.log[logId] = RaftLogEntry{kv, logId, node.currTerm} - - log.Info("send : logId = " + strconv.Itoa(logId) + ", key = " + input) - // 广播给其它节点 - node.BroadCastKV(Normal) - // 持久化 - node.db.Put([]byte(kv.Key), []byte(kv.Value), nil) - } - } - } - }() case Leader: - time.Sleep(50 * time.Millisecond) + // 发送心跳 + fmt.Printf("Node %s is the leader, 发送心跳...\n", node.selfId) + node.BroadCastKV(Normal) } + time.Sleep(50 * time.Millisecond) } }() } +// follower 500-1000ms内没收到appendentries心跳,就变成candidate发起选举 +func (node *Node) resetElectionTimer() { + if node.electionTimer == nil { + node.electionTimer = time.NewTimer(time.Duration(500+rand.Intn(500)) * time.Millisecond) + go func() { + for { + <-node.electionTimer.C + node.startElection() + } + }() + } else { + node.electionTimer.Stop() + node.electionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond) + } +} + func (node *Node) Rpc(port string) { err := rpc.Register(node) if err != nil { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 9379e10..b5bc0c8 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -5,6 +5,7 @@ import ( "net/rpc" "sort" "strconv" + "sync" "time" "github.com/syndtr/goleveldb/leveldb" @@ -25,8 +26,11 @@ type Public_node_info struct { } type Node struct { + mu sync.Mutex // 当前节点id selfId string + // 记录的leader(不能用votedfor:投票的leader可能没有收到多数票) + leaderId string // 除当前节点外其他节点信息 nodes map[string]*Public_node_info @@ -60,6 +64,8 @@ type Node struct { db *leveldb.DB + votedFor string + electionTimer *time.Timer } func (node *Node) BroadCastKV(callMode CallMode) { @@ -72,6 +78,7 @@ func (node *Node) BroadCastKV(callMode CallMode) { } func (node *Node) sendKV(id string, callMode CallMode) { + switch callMode { case Fail: log.Info("模拟发送失败") @@ -96,6 +103,9 @@ func (node *Node) sendKV(id string, callMode CallMode) { } }(client) + node.mu.Lock() + defer node.mu.Unlock() + var appendReply AppendEntriesReply appendReply.Success = false nextIndex := node.nextIndex[id] @@ -110,6 +120,7 @@ func (node *Node) sendKV(id string, callMode CallMode) { PrevLogIndex: nextIndex - 1, Entries: sendEntries, LeaderCommit: node.commitIndex, + LeaderId: node.selfId, } if arg.PrevLogIndex >= 0 { arg.PrevLogTerm = node.log[arg.PrevLogIndex].Term @@ -133,9 +144,6 @@ func (node *Node) sendKV(id string, callMode CallMode) { } func (node *Node) updateCommitIndex() { - // node.mu.Lock() - // defer node.mu.Unlock() - totalNodes := len(node.nodes) // 收集所有 matchIndex 并排序 @@ -149,9 +157,9 @@ func (node *Node) updateCommitIndex() { majorityIndex := matchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派) // 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志 - if majorityIndex > node.commitIndex && node.log[majorityIndex].Term == node.currTerm { + if majorityIndex > node.commitIndex && majorityIndex < len(node.log) && node.log[majorityIndex].Term == node.currTerm { node.commitIndex = majorityIndex - log.Info("Leader 更新 commitIndex: " + strconv.Itoa(majorityIndex)) + log.Info("Leader" + node.selfId + "更新 commitIndex: " + strconv.Itoa(majorityIndex)) // 应用日志到状态机 node.applyCommittedLogs() @@ -173,14 +181,23 @@ func (node *Node) applyCommittedLogs() { // RPC call func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) error { - // node.mu.Lock() - // defer node.mu.Unlock() + node.mu.Lock() + defer node.mu.Unlock() // 1. 如果 term 过期,拒绝接受日志 if node.currTerm > arg.Term { *reply = AppendEntriesReply{node.currTerm, false} return nil } + + // todo: 这里也要持久化 + if node.leaderId != arg.LeaderId { + node.leaderId = arg.LeaderId // 记录Leader + } + + if node.currTerm < arg.Term { + node.currTerm = arg.Term + } // 2. 检查 prevLogIndex 是否有效 if arg.PrevLogIndex >= len(node.log) || (arg.PrevLogIndex >= 0 && node.log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { @@ -197,6 +214,7 @@ func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) // 4. 追加新的日志条目 for _, raftLogEntry := range arg.Entries { + log.Info(node.selfId + "结点写入" + raftLogEntry.print()) if idx < len(node.log) { node.log[idx] = raftLogEntry } else { @@ -218,13 +236,15 @@ func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) // 7. 提交已提交的日志 node.applyCommittedLogs() + // 8. 在成功接受日志或心跳后,重置选举超时 + node.resetElectionTimer() *reply = AppendEntriesReply{node.currTerm, true} return nil } type AppendEntriesArg struct { Term int - // leaderId string + LeaderId string PrevLogIndex int PrevLogTerm int Entries []RaftLogEntry diff --git a/internal/nodes/server_node.go b/internal/nodes/server_node.go index 7c30121..7d2d6e0 100644 --- a/internal/nodes/server_node.go +++ b/internal/nodes/server_node.go @@ -8,26 +8,39 @@ import ( // leader node作为server为client注册的方法 type ServerReply struct{ - Isconnect bool + Isleader bool + LeaderAddress string // 自己不是leader则返回leader地址 HaveValue bool Value string } // RPC call -func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error { +func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error { + log.Info(node.selfId + "收到客户端write请求") + if node.state != Leader { + reply.Isleader = false + if (node.leaderId == "") { + log.Fatal("还没选出第一个leader") + return nil + } + reply.LeaderAddress = node.nodes[node.leaderId].address + log.Info(node.selfId + "转交给" + node.leaderId) + return nil + } + node.maxLogId++ logId := node.maxLogId node.log = append(node.log, RaftLogEntry{kvCall.LogE, logId, node.currTerm}) // node.db.Put([]byte(kvCall.LogE.Key), []byte(kvCall.LogE.Value), nil) - log.Info("server write request : " + kvCall.LogE.print() + ", 模拟方式 : " + strconv.Itoa(int(kvCall.CallState))) + log.Info("leader" + node.selfId + "处理请求 : " + kvCall.LogE.print() + ", 模拟方式 : " + strconv.Itoa(int(kvCall.CallState))) // 广播给其它节点 node.BroadCastKV(kvCall.CallState) - reply.Isconnect = true + reply.Isleader = true return nil } // RPC call func (node *Node) ReadKey(key string, reply *ServerReply) error { log.Info("server read : " + key) - // 先只读leader自己 + // 先只读自己(无论自己是不是leader),也方便测试 value, err := node.db.Get([]byte(key), nil) if err == leveldb.ErrNotFound { reply.HaveValue = false @@ -35,7 +48,7 @@ func (node *Node) ReadKey(key string, reply *ServerReply) error { reply.HaveValue = true reply.Value = string(value) } - reply.Isconnect = true + reply.Isleader = true return nil } diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go new file mode 100644 index 0000000..4c2c3d6 --- /dev/null +++ b/internal/nodes/vote.go @@ -0,0 +1,186 @@ +package nodes + +import ( + "net/rpc" + "strconv" + "sync" + "time" + + "go.uber.org/zap" +) + +type RequestVoteArgs struct { + Term int // 候选人的当前任期 + CandidateId string // 候选人 ID + LastLogIndex int // 候选人最后一条日志的索引 + LastLogTerm int // 候选人最后一条日志的任期 +} + +type RequestVoteReply struct { + Term int // 当前节点的最新任期 + VoteGranted bool // 是否同意投票 +} + +func (n *Node) startElection() { + n.mu.Lock() + defer n.mu.Unlock() + // 1. 增加当前任期,转换为 Candidate + n.currTerm++ + n.state = Candidate + n.votedFor = n.selfId // 自己投自己 + + log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.selfId, n.currTerm) + + // 2. 重新设置选举超时,防止重复选举 + n.resetElectionTimer() + + // 3. 构造 RequestVote 请求 + var lastLogIndex int + var lastLogTerm int + + if len(n.log) == 0 { + lastLogIndex = 0 + lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0 + } else { + lastLogIndex = len(n.log) - 1 + lastLogTerm = n.log[lastLogIndex].Term + } + args := RequestVoteArgs{ + Term: n.currTerm, + CandidateId: n.selfId, + LastLogIndex: lastLogIndex, + LastLogTerm: lastLogTerm, + } + + // 4. 并行向其他节点发送请求投票 + var mu sync.Mutex + cond := sync.NewCond(&mu) + totalNodes := len(n.nodes) + grantedVotes := 1 // 自己的票 + + for peerId := range n.nodes { + go func(peerId string) { + reply := RequestVoteReply{} + if n.sendRequestVote(peerId, &args, &reply) { + mu.Lock() + + if reply.Term > n.currTerm { + // 发现更高任期,回退为 Follower + log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.selfId, reply.Term) + n.currTerm = reply.Term + n.state = Follower + n.votedFor = "" + n.resetElectionTimer() + mu.Unlock() + return + } + + if reply.VoteGranted { + grantedVotes++ + } + + if grantedVotes > totalNodes/2 { + n.state = Leader + log.Sugar().Infof("[%s] 当选 Leader!", n.selfId) + n.initLeaderState() + cond.Broadcast() + } + + mu.Unlock() + } + }(peerId) + } + + // 5. 等待选举结果 + timeout := time.After(300 * time.Millisecond) + + for { + mu.Lock() + if n.state != Candidate { // 选举成功或回退,不再等待 + mu.Unlock() + return + } + select { + case <-timeout: + log.Sugar().Infof("[%s] 选举超时,重新发起选举", n.selfId) + mu.Unlock() + return + default: + cond.Wait() + } + mu.Unlock() + } +} + +func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool { + log.Sugar().Infof("Sending RequestVote to %s at %s", peerId, node.nodes[peerId].address) + client, err := rpc.DialHTTP("tcp", node.nodes[peerId].address) + if err != nil { + log.Error("dialing: ", zap.Error(err)) + return false + } + + defer func(client *rpc.Client) { + err := client.Close() + if err != nil { + log.Error("client close err: ", zap.Error(err)) + } + }(client) + + callErr := client.Call("Node.RequestVote", args, reply) // RPC + if callErr != nil { + log.Error("dialing node_"+peerId+"fail: ", zap.Error(callErr)) + } + return callErr == nil +} + +func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error { + n.mu.Lock() + log.Info(n.selfId) + defer n.mu.Unlock() + // 1. 如果候选人的任期小于当前任期,则拒绝投票 + if args.Term < n.currTerm { + reply.Term = n.currTerm + reply.VoteGranted = false + return nil + } + + // 2. 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower + if args.Term > n.currTerm { + n.currTerm = args.Term + n.state = Follower + n.votedFor = "" + n.resetElectionTimer() // 重新设置选举超时 + } + + // 3. 检查是否已经投过票,且是否投给了同一个候选人 + if n.votedFor == "" || n.votedFor == args.CandidateId { + // 4. 检查日志是否足够新 + var lastLogIndex int + var lastLogTerm int + + if len(n.log) == 0 { + lastLogIndex = -1 + lastLogTerm = 0 + } else { + lastLogIndex = len(n.log) - 1 + lastLogTerm = n.log[lastLogIndex].Term + } + + if args.LastLogTerm > lastLogTerm || + (args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) { + // 5. 投票给候选人 + n.votedFor = args.CandidateId + log.Info("term" + strconv.Itoa(n.currTerm) + ", " + n.selfId + "投票给" + n.votedFor) + reply.VoteGranted = true + n.resetElectionTimer() + } else { + reply.VoteGranted = false + } + } else { + reply.VoteGranted = false + } + + reply.Term = n.currTerm + return nil +} \ No newline at end of file diff --git a/test/common.go b/test/common.go index 554927d..65c5929 100644 --- a/test/common.go +++ b/test/common.go @@ -8,16 +8,9 @@ import ( "strings" ) -func ExecuteNodeI(i int, isLeader bool, isNewDb bool, clusters []string) *exec.Cmd { - tmpClusters := append(clusters[:i], clusters[i+1:]...) +func ExecuteNodeI(i int, isNewDb bool, clusters []string) *exec.Cmd { port := fmt.Sprintf(":%d", uint16(9090)+uint16(i)) - var isleader string - if isLeader { - isleader = "true" - } else { - isleader = "false" - } var isnewdb string if isNewDb { isnewdb = "true" @@ -28,8 +21,7 @@ func ExecuteNodeI(i int, isLeader bool, isNewDb bool, clusters []string) *exec.C "../main", "-id", strconv.Itoa(i + 1), "-port", port, - "-cluster", strings.Join(tmpClusters, ","), - "-isleader=" + isleader, + "-cluster", strings.Join(clusters, ","), "-isNewDb=" + isnewdb, ) cmd.Stdout = os.Stdout diff --git a/test/restart_follower_test.go b/test/restart_follower_test.go index 323507d..3624ad2 100644 --- a/test/restart_follower_test.go +++ b/test/restart_follower_test.go @@ -26,9 +26,9 @@ func TestFollowerRestart(t *testing.T) { for i := 0; i < n; i++ { var cmd *exec.Cmd if i == 0 { - cmd = ExecuteNodeI(i, true, true, clusters) + cmd = ExecuteNodeI(i, true, clusters) } else { - cmd = ExecuteNodeI(i, false, true, clusters) + cmd = ExecuteNodeI(i, true, clusters) } if cmd == nil { @@ -69,7 +69,7 @@ func TestFollowerRestart(t *testing.T) { } } // 恢复结点 - cmd := ExecuteNodeI(n - 1, false, false, clusters) + cmd := ExecuteNodeI(n - 1, false, clusters) if cmd == nil { t.Errorf("recover test1 fail") return diff --git a/test/server_client_test.go b/test/server_client_test.go index 4a6e489..72ad6d1 100644 --- a/test/server_client_test.go +++ b/test/server_client_test.go @@ -24,12 +24,7 @@ func TestServerClient(t *testing.T) { // 结点启动 var cmds []*exec.Cmd for i := 0; i < n; i++ { - var cmd *exec.Cmd - if i == 0 { - cmd = ExecuteNodeI(i, true, true, clusters) - } else { - cmd = ExecuteNodeI(i, false, true, clusters) - } + cmd := ExecuteNodeI(i, true, clusters) if cmd == nil { return @@ -40,7 +35,7 @@ func TestServerClient(t *testing.T) { time.Sleep(time.Second) // 等待启动完毕 // client启动 - c := clientPkg.Client{Address: "127.0.0.1:9090", ServerId: "1"} + c := clientPkg.Client{Address: "127.0.0.1:9092", ServerId: "3"} // 写入 var s clientPkg.Status