diff --git a/.gitignore b/.gitignore index 7091e8e..c070dfd 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ go.work main leveldb +storage diff --git a/cmd/main.go b/cmd/main.go index 0f813b0..4170119 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "github.com/syndtr/goleveldb/leveldb" "os" "os/signal" "simple-kv-store/internal/logprovider" @@ -10,7 +11,6 @@ import ( "strconv" "strings" "syscall" - "github.com/syndtr/goleveldb/leveldb" "go.uber.org/zap" ) @@ -48,15 +48,16 @@ func main() { idCnt++ // 命令行cluster按id排序传入,记录时跳过自己的id,先保证所有节点互相记录的id一致 continue } - idClusterPairs[strconv.Itoa(idCnt)] = addr + idClusterPairs[strconv.Itoa(idCnt)] = addr idCnt++ } if *isNewDb { os.RemoveAll("leveldb/simple-kv-store" + *id) + os.RemoveAll("storage/node" + *id + ".json") } // 打开或创建每个结点自己的数据库 - db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil) + db, err := leveldb.OpenFile("leveldb/simple-kv-store"+*id, nil) if err != nil { log.Fatal("Failed to open database: ", zap.Error(err)) } @@ -64,14 +65,17 @@ func main() { iter := db.NewIterator(nil, nil) defer iter.Release() + // 打开或创建节点数据持久化文件 + storage := nodes.NewRaftStorage("storage/node" + *id + ".json") + // 计数 count := 0 for iter.Next() { count++ } - fmt.Printf(*id + "结点目前有数据:%d\n", count) + fmt.Printf(*id+"结点目前有数据:%d\n", count) - node := nodes.Init(*id, idClusterPairs, *pipe, db) + node := nodes.Init(*id, idClusterPairs, *pipe, db, storage) log.Info("id: " + *id + "节点开始监听: " + *port + "端口") // 监听rpc node.Rpc(*port) @@ -79,6 +83,6 @@ func main() { nodes.Start(node) sig := <-sigs - fmt.Println("node_" + *id + "接收到信号:", sig) + fmt.Println("node_"+*id+"接收到信号:", sig) } diff --git a/internal/nodes/init.go b/internal/nodes/init.go index bda72d3..2cdfeff 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -22,7 +22,7 @@ func newNode(address string) *Public_node_info { } } -func Init(selfId string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node { +func Init(selfId string, nodeAddr map[string]string, pipe string, db *leveldb.DB, rstorage *RaftStorage) *Node { ns := make(map[string]*Public_node_info) for id, addr := range nodeAddr { ns[id] = newNode(addr) @@ -34,7 +34,7 @@ func Init(selfId string, nodeAddr map[string]string, pipe string, db *leveldb.DB leaderId: "", nodes: ns, pipeAddr: pipe, - maxLogId: -1, + maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 currTerm: 1, log: make([]RaftLogEntry, 0), commitIndex: -1, @@ -42,6 +42,7 @@ func Init(selfId string, nodeAddr map[string]string, pipe string, db *leveldb.DB nextIndex: make(map[string]int), matchIndex: make(map[string]int), db: db, + storage: rstorage, } node.initLeaderState() return node @@ -110,6 +111,7 @@ func (n *Node) initLeaderState() { n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引 n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引 } + n.storage.SetTermAndVote(n.currTerm, n.votedFor) } func Start(node *Node) { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index b5bc0c8..74da648 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -62,7 +62,10 @@ type Node struct { // 已经发送给每个节点的最大索引 matchIndex map[string]int + // 存kv(模拟状态机) db *leveldb.DB + // 持久化节点数据(currterm votedfor log) + storage *RaftStorage votedFor string electionTimer *time.Timer @@ -131,8 +134,13 @@ func (node *Node) sendKV(id string, callMode CallMode) { } if appendReply.Term != node.currTerm { - // 转变成follower? - break + log.Info("Leader " + node.selfId + " 收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower") + node.currTerm = appendReply.Term + node.state = Follower + node.votedFor = "" + node.storage.SetTermAndVote(node.currTerm, node.votedFor) + node.resetElectionTimer() + return } nextIndex-- // 失败往前传一格 } @@ -190,14 +198,16 @@ func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) return nil } - // todo: 这里也要持久化 - if node.leaderId != arg.LeaderId { - node.leaderId = arg.LeaderId // 记录Leader - } + node.leaderId = arg.LeaderId // 记录Leader if node.currTerm < arg.Term { - node.currTerm = arg.Term + log.Info("Node " + node.selfId + " 发现更高 term=" + strconv.Itoa(arg.Term)) + node.currTerm = arg.Term + node.state = Follower + node.votedFor = "" + node.storage.SetTermAndVote(node.currTerm, node.votedFor) } + node.storage.SetTermAndVote(node.currTerm, node.votedFor) // 2. 检查 prevLogIndex 是否有效 if arg.PrevLogIndex >= len(node.log) || (arg.PrevLogIndex >= 0 && node.log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { @@ -206,10 +216,13 @@ func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) } // 3. 处理日志冲突(如果存在不同 term,则截断日志) - idx := arg.PrevLogIndex + 1 - if idx < len(node.log) && node.log[idx].Term != arg.Entries[0].Term { - node.log = node.log[:idx] // 截断冲突日志 - } + idx := arg.PrevLogIndex + 1 + for i := idx; i < len(node.log) && i-idx < len(arg.Entries); i++ { + if node.log[i].Term != arg.Entries[i-idx].Term { + node.log = node.log[:idx] + break + } + } // log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.log))) // 4. 追加新的日志条目 @@ -223,20 +236,23 @@ func (node *Node) AppendEntries(arg AppendEntriesArg, reply *AppendEntriesReply) idx++ } - // 5. 更新 maxLogId + // 暴力持久化 + node.storage.WriteLog(node.log) + + // 更新 maxLogId node.maxLogId = len(node.log) - 1 - // 6. 更新 commitIndex + // 更新 commitIndex if arg.LeaderCommit < node.maxLogId { node.commitIndex = arg.LeaderCommit } else { node.commitIndex = node.maxLogId } - // 7. 提交已提交的日志 + // 提交已提交的日志 node.applyCommittedLogs() - // 8. 在成功接受日志或心跳后,重置选举超时 + // 在成功接受日志或心跳后,重置选举超时 node.resetElectionTimer() *reply = AppendEntriesReply{node.currTerm, true} return nil diff --git a/internal/nodes/node_storage.go b/internal/nodes/node_storage.go new file mode 100644 index 0000000..192c222 --- /dev/null +++ b/internal/nodes/node_storage.go @@ -0,0 +1,147 @@ +package nodes + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + + "go.uber.org/zap" +) + +// RaftStorage 结构,持久化 currentTerm、votedFor 和 logEntries +type RaftStorage struct { + mu sync.Mutex + filePath string + CurrentTerm int `json:"current_term"` + VotedFor string `json:"voted_for"` + LogEntries []RaftLogEntry `json:"log_entries"` +} + +// NewRaftStorage 创建 Raft 存储 +func NewRaftStorage(filePath string) *RaftStorage { + storage := &RaftStorage{ + filePath: filePath, + } + storage.loadData() // 载入已有数据 + return storage +} + +// loadData 读取 JSON 文件数据 +func (rs *RaftStorage) loadData() { + rs.mu.Lock() + defer rs.mu.Unlock() + + file, err := os.Open(rs.filePath) + if err != nil { + log.Info("文件不存在:" + rs.filePath) + rs.saveData() // 文件不存在时创建默认数据 + return + } + defer file.Close() + + err = json.NewDecoder(file).Decode(rs) + if err != nil { + log.Error("读取文件失败:" + rs.filePath) + } +} + +// 持久化数据到 JSON(必须持有锁,不能直接外部调用) +func (rs *RaftStorage) saveData() { + // 获取文件所在的目录 + dir := filepath.Dir(rs.filePath) + + // 确保目录存在 + if err := os.MkdirAll(dir, 0755); err != nil { + log.Error("创建存储目录失败", zap.Error(err)) + return + } + + file, err := os.Create(rs.filePath) + if err != nil { + log.Error("持久化节点出错", zap.Error(err)) + return + } + defer file.Close() + + err = json.NewEncoder(file).Encode(rs) + if err != nil { + log.Error("持久化写入失败") + } +} + +// SetCurrentTerm 设置当前 term,并清空 votedFor(符合 Raft 规范) +func (rs *RaftStorage) SetCurrentTerm(term int) { + rs.mu.Lock() + defer rs.mu.Unlock() + if term > rs.CurrentTerm { + rs.CurrentTerm = term + rs.VotedFor = "" // 新任期清空投票 + rs.saveData() + } +} + +// GetCurrentTerm 获取当前 term +func (rs *RaftStorage) GetCurrentTerm() int { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.CurrentTerm +} + +// SetVotedFor 记录投票给谁 +func (rs *RaftStorage) SetVotedFor(candidate string) { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.VotedFor = candidate + rs.saveData() +} + +// GetVotedFor 获取投票对象 +func (rs *RaftStorage) GetVotedFor() string { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.VotedFor +} + +func (rs *RaftStorage) SetTermAndVote(term int, candidate string) { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.VotedFor = candidate + rs.CurrentTerm = term + rs.saveData() +} + +// append日志 +func (rs *RaftStorage) AppendLog(rlogE RaftLogEntry) { + rs.mu.Lock() + defer rs.mu.Unlock() + + rs.LogEntries = append(rs.LogEntries, rlogE) + rs.saveData() +} + +// 更改日志 +func (rs *RaftStorage) WriteLog(rlogEs []RaftLogEntry) { + rs.mu.Lock() + defer rs.mu.Unlock() + + rs.LogEntries = rlogEs + rs.saveData() +} + +// 获取所有日志 +func (rs *RaftStorage) GetLogEntries() []RaftLogEntry { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.LogEntries +} + +// GetLastLogIndex 获取最新日志的 index +func (rs *RaftStorage) GetLastLogIndex() int { + rs.mu.Lock() + defer rs.mu.Unlock() + if len(rs.LogEntries) == 0 { + return 0 + } + return len(rs.LogEntries)-1 +} diff --git a/internal/nodes/server_node.go b/internal/nodes/server_node.go index 7d2d6e0..da41d24 100644 --- a/internal/nodes/server_node.go +++ b/internal/nodes/server_node.go @@ -29,8 +29,9 @@ func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error { node.maxLogId++ logId := node.maxLogId - node.log = append(node.log, RaftLogEntry{kvCall.LogE, logId, node.currTerm}) - // node.db.Put([]byte(kvCall.LogE.Key), []byte(kvCall.LogE.Value), nil) + rLogE := RaftLogEntry{kvCall.LogE, logId, node.currTerm} + node.log = append(node.log, rLogE) + node.storage.AppendLog(rLogE) log.Info("leader" + node.selfId + "处理请求 : " + kvCall.LogE.print() + ", 模拟方式 : " + strconv.Itoa(int(kvCall.CallState))) // 广播给其它节点 node.BroadCastKV(kvCall.CallState) diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go index 4c2c3d6..cf3799b 100644 --- a/internal/nodes/vote.go +++ b/internal/nodes/vote.go @@ -28,6 +28,7 @@ func (n *Node) startElection() { n.currTerm++ n.state = Candidate n.votedFor = n.selfId // 自己投自己 + n.storage.SetTermAndVote(n.currTerm, n.votedFor) log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.selfId, n.currTerm) @@ -70,6 +71,7 @@ func (n *Node) startElection() { n.currTerm = reply.Term n.state = Follower n.votedFor = "" + n.storage.SetTermAndVote(n.currTerm, n.votedFor) n.resetElectionTimer() mu.Unlock() return @@ -136,7 +138,6 @@ func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *R func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error { n.mu.Lock() - log.Info(n.selfId) defer n.mu.Unlock() // 1. 如果候选人的任期小于当前任期,则拒绝投票 if args.Term < n.currTerm { @@ -181,6 +182,7 @@ func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error reply.VoteGranted = false } + n.storage.SetTermAndVote(n.currTerm, n.votedFor) reply.Term = n.currTerm return nil } \ No newline at end of file