diff --git a/.gitignore b/.gitignore index 7091e8e..afb9a4e 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ go.work main leveldb +storage +*.log +testdata diff --git a/README.md b/README.md index 9a0791c..2c56ccc 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,240 @@ # go-raft-kv --- +# 简介 +本项目是基于go语言实现的一个raft算法分布式kv数据库。 -基于go语言实现分布式kv数据库 +项目亮点如下: +支持线程与进程两种通信机制,方便测试与部署切换的同时,降低了系统耦合度 + +提供基于状态不变量的fuzz测试机制,增强健壮性 + +项目高度模块化,便于扩展更多功能 + +本报告主要起到工作总结,以及辅助阅读代码的作用(并贴出了一些关键部分),一些工作中遇到的具体问题、思考和收获不太方便用简洁的文字描述,放在了汇报的ppt中。 + +# 项目框架 +``` +---cmd + ---main.go 进程版的启动 +---internal + ---client 客户端使用节点提供的读写功能 + ---logprovider 封装了简单的日志打印,方便调试 + ---nodes 分布式核心代码 + init.go 节点的初始化(包含两种版本),和大循环启动 + log.go 节点存储的entry相关数据结构 + node_storage.go 抽象了节点数据持久化方法,序列化后存到leveldb里 + node.go 节点的相关数据结构 + random_timetake.go 控制系统中的随机时间 + real_transport.go 进程版的rpc通讯 + replica.go 日志复制相关逻辑 + server_node.go 节点作为server为 client提供的功能(读写) + simulate_ctx.go 测试中控制通讯消息行为 + thread_transport.go 线程版的通讯方法 + transport.go 为两种系统提供的基类通讯接口 + vote.go 选主相关逻辑 +---test 进程版的测试 +---threadTest 线程版的测试 + ---fuzz 随机测试部分 + election_test.go 选举部分 + log_replication_test.go 日志复制部分 + network_partition_test.go 网络分区部分 + restart_node_test.go 恢复测试 + server_client_test.go 客户端交互 +``` + +# raft系统部分 +## 主要流程 +在init.go中每个节点会初始化,发布监听线程,然后在start函数开启主循环。主循环中每隔一段心跳时间,如果判断自己是leader,就broadcast并resetElectionTime。(init.go) +```go +func Start(node *Node, quitChan chan struct{}) { + node.Mu.Lock() + node.State = Follower // 所有节点以 Follower 状态启动 + node.Mu.Unlock() + node.ResetElectionTimer() // 启动选举超时定时器 + + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-quitChan: + fmt.Printf("[%s] Raft start 退出...\n", node.SelfId) + return // 退出 goroutine + + case <-ticker.C: + node.Mu.Lock() + state := node.State + node.Mu.Unlock() + + switch state { + case Follower: + // 监听心跳超时 + + case Leader: + // 发送心跳 + node.ResetElectionTimer() // leader 不主动触发选举 + node.BroadCastKV() + } + } + } + }() +} +``` + +两个监听方法: +appendEntries:broadcast中遍历每个peerNode,在sendkv中进行call,实现日志复制相关逻辑。 +requestVote:每个node有个ResetElectionTimer定时器,一段时间没有reset它就会StartElection,其中遍历每个peerNode,在sendRequestVote中进行call,实现选主相关逻辑。leader会在心跳时reset(避免自己的再选举),follower则会在收到appendentries时reset。 + +```go +func (node *Node) ResetElectionTimer() { + node.MuElection.Lock() + defer node.MuElection.Unlock() + if node.ElectionTimer == nil { + node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout()) + go func() { + for { + <-node.ElectionTimer.C + node.StartElection() + } + }() + } else { + node.ElectionTimer.Stop() + node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond) + } +} +``` + +## 客户端工作原理(client & server_node.go) +### 客户端写 +客户端每次会随机连上集群中一个节点,此时有四种情况: +a 节点认为自己是leader,直接处理请求(记录后broadcast) +b 节点认为自己是follower,且有知道的leader,返回leader的id。客户端再连接这个新的id,新节点重新分析四种情况。 +c 节点认为自己是follower,但不知道leader是谁,返回空的id。客户端再随机连一个节点 +d 连接超时,客户端重新随机连一个节点 +```go +func (client *Client) Write(kv nodes.LogEntry) Status { + kvCall := nodes.LogEntryCall{LogE: kv, + Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}} + client.NextLogId++ + + c := client.FindActiveNode() + var err error + + timeout := time.Second + deadline := time.Now().Add(timeout) + + for { // 根据存活节点的反馈,直到找到leader + if time.Now().After(deadline) { + return Fail + } + + var reply nodes.ServerReply + reply.Isleader = false + + callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC + if callErr != nil { // dial和call之间可能崩溃,重新找存活节点 + log.Error("dialing: ", zap.Error(callErr)) + client.CloseRpcClient(c) + c = client.FindActiveNode() + continue + } + + if !reply.Isleader { // 对方不是leader,根据反馈找leader + leaderId := reply.LeaderId + client.CloseRpcClient(c) + if leaderId == "" { // 这个节点不知道leader是谁,再随机找 + c = client.FindActiveNode() + } else { // dial leader + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId) + if err != nil { // dial失败,重新找下一个存活节点 + c = client.FindActiveNode() + } + } + } else { // 成功 + client.CloseRpcClient(c) + return Ok + } + } +} +``` + +### 客户端读 +随机连上集群中一个节点,读它commit的kv。 + +## 重要数据持久化 +封装在了node_storage.go中,主要是setTermAndVote序列化后写入leveldb,log的写入。在这些数据变化时调用它们进行持久化,以及相应的恢复时读取。 +```go +// SetTermAndVote 原子更新 term 和 vote +func (rs *RaftStorage) SetTermAndVote(term int, candidate string) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.isfinish { + return + } + + batch := new(leveldb.Batch) + batch.Put([]byte("current_term"), []byte(strconv.Itoa(term))) + batch.Put([]byte("voted_for"), []byte(candidate)) + + err := rs.db.Write(batch, nil) // 原子提交 + if err != nil { + log.Error("SetTermAndVote 持久化失败:", zap.Error(err)) + } +} +``` + +## 两种系统的切换 +将raft系统中,所有涉及网络的部分提取出来,抽象为dial和call方法,作为每个node的接口类transport的两个基类方法,进程版与线程版的transport派生类分别实现,使得相互之间实现隔离。 +```go +type Transport interface { + DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) + CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error +} +``` + +进程版:dial和call均为go原生rpc库的方法,加一层timeout封装(real_transport.go) +线程版:threadTransport为每个节点共用,节点初始化时把一个自己的chan注册进里面的map,然后go一个线程去监听它,收到req后去调用自己对应的函数(thread_transport.go) + +# 测试部分 +## 单元测试 +从Leader选举、日志复制、崩溃恢复、网络分区、客户端交互五个维度,对系统进行分模块的测试。测试中夹杂消息状态的细粒度模拟,尽可能在项目前中期验证代码与思路的一致性,避免大的问题。 + +## fuzz测试 +分为不同节点、系统随机时间配置测试异常的多系统随机(basic),与对单个系统注入多个随机异常的单系统随机(robust),这两个维度,以及最后综合两个维度的进一步测试(plus)。 +测试中加入了raft的TLA标准,作为测试断言,确保系统在运行中的稳定性。 +fuzz test不仅覆盖了单元测试的内容,也在随机的测试中发现了更多边界条件的异常,以及通过系统状态的不变量检测,确保系统在不同配置下支持长时间的运行中保持正确可用。 + +![alt text](pics/plus.png) + +![alt text](pics/robust.png) + +## bug的简单记录 +LogId0的歧义,不同接口对接日志编号出现问题 +随机选举超时相同导致的candidate卡死问题 +重要数据持久化不原子、与状态机概念混淆 +客户端缺乏消息唯一标识,导致系统重复执行 +重构系统过程中lock使用不当 +伪同步接口的异步语义陷阱 +测试和系统混合产生的bug(延迟导致的超时、退出不完全导致的异常、文件系统异常、lock不当) # 环境与运行 使用环境是wsl+ubuntu go mod download安装依赖 -./scripts/build.sh 会在根目录下编译出main -./scripts/run.sh 运行三个节点,目前能在终端进行读入,leader(n1)节点输出send log,其余节点输出receive log。终端输入后如果超时就退出(脚本运行时间可以在其中调整)。 - -# 注意 -脚本第一次运行需要权限获取 chmod +x <脚本> -如果出现tcp listen error可能是因为之前的进程没用正常退出,占用了端口 -lsof -i :9091查看pid -kill -9 杀死进程 -## 关于测试 -通过新开进程的方式创建节点,如果通过线程创建,会出现重复注册rpc问题 - -# todo list -消息通讯异常的处理 -kv本地持久化 -崩溃与恢复(以及对应的测试) \ No newline at end of file +./scripts/build.sh 会在根目录下编译出main(进程级的测试需要) + +# 参考资料 +In Search of an Understandable Consensus Algorithm +Consensus: Bridging Theory and Practice +Raft TLA+ Specification +全项目除了logprovider文件夹下的一些go的日志库使用参考了一篇博客的封装,其余皆为独立原创。 + +# 分工 +| 姓名 | 工作 | 贡献度 | +|--------|--------|--------| +| 李度 | raft系统设计+实现,测试设计+实现 | 75% | +| 马也驰 | raft系统设计,测试设计+实现 | 25% | + +![alt text](pics/plan.png) + diff --git a/cmd/main.go b/cmd/main.go index b0b8bd8..b3e9a2a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "github.com/syndtr/goleveldb/leveldb" "os" "os/signal" "simple-kv-store/internal/logprovider" @@ -10,7 +11,6 @@ import ( "strconv" "strings" "syscall" - "github.com/syndtr/goleveldb/leveldb" "go.uber.org/zap" ) @@ -29,11 +29,9 @@ func main() { signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) port := flag.String("port", ":9091", "rpc listen port") - cluster := flag.String("cluster", "127.0.0.1:9092,127.0.0.1:9093", "comma sep") + cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep") id := flag.String("id", "1", "node ID") - pipe := flag.String("pipe", "", "input from scripts") - isLeader := flag.Bool("isleader", false, "init node state") - isNewDb := flag.Bool("isNewDb", true, "new test or restart") + isRestart := flag.Bool("isRestart", false, "new test or restart") // 参数解析 flag.Parse() @@ -42,43 +40,46 @@ func main() { idCnt := 1 selfi, err := strconv.Atoi(*id) if err != nil { - log.Error("figure id only") + log.Fatal("figure id only") } for _, addr := range clusters { if idCnt == selfi { idCnt++ // 命令行cluster按id排序传入,记录时跳过自己的id,先保证所有节点互相记录的id一致 + continue } - idClusterPairs[strconv.Itoa(idCnt)] = addr + idClusterPairs[strconv.Itoa(idCnt)] = addr idCnt++ } - if *isNewDb { - os.RemoveAll("leveldb/simple-kv-store" + *id) + // storage/文件夹下为node重要数据持久化数据库,节点一旦创建成功就不能被删除 + if !*isRestart { + os.RemoveAll("storage/node" + *id) } - // 打开或创建每个结点自己的数据库 + + // 创建每个结点自己的数据库。这里一开始理解上有些误区,状态机的状态恢复应该靠节点的持久化log, + // 而用leveldb模拟状态机,造成了状态机本身的持久化,因此通过删去旧db避免这一矛盾 + // 因此leveldb/文件夹下为状态机模拟数据库,每次节点启动都需要删除该数据库 + os.RemoveAll("leveldb/simple-kv-store" + *id) + db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil) if err != nil { log.Fatal("Failed to open database: ", zap.Error(err)) } defer db.Close() // 确保数据库在使用完毕后关闭 - iter := db.NewIterator(nil, nil) - defer iter.Release() - // 计数 - count := 0 - for iter.Next() { - count++ - } - fmt.Printf(*id + "结点目前有数据:%d\n", count) + // 打开或创建节点数据持久化文件 + storage := nodes.NewRaftStorage("storage/node" + *id) + defer storage.Close() + + // 初始化 + node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, *isRestart) - node := nodes.Init(*id, idClusterPairs, *pipe, db) - log.Info("id: " + *id + "节点开始监听: " + *port + "端口") - // 监听rpc - node.Rpc(*port) // 开启 raft - nodes.Start(node, *isLeader) + quitChan := make(chan struct{}, 1) + nodes.Start(node, quitChan) sig := <-sigs - fmt.Println("node_" + *id + "接收到信号:", sig) + fmt.Println("node_"+ *id +"接收到信号:", sig) + close(quitChan) } diff --git a/internal/client/client_node.go b/internal/client/client_node.go index 7eaf99d..6e77a63 100644 --- a/internal/client/client_node.go +++ b/internal/client/client_node.go @@ -1,9 +1,10 @@ package clientPkg import ( - "net/rpc" + "math/rand" "simple-kv-store/internal/logprovider" "simple-kv-store/internal/nodes" + "time" "go.uber.org/zap" ) @@ -11,9 +12,11 @@ import ( var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) type Client struct { - // 连接的server端节点(node1) - ServerId string - Address string + ClientId string // 每个client唯一标识 + NextLogId int + // 连接的server端节点群 + PeerIds []string + Transport nodes.Transport } type Status = uint8 @@ -24,32 +27,83 @@ const ( Fail ) -func (client *Client) Write(kvCall nodes.LogEntryCall) Status { - log.Info("client write request key :" + kvCall.LogE.Key) - c, err := rpc.DialHTTP("tcp", client.Address) - if err != nil { - log.Error("dialing: ", zap.Error(err)) - return Fail - } +func NewClient(clientId string, peerIds []string, transport nodes.Transport) *Client { + return &Client{ClientId: clientId, NextLogId: 0, PeerIds: peerIds, Transport: transport} +} - defer func(server *rpc.Client) { - err := c.Close() +func getRandomAddress(peerIds []string) string { + // 随机选一个 id + randomKey := peerIds[rand.Intn(len(peerIds))] + return randomKey +} + +func (client *Client) FindActiveNode() nodes.ClientInterface { + var err error + var c nodes.ClientInterface + for { // 直到找到一个可连接的节点(保证至少一个节点活着) + peerId := getRandomAddress(client.PeerIds) + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", peerId) if err != nil { - log.Error("client close err: ", zap.Error(err)) + log.Error("dialing: ", zap.Error(err)) + } else { + log.Sugar().Infof("client发现活跃节点[%s]", peerId) + return c } - }(c) + } +} - var reply nodes.ServerReply - callErr := c.Call("Node.WriteKV", kvCall, &reply) // RPC - if callErr != nil { - log.Error("dialing: ", zap.Error(callErr)) - return Fail +func (client *Client) CloseRpcClient(c nodes.ClientInterface) { + err := c.Close() + if err != nil { + log.Error("client close err: ", zap.Error(err)) } +} + +func (client *Client) Write(kv nodes.LogEntry) Status { + defer logprovider.DebugTraceback("client") + log.Info("client write request key :" + kv.Key) + kvCall := nodes.LogEntryCall{LogE: kv, + Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}} + client.NextLogId++ + + c := client.FindActiveNode() + var err error + + timeout := time.Second + deadline := time.Now().Add(timeout) + + for { // 根据存活节点的反馈,直到找到leader + if time.Now().After(deadline) { + log.Error("系统繁忙,疑似出错") + return Fail + } + + var reply nodes.ServerReply + reply.Isleader = false + + callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC + if callErr != nil { // dial和call之间可能崩溃,重新找存活节点 + log.Error("dialing: ", zap.Error(callErr)) + client.CloseRpcClient(c) + c = client.FindActiveNode() + continue + } - if reply.Isconnect { // 发送成功 - return Ok - } else { // 失败 - return Fail + if !reply.Isleader { // 对方不是leader,根据反馈找leader + leaderId := reply.LeaderId + client.CloseRpcClient(c) + if leaderId == "" { // 这个节点不知道leader是谁,再随机找 + c = client.FindActiveNode() + } else { // dial leader + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId) + if err != nil { // dial失败,重新找下一个存活节点 + c = client.FindActiveNode() + } + } + } else { // 成功 + client.CloseRpcClient(c) + return Ok + } } } @@ -58,37 +112,59 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value if value == nil { return Fail } - - c, err := rpc.DialHTTP("tcp", client.Address) - if err != nil { - log.Error("dialing: ", zap.Error(err)) - return Fail - } - - defer func(server *rpc.Client) { - err := c.Close() - if err != nil { - log.Error("client close err: ", zap.Error(err)) + var c nodes.ClientInterface + for { + c = client.FindActiveNode() + + var reply nodes.ServerReply + callErr := client.Transport.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC + if callErr != nil { + log.Error("dialing: ", zap.Error(callErr)) + client.CloseRpcClient(c) + continue } - }(c) - - var reply nodes.ServerReply - callErr := c.Call("Node.ReadKey", key, &reply) // RPC - if callErr != nil { - log.Error("dialing: ", zap.Error(callErr)) - return Fail - } - if reply.Isconnect { // 发送成功 + // 目前一定发送成功 if reply.HaveValue { *value = reply.Value + client.CloseRpcClient(c) return Ok } else { + client.CloseRpcClient(c) return NotFound + } + } +} + +func (client *Client) FindLeader() string { + var arg struct{} + var reply nodes.FindLeaderReply + reply.Isleader = false + c := client.FindActiveNode() + var err error + + for !reply.Isleader { // 根据存活节点的反馈,直到找到leader + callErr := client.Transport.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC + if callErr != nil { // dial和call之间可能崩溃,重新找存活节点 + log.Error("dialing: ", zap.Error(callErr)) + client.CloseRpcClient(c) + c = client.FindActiveNode() + continue } - } else { // 失败 - return Fail + + if !reply.Isleader { // 对方不是leader,根据反馈找leader + client.CloseRpcClient(c) + c, err = client.Transport.DialHTTPWithTimeout("tcp", "", reply.LeaderId) + for err != nil { // 重新找下一个存活节点 + c = client.FindActiveNode() + } + } else { // 成功 + client.CloseRpcClient(c) + return reply.LeaderId + } } + log.Fatal("客户端会一直找存活节点,不会运行到这里") + return "fault" } diff --git a/internal/logprovider/traceback.go b/internal/logprovider/traceback.go new file mode 100644 index 0000000..98cfcc7 --- /dev/null +++ b/internal/logprovider/traceback.go @@ -0,0 +1,15 @@ +package logprovider + +import ( + "fmt" + "os" + "runtime/debug" +) +func DebugTraceback(errFuncName string) { + if r := recover(); r != nil { + msg := fmt.Sprintf("panic in goroutine: %v\n%s", r, debug.Stack()) + f, _ := os.Create(errFuncName + ".log") + fmt.Fprint(f, msg) + f.Close() + } +} \ No newline at end of file diff --git a/internal/nodes/init.go b/internal/nodes/init.go index 0a38451..64eec06 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -1,13 +1,12 @@ package nodes import ( - "io" + "errors" + "fmt" "net" "net/http" "net/rpc" - "os" "simple-kv-store/internal/logprovider" - "strconv" "time" "github.com/syndtr/goleveldb/leveldb" @@ -16,88 +15,48 @@ import ( var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) -func newNode(address string) *Public_node_info { - return &Public_node_info{ - connect: false, - address: address, - } -} - -func Init(id string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node { - ns := make(map[string]*Public_node_info) - for id, addr := range nodeAddr { - ns[id] = newNode(addr) +// 运行在进程上的初始化 + rpc注册 +func InitRPCNode(SelfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node { + var nodeIds []string + for id := range nodeAddr { + nodeIds = append(nodeIds, id) } // 创建节点 - return &Node{ - selfId: id, - nodes: ns, - pipeAddr: pipe, - maxLogId: 0, - log: make(map[int]LogEntry), - db: db, + node := &Node{ + SelfId: SelfId, + LeaderId: "", + Nodes: nodeIds, + MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 + CurrTerm: 1, + Log: make([]RaftLogEntry, 0), + CommitIndex: -1, + LastApplied: -1, + NextIndex: make(map[string]int), + MatchIndex: make(map[string]int), + Db: db, + Storage: rstorage, + Transport: &HTTPTransport{NodeMap: nodeAddr}, + RTTable: NewRTTable(), + SeenRequests: make(map[LogEntryCallId]bool), + IsFinish: false, } -} - -func Start(node *Node, isLeader bool) { - if isLeader { - node.state = Candidate // 需要身份转变 - } else { - node.state = Follower + node.initLeaderState() + if isRestart { + node.CurrTerm = rstorage.GetCurrentTerm() + node.VotedFor = rstorage.GetVotedFor() + node.Log = rstorage.GetLogEntries() + log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log)) } - go func() { - for { - switch node.state { - case Follower: - - case Candidate: - // candidate发布一个监听输入线程后,变成leader - node.state = Leader - go func() { - if node.pipeAddr == "" { // 客户端远程调用server_node方法 - log.Info("请运行客户端进程进行读写") - } else { // 命令行提供了管道,支持管道(键盘)输入 - pipe, err := os.Open(node.pipeAddr) - if err != nil { - log.Error("Failed to open pipe") - } - defer pipe.Close() - - // 不断读取管道中的输入 - buffer := make([]byte, 256) - for { - n, err := pipe.Read(buffer) - if err != nil && err != io.EOF { - log.Error("Error reading from pipe") - } - if n > 0 { - input := string(buffer[:n]) - // 将用户输入封装成一个 LogEntry - kv := LogEntry{input, ""} // 目前键盘输入key,value 0 - logId := node.maxLogId - node.maxLogId++ - node.log[logId] = kv - - log.Info("send : logId = " + strconv.Itoa(logId) + ", key = " + input) - // 广播给其它节点 - kvCall := LogEntryCall{kv, Normal} - node.BroadCastKV(logId, kvCall) - // 持久化 - node.db.Put([]byte(kv.Key), []byte(kv.Value), nil) - } - } - } - }() - case Leader: - time.Sleep(50 * time.Millisecond) - } - } - }() + log.Sugar().Infof("[%s]开始监听" + port + "端口", SelfId) + node.ListenPort(port) + + return node } -func (node *Node) Rpc(port string) { +func (node *Node) ListenPort(port string) { + err := rpc.Register(node) if err != nil { log.Fatal("rpc register failed", zap.Error(err)) @@ -115,3 +74,184 @@ func (node *Node) Rpc(port string) { } }() } + +// 线程模拟的初始化 +func InitThreadNode(SelfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) { + rpcChan := make(chan RPCRequest, 100) // 要监听的chan + // 创建节点 + node := &Node{ + SelfId: SelfId, + LeaderId: "", + Nodes: peerIds, + MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 + CurrTerm: 1, + Log: make([]RaftLogEntry, 0), + CommitIndex: -1, + LastApplied: -1, + NextIndex: make(map[string]int), + MatchIndex: make(map[string]int), + Db: db, + Storage: rstorage, + Transport: threadTransport, + RTTable: NewRTTable(), + SeenRequests: make(map[LogEntryCallId]bool), + IsFinish: false, + } + node.initLeaderState() + if isRestart { + node.CurrTerm = rstorage.GetCurrentTerm() + node.VotedFor = rstorage.GetVotedFor() + node.Log = rstorage.GetLogEntries() + log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log)) + } + + threadTransport.RegisterNodeChan(SelfId, rpcChan) + quitChan := make(chan struct{}, 1) + go node.listenForChan(rpcChan, quitChan) + + return node, quitChan +} + +func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) { + defer logprovider.DebugTraceback("listen") + + for { + select { + case req := <-rpcChan: + switch req.Behavior { + case DelayRpc: + threadTran, ok := node.Transport.(*ThreadTransport) + if !ok { + log.Fatal("无效的delayRpc模式") + } + duration, ok2 := threadTran.Ctx.GetDelay(req.SourceId, node.SelfId) + if !ok2 { + log.Fatal("没有设置对应的delay时间") + } + go node.switchReq(req, duration) + + case FailRpc: + continue + default: + go node.switchReq(req, 0) + } + + case <-quitChan: + node.Mu.Lock() + defer node.Mu.Unlock() + log.Sugar().Infof("[%s] 监听线程收到退出信号", node.SelfId) + node.IsFinish = true + node.Db.Close() + node.Storage.Close() + return + } + } +} + +func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) { + defer logprovider.DebugTraceback("switch") + time.Sleep(delayTime) + + switch req.ServiceMethod { + case "Node.AppendEntries": + arg, ok := req.Args.(*AppendEntriesArg) + resp, ok2 := req.Reply.(*AppendEntriesReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for AppendEntries") + } else { + var respCopy AppendEntriesReply + err := node.AppendEntries(arg, &respCopy) + *resp = respCopy + req.Done <- err + } + + case "Node.RequestVote": + arg, ok := req.Args.(*RequestVoteArgs) + resp, ok2 := req.Reply.(*RequestVoteReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for RequestVote") + } else { + var respCopy RequestVoteReply + err := node.RequestVote(arg, &respCopy) + *resp = respCopy + req.Done <- err + } + + case "Node.WriteKV": + arg, ok := req.Args.(*LogEntryCall) + resp, ok2 := req.Reply.(*ServerReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for WriteKV") + } else { + req.Done <- node.WriteKV(arg, resp) + } + + case "Node.ReadKey": + arg, ok := req.Args.(*string) + resp, ok2 := req.Reply.(*ServerReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for ReadKey") + } else { + req.Done <- node.ReadKey(arg, resp) + } + + case "Node.FindLeader": + arg, ok := req.Args.(struct{}) + resp, ok2 := req.Reply.(*FindLeaderReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for FindLeader") + } else { + req.Done <- node.FindLeader(arg, resp) + } + + default: + req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod) + } +} + +// 共同部分和启动 +func (n *Node) initLeaderState() { + for _, peerId := range n.Nodes { + n.NextIndex[peerId] = len(n.Log) // 发送日志的下一个索引 + n.MatchIndex[peerId] = 0 // 复制日志的最新匹配索引 + } +} + +func Start(node *Node, quitChan chan struct{}) { + node.Mu.Lock() + node.State = Follower // 所有节点以 Follower 状态启动 + node.Mu.Unlock() + node.ResetElectionTimer() // 启动选举超时定时器 + + go func() { + defer logprovider.DebugTraceback("start") + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-quitChan: + fmt.Printf("[%s] Raft start 退出...\n", node.SelfId) + return // 退出 goroutine + + case <-ticker.C: + node.Mu.Lock() + state := node.State + node.Mu.Unlock() + + switch state { + case Follower: + // 监听心跳超时 + + case Leader: + // 发送心跳 + node.ResetElectionTimer() // leader 不主动触发选举 + node.BroadCastKV() + } + } + } + }() +} + + + diff --git a/internal/nodes/log.go b/internal/nodes/log.go index 3d13831..f8b0447 100644 --- a/internal/nodes/log.go +++ b/internal/nodes/log.go @@ -1,19 +1,32 @@ package nodes -const ( - Normal State = iota + 1 - Delay - Fail -) +import "strconv" type LogEntry struct { Key string Value string } +func (LogE *LogEntry) print() string { + return "key: " + LogE.Key + ", value: " + LogE.Value +} + +type RaftLogEntry struct { + LogE LogEntry + LogId int + Term int +} +func (RLogE *RaftLogEntry) print() string { + return "logid: " + strconv.Itoa(RLogE.LogId) + ", term: " + strconv.Itoa(RLogE.Term) + ", " + RLogE.LogE.print() +} type LogEntryCall struct { + Id LogEntryCallId LogE LogEntry - CallState State +} + +type LogEntryCallId struct { + ClientId string + LogId int } type KVReply struct { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index f1a17c8..7bacadb 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -1,13 +1,10 @@ package nodes import ( - "math/rand" - "net/rpc" - "strconv" + "sync" "time" "github.com/syndtr/goleveldb/leveldb" - "go.uber.org/zap" ) type State = uint8 @@ -18,84 +15,58 @@ const ( Leader ) -type Public_node_info struct { - connect bool - address string -} - type Node struct { + Mu sync.Mutex + MuElection sync.Mutex // 当前节点id - selfId string + SelfId string + // 记录的leader(不能用votedfor:投票的leader可能没有收到多数票) + LeaderId string - // 除当前节点外其他节点信息 - nodes map[string]*Public_node_info - - //管道名 - pipeAddr string + // 除当前节点外其他节点id + Nodes []string // 当前节点状态 - state State + State State + + // 任期 + CurrTerm int // 简单的kv存储 - log map[int]LogEntry + Log []RaftLogEntry - // leader用来标记新log - maxLogId int + // leader用来标记新log, = log.len + MaxLogId int - db *leveldb.DB -} + // 已提交的index + CommitIndex int -func (node *Node) BroadCastKV(logId int, kvCall LogEntryCall) { - // 遍历所有节点 - for id, _ := range node.nodes { - go func(id string, kv LogEntryCall) { - var reply KVReply - node.sendKV(id, logId, kvCall, &reply) - }(id, kvCall) - } -} + // 最后应用(写到db)的index + LastApplied int -func (node *Node) sendKV(id string, logId int, kvCall LogEntryCall, reply *KVReply) { - switch kvCall.CallState { - case Fail: - log.Info("模拟发送失败") - // 这么写向所有的node发送都失败,也可以随机数确定是否失败 - case Delay: - log.Info("模拟发送延迟") - // 随机延迟0-5ms - time.Sleep(time.Millisecond * time.Duration(rand.Intn(5))) - default: - } - - client, err := rpc.DialHTTP("tcp", node.nodes[id].address) - if err != nil { - log.Error("dialing: ", zap.Error(err)) - return - } - - defer func(client *rpc.Client) { - err := client.Close() - if err != nil { - log.Error("client close err: ", zap.Error(err)) - } - }(client) - - arg := LogIdAndEntry{logId, kvCall.LogE} - callErr := client.Call("Node.ReceiveKV", arg, reply) // RPC - if callErr != nil { - log.Error("dialing node_"+id+"fail: ", zap.Error(callErr)) - } -} + // 需要发送给每个节点的下一个索引 + NextIndex map[string]int -// RPC call -func (node *Node) ReceiveKV(arg LogIdAndEntry, reply *KVReply) error { - log.Info("node_" + node.selfId + " receive: logId = " + strconv.Itoa(arg.LogId) + ", key = " + arg.Entry.Key) - entry, ok := node.log[arg.LogId] - if !ok { - node.log[arg.LogId] = entry - } - // 持久化 - node.db.Put([]byte(arg.Entry.Key), []byte(arg.Entry.Value), nil) - reply.Reply = true // rpc call需要有reply,但实际上调用是否成功是error返回值决定 - return nil + // 已经发送给每个节点的最大索引 + MatchIndex map[string]int + + // 存kv(模拟状态机) + Db *leveldb.DB + // 持久化节点数据(currterm votedfor log) + Storage *RaftStorage + + VotedFor string + ElectionTimer *time.Timer + + // 通信方式 + Transport Transport + + // 系统的随机时间 + RTTable *RandomTimeTable + + // 已经处理过的客户端请求 + SeenRequests map[LogEntryCallId]bool + + IsFinish bool } + diff --git a/internal/nodes/node_storage.go b/internal/nodes/node_storage.go new file mode 100644 index 0000000..b6434f4 --- /dev/null +++ b/internal/nodes/node_storage.go @@ -0,0 +1,194 @@ +package nodes + +import ( + "encoding/json" + "strconv" + "strings" + "sync" + + "github.com/syndtr/goleveldb/leveldb" + "go.uber.org/zap" +) + +// RaftStorage 结构,持久化 currentTerm、votedFor 和 logEntries +type RaftStorage struct { + mu sync.Mutex + db *leveldb.DB + filePath string + isfinish bool +} + +// NewRaftStorage 创建 Raft 存储 +func NewRaftStorage(filePath string) *RaftStorage { + db, err := leveldb.OpenFile(filePath, nil) + if err != nil { + log.Fatal("无法打开 LevelDB:", zap.Error(err)) + } + + return &RaftStorage{ + db: db, + filePath: filePath, + isfinish: false, + } +} + +// SetCurrentTerm 设置当前 term +func (rs *RaftStorage) SetCurrentTerm(term int) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.isfinish { + return + } + err := rs.db.Put([]byte("current_term"), []byte(strconv.Itoa(term)), nil) + if err != nil { + log.Error("SetCurrentTerm 持久化失败:", zap.Error(err)) + } +} + +// GetCurrentTerm 获取当前 term +func (rs *RaftStorage) GetCurrentTerm() int { + rs.mu.Lock() + defer rs.mu.Unlock() + data, err := rs.db.Get([]byte("current_term"), nil) + if err != nil { + return 0 // 默认 term = 0 + } + term, _ := strconv.Atoi(string(data)) + return term +} + +// SetVotedFor 记录投票给谁 +func (rs *RaftStorage) SetVotedFor(candidate string) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.isfinish { + return + } + err := rs.db.Put([]byte("voted_for"), []byte(candidate), nil) + if err != nil { + log.Error("SetVotedFor 持久化失败:", zap.Error(err)) + } +} + +// GetVotedFor 获取投票对象 +func (rs *RaftStorage) GetVotedFor() string { + rs.mu.Lock() + defer rs.mu.Unlock() + data, err := rs.db.Get([]byte("voted_for"), nil) + if err != nil { + return "" + } + return string(data) +} + +// SetTermAndVote 原子更新 term 和 vote +func (rs *RaftStorage) SetTermAndVote(term int, candidate string) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.isfinish { + return + } + + batch := new(leveldb.Batch) + batch.Put([]byte("current_term"), []byte(strconv.Itoa(term))) + batch.Put([]byte("voted_for"), []byte(candidate)) + + err := rs.db.Write(batch, nil) // 原子提交 + if err != nil { + log.Error("SetTermAndVote 持久化失败:", zap.Error(err)) + } +} + +// AppendLog 追加日志 +func (rs *RaftStorage) AppendLog(entry RaftLogEntry) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.db == nil { + return + } + + // 序列化日志 + batch := new(leveldb.Batch) + data, _ := json.Marshal(entry) + key := "log_" + strconv.Itoa(entry.LogId) + batch.Put([]byte(key), data) + + lastIndex := strconv.Itoa(entry.LogId) + batch.Put([]byte("last_log_index"), []byte(lastIndex)) + err := rs.db.Write(batch, nil) + if err != nil { + log.Error("AppendLog 持久化失败:", zap.Error(err)) + } +} + +// GetLastLogIndex 获取最新日志的 index +func (rs *RaftStorage) GetLastLogIndex() int { + rs.mu.Lock() + defer rs.mu.Unlock() + data, err := rs.db.Get([]byte("last_log_index"), nil) + if err != nil { + return -1 + } + index, _ := strconv.Atoi(string(data)) + return index +} + +// WriteLog 批量写入日志(保证原子性) +func (rs *RaftStorage) WriteLog(entries []RaftLogEntry) { + if len(entries) == 0 { + return + } + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.isfinish { + return + } + + batch := new(leveldb.Batch) + for _, entry := range entries { + data, _ := json.Marshal(entry) + key := "log_" + strconv.Itoa(entry.LogId) + batch.Put([]byte(key), data) + } + + // 更新最新日志索引 + lastIndex := strconv.Itoa(entries[len(entries)-1].LogId) + batch.Put([]byte("last_log_index"), []byte(lastIndex)) + + err := rs.db.Write(batch, nil) + if err != nil { + log.Error("WriteLog 持久化失败:", zap.Error(err)) + } +} + +// GetLogEntries 获取所有日志 +func (rs *RaftStorage) GetLogEntries() []RaftLogEntry { + rs.mu.Lock() + defer rs.mu.Unlock() + + var logs []RaftLogEntry + iter := rs.db.NewIterator(nil, nil) // 遍历所有键值 + defer iter.Release() + + for iter.Next() { + key := string(iter.Key()) + if strings.HasPrefix(key, "log_") { // 过滤日志 key + var entry RaftLogEntry + if err := json.Unmarshal(iter.Value(), &entry); err == nil { + logs = append(logs, entry) + } else { + log.Error("解析日志失败:", zap.Error(err)) + } + } + } + + return logs +} + +// Close 关闭数据库 +func (rs *RaftStorage) Close() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.db.Close() + rs.isfinish = true +} diff --git a/internal/nodes/random_timetable.go b/internal/nodes/random_timetable.go new file mode 100644 index 0000000..5496cf1 --- /dev/null +++ b/internal/nodes/random_timetable.go @@ -0,0 +1,46 @@ +package nodes + +import ( + "math/rand" + "sync" + "time" +) + +type RandomTimeTable struct { + Mu sync.Mutex + electionTimeOut time.Duration + israndom bool + // heartbeat 50ms + // rpcTimeout 50ms + // follower变candidate 500ms + // 等待选举成功时间 300ms +} + +func NewRTTable() *RandomTimeTable { + return &RandomTimeTable{ + israndom: true, + } +} + +func (rttable *RandomTimeTable) GetElectionTimeout() time.Duration { + rttable.Mu.Lock() + defer rttable.Mu.Unlock() + if rttable.israndom { + return time.Duration(500+rand.Intn(500)) * time.Millisecond + } else { + return rttable.electionTimeOut + } +} + +func (rttable *RandomTimeTable) SetElectionTimeout(t time.Duration) { + rttable.Mu.Lock() + defer rttable.Mu.Unlock() + rttable.israndom = false + rttable.electionTimeOut = t +} + +func (rttable *RandomTimeTable) ResetElectionTimeout() { + rttable.Mu.Lock() + defer rttable.Mu.Unlock() + rttable.israndom = true +} \ No newline at end of file diff --git a/internal/nodes/real_transport.go b/internal/nodes/real_transport.go new file mode 100644 index 0000000..e5f45eb --- /dev/null +++ b/internal/nodes/real_transport.go @@ -0,0 +1,111 @@ +package nodes + +import ( + "errors" + "fmt" + "net/rpc" + "time" +) + +// 真实rpc通讯的transport类型实现 +type HTTPTransport struct{ + NodeMap map[string]string // id到addr的映射 +} + +// 封装有超时的dial +func (t *HTTPTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) { + done := make(chan struct{}) + var client *rpc.Client + var err error + + go func() { + client, err = rpc.DialHTTP(network, t.NodeMap[peerId]) + close(done) + }() + + select { + case <-done: + return &HTTPClient{rpcClient: client}, err + case <-time.After(50 * time.Millisecond): + return nil, fmt.Errorf("dial timeout: %s", t.NodeMap[peerId]) + } +} + +func (t *HTTPTransport) CallWithTimeout(clientInterface ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { + c, ok := clientInterface.(*HTTPClient) + client := c.rpcClient + if !ok { + return fmt.Errorf("invalid client type") + } + + done := make(chan error, 1) + + go func() { + switch serviceMethod { + case "Node.AppendEntries": + arg, ok := args.(*AppendEntriesArg) + resp, ok2 := reply.(*AppendEntriesReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for AppendEntries") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.RequestVote": + arg, ok := args.(*RequestVoteArgs) + resp, ok2 := reply.(*RequestVoteReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for RequestVote") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.WriteKV": + arg, ok := args.(*LogEntryCall) + resp, ok2 := reply.(*ServerReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for WriteKV") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.ReadKey": + arg, ok := args.(*string) + resp, ok2 := reply.(*ServerReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for ReadKey") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.FindLeader": + arg, ok := args.(struct{}) + resp, ok2 := reply.(*FindLeaderReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for FindLeader") + return + } + done <- client.Call(serviceMethod, arg, resp) + + default: + done <- fmt.Errorf("unknown service method: %s", serviceMethod) + } + }() + + select { + case err := <-done: + return err + case <-time.After(50 * time.Millisecond): // 设置超时时间 + return fmt.Errorf("call timeout: %s", serviceMethod) + } +} + +type HTTPClient struct { + rpcClient *rpc.Client +} + +func (h *HTTPClient) Close() error { + return h.rpcClient.Close() +} + + diff --git a/internal/nodes/replica.go b/internal/nodes/replica.go new file mode 100644 index 0000000..de8b98b --- /dev/null +++ b/internal/nodes/replica.go @@ -0,0 +1,271 @@ +package nodes + +import ( + "simple-kv-store/internal/logprovider" + "sort" + "strconv" + "sync" + + "go.uber.org/zap" +) + +type AppendEntriesArg struct { + Term int + LeaderId string + PrevLogIndex int + PrevLogTerm int + Entries []RaftLogEntry + LeaderCommit int +} + +type AppendEntriesReply struct { + Term int + Success bool +} + +// leader收到新内容要广播,以及心跳广播(同步自己的log) +func (node *Node) BroadCastKV() { + log.Sugar().Infof("leader[%s]广播消息", node.SelfId) + defer logprovider.DebugTraceback("broadcast") + failCount := 0 + // 这里增加一个锁,防止并发修改成功计数 + var failMutex sync.Mutex + // 遍历所有节点 + for _, id := range node.Nodes { + go func(id string) { + defer logprovider.DebugTraceback("send") + node.sendKV(id, &failCount, &failMutex) + }(id) + } +} + +func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) { + node.Mu.Lock() + selfId := node.SelfId + node.Mu.Unlock() + + client, err := node.Transport.DialHTTPWithTimeout("tcp", selfId, peerId) + if err != nil { + node.Mu.Lock() + log.Error("[" + node.SelfId + "]dialling [" + peerId + "] fail: ", zap.Error(err)) + failMutex.Lock() + *failCount++ + if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级 + node.LeaderId = "" + node.State = Follower + node.ResetElectionTimer() + } + failMutex.Unlock() + node.Mu.Unlock() + return + } + + defer func(client ClientInterface) { + err := client.Close() + if err != nil { + log.Error("client close err: ", zap.Error(err)) + } + }(client) + + node.Mu.Lock() + + NextIndex := node.NextIndex[peerId] + // log.Info("NextIndex " + strconv.Itoa(NextIndex)) + for { + if NextIndex < 0 { + log.Fatal("assert >= 0 here") + } + + + sendEntries := node.Log[NextIndex:] + arg := AppendEntriesArg{ + Term: node.CurrTerm, + PrevLogIndex: NextIndex - 1, + Entries: sendEntries, + LeaderCommit: node.CommitIndex, + LeaderId: node.SelfId, + } + if arg.PrevLogIndex >= 0 { + arg.PrevLogTerm = node.Log[arg.PrevLogIndex].Term + } + // 记录关键数据后解锁 + currTerm := node.CurrTerm + currState := node.State + MaxLogId := node.MaxLogId + + var appendReply AppendEntriesReply + appendReply.Success = false + node.Mu.Unlock() + + callErr := node.Transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC + + node.Mu.Lock() + if node.CurrTerm != currTerm || node.MaxLogId != MaxLogId || node.State != currState { + node.Mu.Unlock() + return + } + + if callErr != nil { + log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr)) + failMutex.Lock() + *failCount++ + if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级 + log.Info("term=" + strconv.Itoa(node.CurrTerm) + "的Leader[" + node.SelfId + "]无法联系到半数节点, 降级为 Follower") + node.LeaderId = "" + node.State = Follower + node.ResetElectionTimer() + } + failMutex.Unlock() + node.Mu.Unlock() + return + } + + if appendReply.Term != node.CurrTerm { + log.Sugar().Infof("term=%s的leader[%s]因为[%s]收到更高的term=%s, 转换为follower", + strconv.Itoa(node.CurrTerm), node.SelfId, peerId, strconv.Itoa(appendReply.Term)) + + node.LeaderId = "" + node.CurrTerm = appendReply.Term + node.State = Follower + node.VotedFor = "" + node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor) + node.ResetElectionTimer() + node.Mu.Unlock() + return + } + + if appendReply.Success { + break + } + + NextIndex-- // 失败往前传一格 + } + + // 不变成follower情况下 + node.NextIndex[peerId] = node.MaxLogId + 1 + node.MatchIndex[peerId] = node.MaxLogId + node.updateCommitIndex() + node.Mu.Unlock() +} + +func (node *Node) updateCommitIndex() { + if node.Mu.TryLock() { + log.Fatal("这里要保证有锁") + } + if node.IsFinish { + return + } + + totalNodes := len(node.Nodes) + + // 收集所有 MatchIndex 并排序 + MatchIndexes := make([]int, 0, totalNodes) + for _, index := range node.MatchIndex { + MatchIndexes = append(MatchIndexes, index) + } + sort.Ints(MatchIndexes) // 排序 + + // 计算多数派 CommitIndex + majorityIndex := MatchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派) + + // 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志 + if majorityIndex > node.CommitIndex && majorityIndex < len(node.Log) && node.Log[majorityIndex].Term == node.CurrTerm { + node.CommitIndex = majorityIndex + log.Info("Leader[" + node.SelfId + "]更新 CommitIndex: " + strconv.Itoa(majorityIndex)) + + // 应用日志到状态机 + node.applyCommittedLogs() + } +} + +// 应用日志到状态机 +func (node *Node) applyCommittedLogs() { + for node.LastApplied < node.CommitIndex { + node.LastApplied++ + logEntry := node.Log[node.LastApplied] + log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.SelfId) + err := node.Db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil) + if err != nil { + log.Error(node.SelfId + "应用状态机失败: ", zap.Error(err)) + } + } +} + +// RPC call +func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply) error { + defer logprovider.DebugTraceback("append") + + node.Mu.Lock() + defer node.Mu.Unlock() + log.Sugar().Infof("[%s]在term=%d收到[%s]的AppendEntries", node.SelfId, node.CurrTerm, arg.LeaderId) + + + // 如果 term 过期,拒绝接受日志 + if node.CurrTerm > arg.Term { + reply.Term = node.CurrTerm + reply.Success = false + return nil + } + + node.LeaderId = arg.LeaderId // 记录Leader + + // 如果term比自己高,或自己不是follower但收到相同term的心跳 + if node.CurrTerm < arg.Term || node.State != Follower { + log.Sugar().Infof("[%s]发现更高 term(%s)", node.SelfId, strconv.Itoa(arg.Term)) + node.CurrTerm = arg.Term + node.State = Follower + node.VotedFor = "" + // node.storage.SetTermAndVote(node.CurrTerm, node.VotedFor) + } + node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor) + + // 检查 prevLogIndex 是否有效 + if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) { + reply.Term = node.CurrTerm + reply.Success = false + return nil + } + + // 处理日志冲突(如果存在不同 term,则截断日志) + idx := arg.PrevLogIndex + 1 + for i := idx; i < len(node.Log) && i-idx < len(arg.Entries); i++ { + if node.Log[i].Term != arg.Entries[i-idx].Term { + node.Log = node.Log[:idx] + break + } + } + // log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.Log))) + + // 追加新的日志条目 + for _, raftLogEntry := range arg.Entries { + log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.SelfId) + if idx < len(node.Log) { + node.Log[idx] = raftLogEntry + } else { + node.Log = append(node.Log, raftLogEntry) + } + idx++ + } + + // 暴力持久化 + node.Storage.WriteLog(node.Log) + + // 更新 MaxLogId + node.MaxLogId = len(node.Log) - 1 + + // 更新 CommitIndex + if arg.LeaderCommit < node.MaxLogId { + node.CommitIndex = arg.LeaderCommit + } else { + node.CommitIndex = node.MaxLogId + } + + // 提交已提交的日志 + node.applyCommittedLogs() + + // 在成功接受日志或心跳后,重置选举超时 + node.ResetElectionTimer() + reply.Term = node.CurrTerm + reply.Success = true + return nil +} \ No newline at end of file diff --git a/internal/nodes/server_node.go b/internal/nodes/server_node.go index 58eed40..e07a33e 100644 --- a/internal/nodes/server_node.go +++ b/internal/nodes/server_node.go @@ -1,42 +1,94 @@ package nodes import ( - "strconv" + "simple-kv-store/internal/logprovider" "github.com/syndtr/goleveldb/leveldb" ) // leader node作为server为client注册的方法 type ServerReply struct{ - Isconnect bool + Isleader bool + LeaderId string // 自己不是leader则返回leader HaveValue bool Value string } // RPC call -func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error { - - logId := node.maxLogId - node.maxLogId++ - node.log[logId] = kvCall.LogE - node.db.Put([]byte(kvCall.LogE.Key), []byte(kvCall.LogE.Value), nil) - log.Info("server write : logId = " + strconv.Itoa(logId) + ", key = " + kvCall.LogE.Key) +func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error { + defer logprovider.DebugTraceback("write") + node.Mu.Lock() + defer node.Mu.Unlock() + + log.Sugar().Infof("[%s]收到客户端write请求", node.SelfId) + // 自己不是leader,转交leader地址回复 + if node.State != Leader { + reply.Isleader = false + reply.LeaderId = node.LeaderId // 可能是空,那client就随机再找一个节点 + log.Sugar().Infof("[%s]转交给[%s]", node.SelfId, node.LeaderId) + return nil + } + + if node.SeenRequests[kvCall.Id] { + log.Sugar().Infof("Leader [%s] 已处理过client[%s]的请求 %d, 跳过", node.SelfId, kvCall.Id.ClientId, kvCall.Id.LogId) + reply.Isleader = true + return nil + } + node.SeenRequests[kvCall.Id] = true + + // 自己是leader,修改自己的记录并广播 + node.MaxLogId++ + logId := node.MaxLogId + rLogE := RaftLogEntry{kvCall.LogE, logId, node.CurrTerm} + node.Log = append(node.Log, rLogE) + node.Storage.AppendLog(rLogE) + log.Info("leader[" + node.SelfId + "]处理请求 : " + kvCall.LogE.print()) // 广播给其它节点 - node.BroadCastKV(logId, kvCall) - reply.Isconnect = true + node.BroadCastKV() + reply.Isleader = true return nil } + // RPC call -func (node *Node) ReadKey(key string, reply *ServerReply) error { - log.Info("server read : " + key) - // 先只读leader自己 - value, err := node.db.Get([]byte(key), nil) +func (node *Node) ReadKey(key *string, reply *ServerReply) error { + defer logprovider.DebugTraceback("read") + node.Mu.Lock() + defer node.Mu.Unlock() + log.Sugar().Infof("[%s]收到客户端read请求", node.SelfId) + + // 先只读自己(无论自己是不是leader),也方便测试 + value, err := node.Db.Get([]byte(*key), nil) if err == leveldb.ErrNotFound { reply.HaveValue = false } else { reply.HaveValue = true reply.Value = string(value) } - reply.Isconnect = true + reply.Isleader = true return nil } +// RPC call 测试中寻找当前leader +type FindLeaderReply struct{ + Isleader bool + LeaderId string +} +func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error { + defer logprovider.DebugTraceback("find") + node.Mu.Lock() + defer node.Mu.Unlock() + + // 自己不是leader,转交leader地址回复 + if node.State != Leader { + reply.Isleader = false + if (node.LeaderId == "") { + log.Fatal("还没选出第一个leader") + return nil + } + reply.LeaderId = node.LeaderId + return nil + } + + reply.LeaderId = node.SelfId + reply.Isleader = true + return nil +} diff --git a/internal/nodes/simulate_ctx.go b/internal/nodes/simulate_ctx.go new file mode 100644 index 0000000..0fa2336 --- /dev/null +++ b/internal/nodes/simulate_ctx.go @@ -0,0 +1,65 @@ +package nodes + +import ( + "fmt" + "sync" + "time" +) + +// Ctx 结构体:管理不同节点之间的通信行为 +type Ctx struct { + mu sync.Mutex + Behavior map[string]CallBehavior // (src,target) -> CallBehavior + Delay map[string]time.Duration // (src,target) -> 延迟时间 + Retries map[string]int // 记录 (src,target) 的重发调用次数 +} + +// NewCtx 创建上下文 +func NewCtx() *Ctx { + return &Ctx{ + Behavior: make(map[string]CallBehavior), + Delay: make(map[string]time.Duration), + Retries: make(map[string]int), + } +} + +// SetBehavior 设置 A->B 的 RPC 行为 +func (c *Ctx) SetBehavior(src, dst string, behavior CallBehavior, delay time.Duration, retries int) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + c.Behavior[key] = behavior + c.Delay[key] = delay + c.Retries[key] = retries +} + +// GetBehavior 获取 A->B 的行为 +func (c *Ctx) GetBehavior(src, dst string) (CallBehavior) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if state, exists := c.Behavior[key]; exists { + return state + } + return NormalRpc +} + +func (c *Ctx) GetDelay(src, dst string) (t time.Duration, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if t, ok = c.Delay[key]; ok { + return t, ok + } + return 0, ok +} + +func (c *Ctx) GetRetries(src, dst string) (times int, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + key := fmt.Sprintf("%s->%s", src, dst) + if times, ok = c.Retries[key]; ok { + return times, ok + } + return 0, ok +} \ No newline at end of file diff --git a/internal/nodes/thread_transport.go b/internal/nodes/thread_transport.go new file mode 100644 index 0000000..c437eee --- /dev/null +++ b/internal/nodes/thread_transport.go @@ -0,0 +1,234 @@ +package nodes + +import ( + "fmt" + "sync" + "time" +) + +type CallBehavior = uint8 + +const ( + NormalRpc CallBehavior = iota + 1 + DelayRpc + RetryRpc + FailRpc +) + +// RPC 请求结构 +type RPCRequest struct { + ServiceMethod string + Args interface{} + Reply interface{} + + Done chan error // 用于返回响应 + + SourceId string + // 模拟rpc请求状态 + Behavior CallBehavior +} + +// 线程版 Transport +type ThreadTransport struct { + mu sync.Mutex + nodeChans map[string]chan RPCRequest // 每个节点的消息通道 + connectivityMap map[string]map[string]bool // 模拟网络分区 + Ctx *Ctx +} + +// 线程版 dial的返回clientinterface +type ThreadClient struct { + SourceId string + TargetId string +} + +func (c *ThreadClient) Close() error { + return nil +} + +// 初始化线程通信系统 +func NewThreadTransport(ctx *Ctx) *ThreadTransport { + return &ThreadTransport{ + nodeChans: make(map[string]chan RPCRequest), + connectivityMap: make(map[string]map[string]bool), + Ctx: ctx, + } +} + +// 注册一个新节点chan +func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) { + t.mu.Lock() + defer t.mu.Unlock() + t.nodeChans[nodeId] = ch + + // 初始化连通性(默认所有节点互相可达) + if _, exists := t.connectivityMap[nodeId]; !exists { + t.connectivityMap[nodeId] = make(map[string]bool) + } + for peerId := range t.nodeChans { + t.connectivityMap[nodeId][peerId] = true + t.connectivityMap[peerId][nodeId] = true + } +} + +// 设置两个节点的连通性 +func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) { + t.mu.Lock() + defer t.mu.Unlock() + if _, exists := t.connectivityMap[from]; exists { + t.connectivityMap[from][to] = isConnected + } +} + +func (t *ThreadTransport) ResetConnectivity() { + t.mu.Lock() + defer t.mu.Unlock() + for firstId:= range t.nodeChans { + for peerId:= range t.nodeChans { + if firstId != peerId { + t.connectivityMap[firstId][peerId] = true + t.connectivityMap[peerId][firstId] = true + } + } + } +} + +// 获取节点的 channel +func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) { + t.mu.Lock() + defer t.mu.Unlock() + ch, ok := t.nodeChans[nodeId] + return ch, ok +} + +// 模拟 Dial 操作 +func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) { + t.mu.Lock() + defer t.mu.Unlock() + if _, exists := t.nodeChans[peerId]; !exists { + return nil, fmt.Errorf("节点 [%s] 不存在", peerId) + } + + return &ThreadClient{SourceId: myId, TargetId: peerId}, nil +} + +// 模拟 Call 操作 +func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { + threadClient, ok := client.(*ThreadClient) + if !ok { + return fmt.Errorf("无效的caller") + } + + var isConnected bool + if threadClient.SourceId == "" { + isConnected = true + } else { + t.mu.Lock() + isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] + t.mu.Unlock() + } + if !isConnected { + return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId) + } + + targetChan, exists := t.getNodeChan(threadClient.TargetId) + if !exists { + return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId) + } + + done := make(chan error, 1) + behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId) + + // 辅助函数:复制 replyCopy 到原始 reply + copyReply := func(dst, src interface{}) { + switch d := dst.(type) { + case *AppendEntriesReply: + *d = *(src.(*AppendEntriesReply)) + case *RequestVoteReply: + *d = *(src.(*RequestVoteReply)) + } + } + + sendRequest := func(req RPCRequest, ch chan RPCRequest) bool { + select { + case ch <- req: + return true + default: + return false + } + } + + switch behavior { + case RetryRpc: + retryTimes, ok := t.Ctx.GetRetries(threadClient.SourceId, threadClient.TargetId) + if !ok { + log.Fatal("没有设置对应的retry次数") + } + + var lastErr error + for i := 0; i < retryTimes; i++ { + var replyCopy interface{} + useCopy := true + + switch r := reply.(type) { + case *AppendEntriesReply: + tmp := *r + replyCopy = &tmp + case *RequestVoteReply: + tmp := *r + replyCopy = &tmp + default: + replyCopy = reply // 其他类型不复制 + useCopy = false + } + + request := RPCRequest{ + ServiceMethod: serviceMethod, + Args: args, + Reply: replyCopy, + Done: done, + SourceId: threadClient.SourceId, + Behavior: NormalRpc, + } + + if !sendRequest(request, targetChan) { + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) + } + + select { + case err := <-done: + if err == nil && useCopy { + copyReply(reply, replyCopy) + } + if err == nil { + return nil + } + lastErr = err + case <-time.After(250 * time.Millisecond): + lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod) + } + } + return lastErr + + default: + request := RPCRequest{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + SourceId: threadClient.SourceId, + Behavior: behavior, + } + + if !sendRequest(request, targetChan) { + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId) + } + + select { + case err := <-done: + return err + case <-time.After(250 * time.Millisecond): + return fmt.Errorf("RPC 调用超时: %s", serviceMethod) + } + } +} \ No newline at end of file diff --git a/internal/nodes/transport.go b/internal/nodes/transport.go new file mode 100644 index 0000000..8d92985 --- /dev/null +++ b/internal/nodes/transport.go @@ -0,0 +1,10 @@ +package nodes + +type ClientInterface interface{ + Close() error +} + +type Transport interface { + DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) + CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error +} diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go new file mode 100644 index 0000000..6ec64b6 --- /dev/null +++ b/internal/nodes/vote.go @@ -0,0 +1,215 @@ +package nodes + +import ( + "math/rand" + "simple-kv-store/internal/logprovider" + "strconv" + "sync" + "time" + + "go.uber.org/zap" +) + +type RequestVoteArgs struct { + Term int // 候选人的当前任期 + CandidateId string // 候选人 ID + LastLogIndex int // 候选人最后一条日志的索引 + LastLogTerm int // 候选人最后一条日志的任期 +} + +type RequestVoteReply struct { + Term int // 当前节点的最新任期 + VoteGranted bool // 是否同意投票 +} + +func (n *Node) StartElection() { + defer logprovider.DebugTraceback("startElection") + n.Mu.Lock() + if n.IsFinish { + n.Mu.Unlock() + return + } + // 增加当前任期,转换为 Candidate + n.CurrTerm++ + n.State = Candidate + n.VotedFor = n.SelfId // 自己投自己 + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) + + log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.SelfId, n.CurrTerm) + + // 重新设置选举超时,防止重复选举 + n.ResetElectionTimer() + + // 构造 RequestVote 请求 + var lastLogIndex int + var lastLogTerm int + + if len(n.Log) == 0 { + lastLogIndex = 0 + lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0 + } else { + lastLogIndex = len(n.Log) - 1 + lastLogTerm = n.Log[lastLogIndex].Term + } + args := RequestVoteArgs{ + Term: n.CurrTerm, + CandidateId: n.SelfId, + LastLogIndex: lastLogIndex, + LastLogTerm: lastLogTerm, + } + + // 并行向其他节点发送请求投票 + var Mu sync.Mutex + totalNodes := len(n.Nodes) + grantedVotes := 1 // 自己的票 + + currTerm := n.CurrTerm + currState := n.State + n.Mu.Unlock() + + for _, peerId := range n.Nodes { + go func(peerId string) { + defer logprovider.DebugTraceback("vote") + var reply RequestVoteReply + if n.sendRequestVote(peerId, &args, &reply) { + Mu.Lock() + defer Mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() + + if currTerm != n.CurrTerm || currState != n.State { + return + } + + if reply.Term > n.CurrTerm { + // 发现更高任期,回退为 Follower + log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term) + n.CurrTerm = reply.Term + n.State = Follower + n.VotedFor = "" + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) + n.ResetElectionTimer() + return + } + + if reply.VoteGranted { + grantedVotes++ + } + + if grantedVotes == totalNodes / 2 + 1 { + n.State = Leader + log.Sugar().Infof("[%s] 当选 Leader!", n.SelfId) + n.initLeaderState() + } + } + }(peerId) + } + + // 等待选举结果 + time.Sleep(300 * time.Millisecond) + + Mu.Lock() + defer Mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() + + if n.State == Candidate { + log.Sugar().Infof("[%s] 选举超时,等待后将重新发起选举", n.SelfId) + // n.State = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower + n.ResetElectionTimer() + } +} + +func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool { + log.Sugar().Infof("[%s] 请求 [%s] 投票", node.SelfId, peerId) + client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId) + if err != nil { + log.Error("[" + node.SelfId + "]dialing [" + peerId + "] fail: ", zap.Error(err)) + return false + } + + defer func(client ClientInterface) { + err := client.Close() + if err != nil { + log.Error("client close err: ", zap.Error(err)) + } + }(client) + + callErr := node.Transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC + if callErr != nil { + log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr)) + } + return callErr == nil +} + +func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error { + defer logprovider.DebugTraceback("requestVote") + n.Mu.Lock() + defer n.Mu.Unlock() + + // 如果候选人的任期小于当前任期,则拒绝投票 + if args.Term < n.CurrTerm { + reply.Term = n.CurrTerm + reply.VoteGranted = false + return nil + } + + // 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower + if args.Term > n.CurrTerm { + n.CurrTerm = args.Term + n.State = Follower + n.VotedFor = "" + n.ResetElectionTimer() // 重新设置选举超时 + } + + // 检查是否已经投过票,且是否投给了同一个候选人 + if n.VotedFor == "" || n.VotedFor == args.CandidateId { + // 检查日志是否足够新 + var lastLogIndex int + var lastLogTerm int + + if len(n.Log) == 0 { + lastLogIndex = -1 + lastLogTerm = 0 + } else { + lastLogIndex = len(n.Log) - 1 + lastLogTerm = n.Log[lastLogIndex].Term + } + + if args.LastLogTerm > lastLogTerm || + (args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) { + // 够新就投票给候选人 + n.VotedFor = args.CandidateId + log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.CurrTerm), n.SelfId, n.VotedFor) + reply.VoteGranted = true + n.ResetElectionTimer() + } else { + reply.VoteGranted = false + } + } else { + reply.VoteGranted = false + } + + n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor) + reply.Term = n.CurrTerm + return nil +} + +// follower 一段时间内没收到appendentries心跳,就变成candidate发起选举 +func (node *Node) ResetElectionTimer() { + node.MuElection.Lock() + defer node.MuElection.Unlock() + if node.ElectionTimer == nil { + node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout()) + go func() { + defer logprovider.DebugTraceback("reset") + for { + <-node.ElectionTimer.C + node.StartElection() + } + }() + } else { + node.ElectionTimer.Stop() + node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond) + } +} \ No newline at end of file diff --git a/pics/plan.png b/pics/plan.png new file mode 100644 index 0000000..8d1f54d Binary files /dev/null and b/pics/plan.png differ diff --git a/pics/plus.png b/pics/plus.png new file mode 100644 index 0000000..f61670e Binary files /dev/null and b/pics/plus.png differ diff --git a/pics/robust.png b/pics/robust.png new file mode 100644 index 0000000..8db9682 Binary files /dev/null and b/pics/robust.png differ diff --git a/raft第二次汇报.pptx b/raft第二次汇报.pptx new file mode 100644 index 0000000..3c4ff00 Binary files /dev/null and b/raft第二次汇报.pptx differ diff --git a/scripts/run.sh b/scripts/run.sh deleted file mode 100755 index 794c247..0000000 --- a/scripts/run.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash - -# 设置运行时间限制:s -RUN_TIME=10 - -# 需要传递数据的管道 -PIPE_NAME="/tmp/input_pipe" - -# 启动节点1 -echo "Starting Node 1..." -timeout $RUN_TIME ./main -id 1 -port ":9091" -cluster "127.0.0.1:9092,127.0.0.1:9093" -pipe "$PIPE_NAME" -isleader=true & - -# 启动节点2 -echo "Starting Node 2..." -timeout $RUN_TIME ./main -id 2 -port ":9092" -cluster "127.0.0.1:9091,127.0.0.1:9093" -pipe "$PIPE_NAME" & - -# 启动节点3 -echo "Starting Node 3..." -timeout $RUN_TIME ./main -id 3 -port ":9093" -cluster "127.0.0.1:9091,127.0.0.1:9092" -pipe "$PIPE_NAME"& - -echo "All nodes started successfully!" -# 创建一个管道用于进程间通信 -if [[ ! -p "$PIPE_NAME" ]]; then - mkfifo "$PIPE_NAME" -fi - -# 捕获终端输入并通过管道传递给三个节点 -echo "Enter input to send to nodes:" -start_time=$(date +%s) -while true; do - # 从终端读取用户输入 - read -r user_input - - current_time=$(date +%s) - elapsed_time=$((current_time - start_time)) - - # 如果运行时间大于限制时间,就退出 - if [ $elapsed_time -ge $RUN_TIME ]; then - echo 'Timeout reached, normal exit now' - break - fi - - # 如果输入为空,跳过 - if [[ -z "$user_input" ]]; then - continue - fi - - # 将用户输入发送到管道 - echo "$user_input" > "$PIPE_NAME" - - # 如果输入 "exit",结束脚本 - if [[ "$user_input" == "exit" ]]; then - break - fi -done - -# 删除管道 -rm "$PIPE_NAME" - -# 等待所有节点完成启动 -wait diff --git a/test/common.go b/test/common.go index 554927d..3b0dc35 100644 --- a/test/common.go +++ b/test/common.go @@ -8,29 +8,21 @@ import ( "strings" ) -func ExecuteNodeI(i int, isLeader bool, isNewDb bool, clusters []string) *exec.Cmd { - tmpClusters := append(clusters[:i], clusters[i+1:]...) +func ExecuteNodeI(i int, isRestart bool, clusters []string) *exec.Cmd { port := fmt.Sprintf(":%d", uint16(9090)+uint16(i)) - var isleader string - if isLeader { - isleader = "true" + var isRestartStr string + if isRestart { + isRestartStr = "true" } else { - isleader = "false" - } - var isnewdb string - if isNewDb { - isnewdb = "true" - } else { - isnewdb = "false" + isRestartStr = "false" } cmd := exec.Command( "../main", "-id", strconv.Itoa(i + 1), "-port", port, - "-cluster", strings.Join(tmpClusters, ","), - "-isleader=" + isleader, - "-isNewDb=" + isnewdb, + "-cluster", strings.Join(clusters, ","), + "-isRestart=" + isRestartStr, ) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/test/restart_follower_test.go b/test/restart_follower_test.go deleted file mode 100644 index 323507d..0000000 --- a/test/restart_follower_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package test - -import ( - "fmt" - "os/exec" - "simple-kv-store/internal/client" - "simple-kv-store/internal/nodes" - "strconv" - "syscall" - "testing" - "time" -) - -func TestFollowerRestart(t *testing.T) { - // 登记结点信息 - n := 3 - var clusters []string - for i := 0; i < n; i++ { - port := fmt.Sprintf("%d", uint16(9090)+uint16(i)) - addr := "127.0.0.1:" + port - clusters = append(clusters, addr) - } - - // 结点启动 - var cmds []*exec.Cmd - for i := 0; i < n; i++ { - var cmd *exec.Cmd - if i == 0 { - cmd = ExecuteNodeI(i, true, true, clusters) - } else { - cmd = ExecuteNodeI(i, false, true, clusters) - } - - if cmd == nil { - return - } else { - cmds = append(cmds, cmd) - } - } - - time.Sleep(time.Second) // 等待启动完毕 - // client启动, 连接leader - cWrite := clientPkg.Client{Address: clusters[0], ServerId: "1"} - - // 写入 - var s clientPkg.Status - for i := 0; i < 5; i++ { - key := strconv.Itoa(i) - newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) - if s != clientPkg.Ok { - t.Errorf("write test fail") - } - } - time.Sleep(time.Second) // 等待写入完毕 - // 模拟最后一个结点崩溃 - err := cmds[n - 1].Process.Signal(syscall.SIGTERM) - if err != nil { - fmt.Println("Error sending signal:", err) - return - } - // 继续写入 - for i := 5; i < 10; i++ { - key := strconv.Itoa(i) - newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) - if s != clientPkg.Ok { - t.Errorf("write test fail") - } - } - // 恢复结点 - cmd := ExecuteNodeI(n - 1, false, false, clusters) - if cmd == nil { - t.Errorf("recover test1 fail") - return - } else { - cmds[n - 1] = cmd - } - time.Sleep(time.Second) // 等待启动完毕 - - // client启动, 连接节点n-1(去读它的数据) - cRead := clientPkg.Client{Address: clusters[n - 1], ServerId: "n"} - // 读崩溃前写入数据 - for i := 0; i < 5; i++ { - key := strconv.Itoa(i) - var value string - s = cRead.Read(key, &value) - if s != clientPkg.Ok { - t.Errorf("Read test1 fail") - } - } - - // 读未写入数据 - for i := 5; i < 15; i++ { - key := strconv.Itoa(i) - var value string - s = cRead.Read(key, &value) - if s != clientPkg.NotFound { - t.Errorf("Read test2 fail") - } - } - - // 通知进程结束 - for _, cmd := range cmds { - err := cmd.Process.Signal(syscall.SIGTERM) - if err != nil { - fmt.Println("Error sending signal:", err) - return - } - } - -} diff --git a/test/restart_node_test.go b/test/restart_node_test.go new file mode 100644 index 0000000..b8d4093 --- /dev/null +++ b/test/restart_node_test.go @@ -0,0 +1,104 @@ +package test + +import ( + "fmt" + "os/exec" + "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "syscall" + "testing" + "time" +) + +func TestNodeRestart(t *testing.T) { + // 登记结点信息 + n := 5 + var clusters []string + var peerIds []string + addressMap := make(map[string]string) + for i := 0; i < n; i++ { + port := fmt.Sprintf("%d", uint16(9090)+uint16(i)) + addr := "127.0.0.1:" + port + clusters = append(clusters, addr) + addressMap[strconv.Itoa(i + 1)] = addr + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var cmds []*exec.Cmd + for i := 0; i < n; i++ { + cmd := ExecuteNodeI(i, false, clusters) + cmds = append(cmds, cmd) + } + + // 通知所有进程结束 + defer func(){ + for _, cmd := range cmds { + err := cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + fmt.Println("Error sending signal:", err) + return + } + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动, 连接任意节点 + transport := &nodes.HTTPTransport{NodeMap: addressMap} + cWrite := clientPkg.NewClient("0", peerIds, transport) + + // 写入 + var s clientPkg.Status + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s := cWrite.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + time.Sleep(time.Second) // 等待写入完毕 + + // 模拟结点轮流崩溃 + for i := 0; i < n; i++ { + err := cmds[i].Process.Signal(syscall.SIGTERM) + if err != nil { + fmt.Println("Error sending signal:", err) + return + } + + time.Sleep(time.Second) + cmd := ExecuteNodeI(i, true, clusters) + if cmd == nil { + t.Errorf("recover test1 fail") + return + } else { + cmds[i] = cmd + } + time.Sleep(time.Second) // 等待启动完毕 + } + + + // client启动 + cRead := clientPkg.NewClient("0", peerIds, transport) + // 读写入数据 + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + var value string + s = cRead.Read(key, &value) + if s != clientPkg.Ok { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 5; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s = cRead.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } +} diff --git a/test/server_client_test.go b/test/server_client_test.go index ab106c3..fc29ddc 100644 --- a/test/server_client_test.go +++ b/test/server_client_test.go @@ -15,50 +15,57 @@ func TestServerClient(t *testing.T) { // 登记结点信息 n := 5 var clusters []string + var peerIds []string + addressMap := make(map[string]string) for i := 0; i < n; i++ { port := fmt.Sprintf("%d", uint16(9090)+uint16(i)) addr := "127.0.0.1:" + port clusters = append(clusters, addr) + addressMap[strconv.Itoa(i + 1)] = addr + peerIds = append(peerIds, strconv.Itoa(i + 1)) } // 结点启动 var cmds []*exec.Cmd for i := 0; i < n; i++ { - var cmd *exec.Cmd - if i == 0 { - cmd = ExecuteNodeI(i, true, true, clusters) - } else { - cmd = ExecuteNodeI(i, false, true, clusters) - } - - if cmd == nil { - return - } else { - cmds = append(cmds, cmd) - } + cmd := ExecuteNodeI(i, false, clusters) + cmds = append(cmds, cmd) } + // 通知所有进程结束 + defer func(){ + for _, cmd := range cmds { + err := cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + fmt.Println("Error sending signal:", err) + return + } + } + }() + time.Sleep(time.Second) // 等待启动完毕 // client启动 - c := clientPkg.Client{Address: "127.0.0.1:9090", ServerId: "1"} + transport := &nodes.HTTPTransport{NodeMap: addressMap} + c := clientPkg.NewClient("0", peerIds, transport) // 写入 var s clientPkg.Status for i := 0; i < 10; i++ { key := strconv.Itoa(i) newlog := nodes.LogEntry{Key: key, Value: "hello"} - s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + s := c.Write(newlog) if s != clientPkg.Ok { t.Errorf("write test fail") } } + time.Sleep(time.Second) // 等待写入完毕 // 读写入数据 for i := 0; i < 10; i++ { key := strconv.Itoa(i) var value string s = c.Read(key, &value) - if s != clientPkg.Ok && value != "hello" + key { + if s != clientPkg.Ok || value != "hello" { t.Errorf("Read test1 fail") } } @@ -72,14 +79,4 @@ func TestServerClient(t *testing.T) { t.Errorf("Read test2 fail") } } - - // 通知进程结束 - for _, cmd := range cmds { - err := cmd.Process.Signal(syscall.SIGTERM) - if err != nil { - fmt.Println("Error sending signal:", err) - return - } - } - } diff --git a/threadTest/common.go b/threadTest/common.go new file mode 100644 index 0000000..44af24c --- /dev/null +++ b/threadTest/common.go @@ -0,0 +1,305 @@ +package threadTest + +import ( + "fmt" + "os" + clientPkg "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" + + "github.com/syndtr/goleveldb/leveldb" +) + +func ExecuteNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) { + if !isRestart { + os.RemoveAll("storage/node" + id) + } + + // 创建临时目录用于 leveldb + dbPath, err := os.MkdirTemp("", "simple-kv-store-"+id+"-") + if err != nil { + panic(fmt.Sprintf("无法创建临时数据库目录: %s", err)) + } + + // 创建临时目录用于 storage + storagePath, err := os.MkdirTemp("", "raft-storage-"+id+"-") + if err != nil { + panic(fmt.Sprintf("无法创建临时存储目录: %s", err)) + } + + db, err := leveldb.OpenFile(dbPath, nil) + if err != nil { + panic(fmt.Sprintf("Failed to open database: %s", err)) + } + + // 初始化 Raft 存储 + storage := nodes.NewRaftStorage(storagePath) + + var otherIds []string + for _, ids := range peerIds { + if ids != id { + otherIds = append(otherIds, ids) // 删除目标元素 + } + } + // 初始化 + node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport) + + // 开启 raft + go nodes.Start(node, quitChan) + return node, quitChan +} + +func ExecuteStaticNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) { + if !isRestart { + os.RemoveAll("storage/node" + id) + } + + os.RemoveAll("leveldb/simple-kv-store" + id) + + db, err := leveldb.OpenFile("leveldb/simple-kv-store"+id, nil) + if err != nil { + fmt.Println("Failed to open database: ", err) + } + + // 打开或创建节点数据持久化文件 + storage := nodes.NewRaftStorage("storage/node" + id) + + var otherIds []string + for _, ids := range peerIds { + if ids != id { + otherIds = append(otherIds, ids) // 删除目标元素 + } + } + // 初始化 + node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport) + + // 开启 raft + // go nodes.Start(node, quitChan) + return node, quitChan +} + +func StopElectionReset(nodeCollections []*nodes.Node) { + for i := 0; i < len(nodeCollections); i++ { + node := nodeCollections[i] + go func(node *nodes.Node) { + ticker := time.NewTicker(400 * time.Millisecond) + defer ticker.Stop() + + for { + <-ticker.C + node.ResetElectionTimer() // 不主动触发选举 + } + }(node) + } +} + +func SendKvCall(kvCall *nodes.LogEntryCall, node *nodes.Node) { + node.Mu.Lock() + defer node.Mu.Unlock() + + node.MaxLogId++ + logId := node.MaxLogId + rLogE := nodes.RaftLogEntry{LogE: kvCall.LogE, LogId: logId, Term: node.CurrTerm} + node.Log = append(node.Log, rLogE) + node.Storage.AppendLog(rLogE) + // 广播给其它节点 + node.BroadCastKV() +} + +func ClientWriteLog(t *testing.T, startLogid int, endLogid int, cWrite *clientPkg.Client) { + var s clientPkg.Status + for i := startLogid; i < endLogid; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = cWrite.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } +} + +func FindLeader(t *testing.T, nodeCollections []*nodes.Node) (i int) { + for i, node := range nodeCollections { + if node.State == nodes.Leader { + return i + } + } + t.Errorf("系统目前没有leader") + t.FailNow() + return 0 +} + +func CheckOneLeader(t *testing.T, nodeCollections []*nodes.Node) { + cnt := 0 + for _, node := range nodeCollections { + node.Mu.Lock() + if node.State == nodes.Leader { + cnt++ + } + node.Mu.Unlock() + } + if cnt != 1 { + t.Errorf("实际有%d个leader(!=1)", cnt) + t.FailNow() + } +} + +func CheckNoLeader(t *testing.T, nodeCollections []*nodes.Node) { + cnt := 0 + for _, node := range nodeCollections { + node.Mu.Lock() + if node.State == nodes.Leader { + cnt++ + } + node.Mu.Unlock() + } + if cnt != 0 { + t.Errorf("实际有%d个leader(!=0)", cnt) + t.FailNow() + } +} + +func CheckZeroOrOneLeader(t *testing.T, nodeCollections []*nodes.Node) { + cnt := 0 + for _, node := range nodeCollections { + node.Mu.Lock() + if node.State == nodes.Leader { + cnt++ + } + node.Mu.Unlock() + } + if cnt > 1 { + errmsg := fmt.Sprintf("%d个节点中,实际有%d个leader(>1)", len(nodeCollections), cnt) + WriteFailLog(nodeCollections[0].SelfId, errmsg) + t.Error(errmsg) + t.FailNow() + } +} + +func CheckIsLeader(t *testing.T, node *nodes.Node) { + node.Mu.Lock() + defer node.Mu.Unlock() + if node.State != nodes.Leader { + t.Errorf("[%s]不是leader", node.SelfId) + t.FailNow() + } +} + +func CheckTerm(t *testing.T, node *nodes.Node, targetTerm int) { + node.Mu.Lock() + defer node.Mu.Unlock() + if node.CurrTerm != targetTerm { + t.Errorf("[%s]实际term=%d (!=%d)", node.SelfId, node.CurrTerm, targetTerm) + t.FailNow() + } +} + +func CheckLogNum(t *testing.T, node *nodes.Node, targetnum int) { + node.Mu.Lock() + defer node.Mu.Unlock() + if len(node.Log) != targetnum { + t.Errorf("[%s]实际logNum=%d (!=%d)", node.SelfId, len(node.Log), targetnum) + t.FailNow() + } +} + +func CheckSameLog(t *testing.T, nodeCollections []*nodes.Node) { + nodeCollections[0].Mu.Lock() + defer nodeCollections[0].Mu.Unlock() + standard_node := nodeCollections[0] + for i, node := range nodeCollections { + if i != 0 { + node.Mu.Lock() + if len(node.Log) != len(standard_node.Log) { + errmsg := fmt.Sprintf("[%s]和[%s]日志数量不一致", nodeCollections[0].SelfId, node.SelfId) + WriteFailLog(node.SelfId, errmsg) + t.Error(errmsg) + t.FailNow() + } + + for idx, log := range node.Log { + standard_log := standard_node.Log[idx] + if log.Term != standard_log.Term || + log.LogE.Key != standard_log.LogE.Key || + log.LogE.Value != standard_log.LogE.Value { + errmsg := fmt.Sprintf("[1]和[%s]日志id%d不一致", node.SelfId, idx) + WriteFailLog(node.SelfId, errmsg) + t.Error(errmsg) + t.FailNow() + } + } + node.Mu.Unlock() + } + } +} + +func CheckLeaderInvariant(t *testing.T, nodeCollections []*nodes.Node) { + leaderCnt := make(map[int]bool) + for _, node := range nodeCollections { + node.Mu.Lock() + if node.State == nodes.Leader { + if _, exist := leaderCnt[node.CurrTerm]; exist { + errmsg := fmt.Sprintf("在%d有多个leader(%s)", node.CurrTerm, node.SelfId) + WriteFailLog(node.SelfId, errmsg) + t.Error(errmsg) + } else { + leaderCnt[node.CurrTerm] = true + } + } + node.Mu.Unlock() + } +} + +func CheckLogInvariant(t *testing.T, nodeCollections []*nodes.Node) { + nodeCollections[0].Mu.Lock() + defer nodeCollections[0].Mu.Unlock() + standard_node := nodeCollections[0] + standard_len := len(standard_node.Log) + for i, node := range nodeCollections { + if i != 0 { + node.Mu.Lock() + len2 := len(node.Log) + var shorti int + if len2 < standard_len { + shorti = len2 + } else { + shorti = standard_len + } + if shorti == 0 { + node.Mu.Unlock() + continue + } + + alreadySame := false + for i := shorti - 1; i >= 0; i-- { + standard_log := standard_node.Log[i] + log := node.Log[i] + if alreadySame { + if log.Term != standard_log.Term || + log.LogE.Key != standard_log.LogE.Key || + log.LogE.Value != standard_log.LogE.Value { + errmsg := fmt.Sprintf("[%s]和[%s]日志id%d不一致", standard_node.SelfId, node.SelfId, i) + WriteFailLog(node.SelfId, errmsg) + t.Error(errmsg) + t.FailNow() + } + } else { + if log.Term == standard_log.Term && + log.LogE.Key == standard_log.LogE.Key && + log.LogE.Value == standard_log.LogE.Value { + alreadySame = true + } + } + } + node.Mu.Unlock() + } + } +} + +func WriteFailLog(name string, errmsg string) { + f, _ := os.Create(name + ".log") + fmt.Fprint(f, errmsg) + f.Close() +} diff --git a/threadTest/election_test.go b/threadTest/election_test.go new file mode 100644 index 0000000..f9078ee --- /dev/null +++ b/threadTest/election_test.go @@ -0,0 +1,313 @@ +package threadTest + +import ( + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestInitElection(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) +} + +func TestRepeatElection(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + go nodeCollections[0].StartElection() + go nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 3) +} + +func TestBelowHalfCandidateElection(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + go nodeCollections[0].StartElection() + go nodeCollections[1].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + for i := 0; i < n; i++ { + CheckTerm(t, nodeCollections[i], 2) + } +} + +func TestOverHalfCandidateElection(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + go nodeCollections[0].StartElection() + go nodeCollections[1].StartElection() + go nodeCollections[2].StartElection() + time.Sleep(time.Second) + + CheckZeroOrOneLeader(t, nodeCollections) + for i := 0; i < n; i++ { + CheckTerm(t, nodeCollections[i], 2) + } +} + +func TestRepeatVoteRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + ctx.SetBehavior("1", "2", nodes.RetryRpc, 0, 2) + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < n; i++ { + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2) + ctx.SetBehavior("2", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2) + } + + go nodeCollections[0].StartElection() + go nodeCollections[1].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckTerm(t, nodeCollections[0], 3) +} + +func TestFailVoteRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + ctx.SetBehavior("1", "2", nodes.FailRpc, 0, 0) + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + ctx.SetBehavior("1", "3", nodes.FailRpc, 0, 0) + ctx.SetBehavior("1", "4", nodes.FailRpc, 0, 0) + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckNoLeader(t, nodeCollections) + CheckTerm(t, nodeCollections[0], 3) +} + +func TestDelayVoteRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0) + } + + nodeCollections[0].StartElection() + time.Sleep(2 * time.Second) + + CheckNoLeader(t, nodeCollections) + for i := 0; i < n; i++ { + CheckTerm(t, nodeCollections[i], 2) + } + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, 50 * time.Millisecond, 0) + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + + CheckOneLeader(t, nodeCollections) + for i := 0; i < n; i++ { + CheckTerm(t, nodeCollections[i], 3) + } +} diff --git a/threadTest/fuzz/fuzz_test.go b/threadTest/fuzz/fuzz_test.go new file mode 100644 index 0000000..f65c789 --- /dev/null +++ b/threadTest/fuzz/fuzz_test.go @@ -0,0 +1,445 @@ +package fuzz + +import ( + "fmt" + "math/rand" + "os" + "runtime/debug" + "sync" + "testing" + "time" + + clientPkg "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "simple-kv-store/threadTest" + "strconv" +) + +// 1.针对随机配置随机消息状态 +func FuzzRaftBasic(f *testing.F) { + var seenSeeds sync.Map + // 添加初始种子 + f.Add(int64(1)) + fmt.Println("Running") + + f.Fuzz(func(t *testing.T, seed int64) { + if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded { + t.Skipf("Seed %d already tested, skipping...", seed) + return + } + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack()) + f, _ := os.Create("panic_goroutine.log") + fmt.Fprint(f, msg) + f.Close() + } + }() + + r := rand.New(rand.NewSource(seed)) // 使用局部 rand + + n := 3 + 2*(r.Intn(4)) + fmt.Printf("随机了%d个节点\n", n) + logs := (r.Intn(10)) + fmt.Printf("随机了%d份日志\n", logs) + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1)) + } + + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + + for i := 0; i < n; i++ { + node, quitChan := threadTest.ExecuteNodeI(strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1), false, peerIds, threadTransport) + nodeCollections = append(nodeCollections, node) + node.RTTable.SetElectionTimeout(750 * time.Millisecond) + quitCollections = append(quitCollections, quitChan) + } + + // 模拟 a-b 通讯行为 + faultyNodes := injectRandomBehavior(ctx, r, peerIds) + + time.Sleep(time.Second) + + clientObj := clientPkg.NewClient("0", peerIds, threadTransport) + + for i := 0; i < logs; i++ { + key := fmt.Sprintf("k%d", i) + log := nodes.LogEntry{Key: key, Value: "v"} + clientObj.Write(log) + } + + time.Sleep(time.Second) + + var rightNodeCollections []*nodes.Node + for _, node := range nodeCollections { + if !faultyNodes[node.SelfId] { + rightNodeCollections = append(rightNodeCollections, node) + } + } + threadTest.CheckSameLog(t, rightNodeCollections) + threadTest.CheckLeaderInvariant(t, nodeCollections) + + for _, quitChan := range quitCollections { + close(quitChan) + } + time.Sleep(time.Second) + for i := 0; i < n; i++ { + // 确保完成退出 + nodeCollections[i].Mu.Lock() + if !nodeCollections[i].IsFinish { + nodeCollections[i].IsFinish = true + } + nodeCollections[i].Mu.Unlock() + + os.RemoveAll("leveldb/simple-kv-store" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1)) + os.RemoveAll("storage/node" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1)) + } + }) +} + +// 注入节点间行为 +func injectRandomBehavior(ctx *nodes.Ctx, r *rand.Rand, peers []string) map[string]bool /*id:Isfault*/ { + behaviors := []nodes.CallBehavior{ + nodes.FailRpc, + nodes.DelayRpc, + nodes.RetryRpc, + } + n := len(peers) + maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点 + + // 随机选择出问题的节点 + shuffled := append([]string(nil), peers...) + r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + faultyNodes := make(map[string]bool) + for i := 0; i < maxFaulty; i++ { + faultyNodes[shuffled[i]] = true + } + + for _, one := range peers { + if faultyNodes[one] { + b := behaviors[r.Intn(len(behaviors))] + delay := time.Duration(r.Intn(100)) * time.Millisecond + + switch b { + case nodes.FailRpc: + fmt.Printf("[%s]的异常行为是fail\n", one) + case nodes.DelayRpc: + fmt.Printf("[%s]的异常行为是delay\n", one) + case nodes.RetryRpc: + fmt.Printf("[%s]的异常行为是retry\n", one) + } + + for _, two := range peers { + if one == two { + continue + } + + if faultyNodes[one] && faultyNodes[two] { + ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0) + ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0) + } else { + ctx.SetBehavior(one, two, b, delay, 2) + ctx.SetBehavior(two, one, b, delay, 2) + } + } + } + + } + return faultyNodes +} + +// 2.对一个长时间运行的系统,注入随机行为 +func FuzzRaftRobust(f *testing.F) { + var seenSeeds sync.Map + var fuzzMu sync.Mutex + // 添加初始种子 + f.Add(int64(0)) + fmt.Println("Running") + n := 5 + + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i+1)) + } + + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + quitCollections := make(map[string]chan struct{}) + nodeCollections := make(map[string]*nodes.Node) + + for i := 0; i < n; i++ { + id := strconv.Itoa(i+1) + node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport) + nodeCollections[id] = node + quitCollections[id] = quitChan + } + + f.Fuzz(func(t *testing.T, seed int64) { + fuzzMu.Lock() + defer fuzzMu.Unlock() + if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded { + t.Skipf("Seed %d already tested, skipping...", seed) + return + } + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack()) + f, _ := os.Create("panic_goroutine.log") + fmt.Fprint(f, msg) + f.Close() + } + }() + + r := rand.New(rand.NewSource(seed)) // 使用局部 rand + + clientObj := clientPkg.NewClient("0", peerIds, threadTransport) + + faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections) + + key := fmt.Sprintf("k%d", seed % 10) + log := nodes.LogEntry{Key: key, Value: "v"} + clientObj.Write(log) + + time.Sleep(time.Second) + + var rightNodeCollections []*nodes.Node + for _, node := range nodeCollections { + _, exist := faultyNodes[node.SelfId] + if !exist { + rightNodeCollections = append(rightNodeCollections, node) + } + } + threadTest.CheckLogInvariant(t, rightNodeCollections) + threadTest.CheckLeaderInvariant(t, rightNodeCollections) + + // ResetFaultyNodes + threadTransport.ResetConnectivity() + for id, isrestart := range faultyNodes { + if !isrestart { + for _, peerIds := range peerIds { + if id == peerIds { + continue + } + ctx.SetBehavior(id, peerIds, nodes.NormalRpc, 0, 0) + ctx.SetBehavior(peerIds, id, nodes.NormalRpc, 0, 0) + } + } else { + newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport) + quitCollections[id] = quitChan + nodeCollections[id] = newNode + } + fmt.Printf("[%s]恢复异常\n", id) + } + }) + for _, quitChan := range quitCollections { + close(quitChan) + } + time.Sleep(time.Second) + for id, node := range nodeCollections { + // 确保完成退出 + node.Mu.Lock() + if !node.IsFinish { + node.IsFinish = true + } + node.Mu.Unlock() + + os.RemoveAll("leveldb/simple-kv-store" + id) + os.RemoveAll("storage/node" + id) + } +} + +// 3.综合 +func FuzzRaftPlus(f *testing.F) { + var seenSeeds sync.Map + // 添加初始种子 + f.Add(int64(0)) + fmt.Println("Running") + + f.Fuzz(func(t *testing.T, seed int64) { + if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded { + t.Skipf("Seed %d already tested, skipping...", seed) + return + } + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack()) + f, _ := os.Create("panic_goroutine.log") + fmt.Fprint(f, msg) + f.Close() + } + }() + + r := rand.New(rand.NewSource(seed)) // 使用局部 rand + + n := 3 + 2*(r.Intn(4)) + fmt.Printf("随机了%d个节点\n", n) + ElectionTimeOut := 500 + r.Intn(500) + fmt.Printf("随机的投票超时时间:%d\n", ElectionTimeOut) + + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1)) + } + + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + quitCollections := make(map[string]chan struct{}) + nodeCollections := make(map[string]*nodes.Node) + + for i := 0; i < n; i++ { + id := strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1) + node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport) + nodeCollections[id] = node + node.RTTable.SetElectionTimeout(time.Duration(ElectionTimeOut) * time.Millisecond) + quitCollections[id] = quitChan + } + + clientObj := clientPkg.NewClient("0", peerIds, threadTransport) + + for i := 0; i < 5; i++ { // 模拟10次异常 + fmt.Printf("第%d轮异常注入开始\n", i + 1) + faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections) + + key := fmt.Sprintf("k%d", i) + log := nodes.LogEntry{Key: key, Value: "v"} + clientObj.Write(log) + + time.Sleep(time.Second) + + var rightNodeCollections []*nodes.Node + for _, node := range nodeCollections { + _, exist := faultyNodes[node.SelfId] + if !exist { + rightNodeCollections = append(rightNodeCollections, node) + } + } + threadTest.CheckLogInvariant(t, rightNodeCollections) + threadTest.CheckLeaderInvariant(t, rightNodeCollections) + + // ResetFaultyNodes + threadTransport.ResetConnectivity() + for id, isrestart := range faultyNodes { + if !isrestart { + for _, peerId := range peerIds { + if id == peerId { + continue + } + ctx.SetBehavior(id, peerId, nodes.NormalRpc, 0, 0) + ctx.SetBehavior(peerId, id, nodes.NormalRpc, 0, 0) + } + } else { + newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport) + quitCollections[id] = quitChan + nodeCollections[id] = newNode + } + fmt.Printf("[%s]恢复异常\n", id) + } + + } + + for _, quitChan := range quitCollections { + close(quitChan) + } + time.Sleep(time.Second) + for id, node := range nodeCollections { + // 确保完成退出 + node.Mu.Lock() + if !node.IsFinish { + node.IsFinish = true + } + node.Mu.Unlock() + + os.RemoveAll("leveldb/simple-kv-store" + id) + os.RemoveAll("storage/node" + id) + } + }) +} + +func injectRandomBehavior2(ctx *nodes.Ctx, r *rand.Rand, peers []string, tran *nodes.ThreadTransport, quitCollections map[string]chan struct{}) map[string]bool /*id:needRestart*/ { + + n := len(peers) + maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点 + + // 随机选择出问题的节点 + shuffled := append([]string(nil), peers...) + r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + faultyNodes := make(map[string]bool) + for i := 0; i < maxFaulty; i++ { + faultyNodes[shuffled[i]] = false + } + PartitionNodes := make(map[string]bool) + + for _, one := range peers { + _, exist := faultyNodes[one] + if exist { + b := r.Intn(5) + + switch b { + case 0: + fmt.Printf("[%s]的异常行为是fail\n", one) + for _, two := range peers { + if one == two { + continue + } + ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0) + ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0) + } + case 1: + fmt.Printf("[%s]的异常行为是delay\n", one) + t := r.Intn(100) + fmt.Printf("[%s]的delay time = %d\n", one, t) + delay := time.Duration(t) * time.Millisecond + for _, two := range peers { + if one == two { + continue + } + _, exist2 := faultyNodes[two] + if exist2 { + ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0) + ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0) + } else { + ctx.SetBehavior(one, two, nodes.DelayRpc, delay, 0) + ctx.SetBehavior(two, one, nodes.DelayRpc, delay, 0) + } + } + case 2: + fmt.Printf("[%s]的异常行为是retry\n", one) + for _, two := range peers { + if one == two { + continue + } + _, exist2 := faultyNodes[two] + if exist2 { + ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0) + ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0) + } else { + ctx.SetBehavior(one, two, nodes.RetryRpc, 0, 2) + ctx.SetBehavior(two, one, nodes.RetryRpc, 0, 2) + } + } + case 3: + fmt.Printf("[%s]的异常行为是stop\n", one) + faultyNodes[one] = true + close(quitCollections[one]) + + case 4: + fmt.Printf("[%s]的异常行为是partition\n", one) + PartitionNodes[one] = true + } + } + } + for id, _ := range PartitionNodes { + for _, two := range peers { + if !PartitionNodes[two] { + tran.SetConnectivity(id, two, false) + tran.SetConnectivity(two, id, false) + } + } + } + + return faultyNodes +} diff --git a/threadTest/log_replication_test.go b/threadTest/log_replication_test.go new file mode 100644 index 0000000..c9b7c2c --- /dev/null +++ b/threadTest/log_replication_test.go @@ -0,0 +1,326 @@ +package threadTest + +import ( + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestNormalReplication(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + time.Sleep(time.Second) + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} + +func TestParallelReplication(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + go nodeCollections[0].BroadCastKV() + } + + time.Sleep(time.Second) + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} + +func TestFollowerLagging(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + close(quitCollections[1]) + time.Sleep(time.Second) + + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + node, q := ExecuteStaticNodeI("2", true, peerIds, threadTransport) + quitCollections[1] = q + nodeCollections[1] = node + nodeCollections[1].State = nodes.Follower + StopElectionReset(nodeCollections[1:2]) + nodeCollections[0].BroadCastKV() + + time.Sleep(time.Second) + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} + +func TestFailLogAppendRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < n; i++ { + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.FailRpc, 0, 0) + } + + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + time.Sleep(time.Second) + for i := 1; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 0) + } +} + +func TestRepeatLogAppendRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < n; i++ { + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2) + } + + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + time.Sleep(time.Second) + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} + +func TestDelayLogAppendRpc(t *testing.T) { + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + StopElectionReset(nodeCollections) + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + for i := 0; i < n; i++ { + nodeCollections[i].State = nodes.Follower + } + + nodeCollections[0].StartElection() + time.Sleep(time.Second) + CheckOneLeader(t, nodeCollections) + CheckIsLeader(t, nodeCollections[0]) + CheckTerm(t, nodeCollections[0], 2) + + for i := 0; i < n; i++ { + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0) + } + + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + time.Sleep(time.Millisecond * 100) + + for i := 0; i < n; i++ { + ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.NormalRpc, 0, 0) + } + for i := 5; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0]) + } + + time.Sleep(time.Second * 2) + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} diff --git a/threadTest/network_partition_test.go b/threadTest/network_partition_test.go new file mode 100644 index 0000000..610e22d --- /dev/null +++ b/threadTest/network_partition_test.go @@ -0,0 +1,245 @@ +package threadTest + +import ( + "fmt" + clientPkg "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "strings" + "testing" + "time" +) + +func TestBasicConnectivity(t *testing.T) { + transport := nodes.NewThreadTransport(nodes.NewCtx()) + + transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10)) + transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10)) + + // 断开 A 和 B + transport.SetConnectivity("1", "2", false) + + err := transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{}) + if err == nil { + t.Errorf("Expected network partition error, but got nil") + } + + // 恢复连接 + transport.SetConnectivity("1", "2", true) + + err = transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{}) + if !strings.Contains(err.Error(), "RPC 调用超时") { + t.Errorf("Expected success, but got error: %v", err) + } +} + +func TestSingelPartition(t *testing.T) { + // 登记结点信息 + n := 3 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + fmt.Println("开始分区模拟1") + var leaderNo int + for i := 0; i < n; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderNo = i + for j := 0; j < n; j++ { + if i != j { // 切断其它节点到leader的消息 + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false) + } + } + } + } + time.Sleep(2 * time.Second) + + if nodeCollections[leaderNo].State == nodes.Leader { + t.Errorf("分区退选失败") + } + + // 恢复网络 + for j := 0; j < n; j++ { + if leaderNo != j { // 恢复其它节点到leader的消息 + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[leaderNo].SelfId, true) + } + } + + time.Sleep(1 * time.Second) + + var leaderCnt int + for i := 0; i < n; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderCnt++ + leaderNo = i + } + } + if leaderCnt != 1 { + t.Errorf("多leader产生") + } + + fmt.Println("开始分区模拟2") + for j := 0; j < n; j++ { + if leaderNo != j { // 切断leader到其它节点的消息 + threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, false) + } + } + time.Sleep(1 * time.Second) + + if nodeCollections[leaderNo].State == nodes.Leader { + t.Errorf("分区退选失败") + } + + leaderCnt = 0 + for j := 0; j < n; j++ { + if nodeCollections[j].State == nodes.Leader { + leaderCnt++ + } + } + if leaderCnt != 1 { + t.Errorf("多leader产生") + } + + // client启动 + c := clientPkg.NewClient("0", peerIds, threadTransport) + var s clientPkg.Status + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = c.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + + // 恢复网络 + for j := 0; j < n; j++ { + if leaderNo != j { + threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, true) + } + } + + time.Sleep(time.Second) + // 日志一致性检查 + for i := 0; i < n; i++ { + if len(nodeCollections[i].Log) != 5 { + t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log))) + } + } +} + +func TestQuorumPartition(t *testing.T) { + // 登记结点信息 + n := 5 // 奇数,模拟不超过半数节点分区 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + fmt.Println("开始分区模拟1") + for i := 0; i < n / 2; i++ { + for j := n / 2; j < n; j++ { + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false) + threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, false) + } + } + time.Sleep(2 * time.Second) + + leaderCnt := 0 + for i := 0; i < n / 2; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderCnt++ + } + } + if leaderCnt != 0 { + t.Errorf("少数分区不应该产生leader") + } + + for i := n / 2; i < n; i++ { + if nodeCollections[i].State == nodes.Leader { + leaderCnt++ + } + } + if leaderCnt != 1 { + t.Errorf("多数分区应该产生一个leader") + } + + // client启动 + c := clientPkg.NewClient("0", peerIds, threadTransport) + var s clientPkg.Status + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = c.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + + // 恢复网络 + for i := 0; i < n / 2; i++ { + for j := n / 2; j < n; j++ { + threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, true) + threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, true) + } + } + time.Sleep(1 * time.Second) + + leaderCnt = 0 + for j := 0; j < n; j++ { + if nodeCollections[j].State == nodes.Leader { + leaderCnt++ + } + } + if leaderCnt != 1 { + t.Errorf("多leader产生") + } + + // 日志一致性检查 + for i := 0; i < n; i++ { + if len(nodeCollections[i].Log) != 5 { + t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log))) + } + } +} \ No newline at end of file diff --git a/threadTest/restart_node_test.go b/threadTest/restart_node_test.go new file mode 100644 index 0000000..71bde83 --- /dev/null +++ b/threadTest/restart_node_test.go @@ -0,0 +1,129 @@ +package threadTest + +import ( + "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestNodeRestart(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动, 连接任意节点 + cWrite := clientPkg.NewClient("0", peerIds, threadTransport) + // 写入 + ClientWriteLog(t, 0, 5, cWrite) + time.Sleep(time.Second) // 等待写入完毕 + + // 模拟结点轮流崩溃 + for i := 0; i < n; i++ { + close(quitCollections[i]) + + time.Sleep(time.Second) + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), true, peerIds, threadTransport) + quitCollections[i] = quitChan + time.Sleep(time.Second) // 等待启动完毕 + } + + + // client启动 + cRead := clientPkg.NewClient("0", peerIds, threadTransport) + // 读写入数据 + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + var value string + s := cRead.Read(key, &value) + if s != clientPkg.Ok { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 5; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s := cRead.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } +} + + +func TestRestartWhileWriting(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + leaderIdx := FindLeader(t, nodeCollections) + // client启动, 连接任意节点 + cWrite := clientPkg.NewClient("0", peerIds, threadTransport) + // 写入 + go ClientWriteLog(t, 0, 5, cWrite) + + go func() { + close(quitCollections[leaderIdx]) + + n, quitChan := ExecuteNodeI(strconv.Itoa(leaderIdx + 1), true, peerIds, threadTransport) + quitCollections[leaderIdx] = quitChan + nodeCollections[leaderIdx] = n + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动 + cRead := clientPkg.NewClient("0", peerIds, threadTransport) + // 读写入数据 + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + var value string + s := cRead.Read(key, &value) + if s != clientPkg.Ok { + t.Errorf("Read test1 fail") + } + } + CheckLogNum(t, nodeCollections[leaderIdx], 5) +} diff --git a/threadTest/server_client_test.go b/threadTest/server_client_test.go new file mode 100644 index 0000000..8b80ac1 --- /dev/null +++ b/threadTest/server_client_test.go @@ -0,0 +1,192 @@ +package threadTest + +import ( + "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestServerClient(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + threadTransport := nodes.NewThreadTransport(nodes.NewCtx()) + for i := 0; i < n; i++ { + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动 + c := clientPkg.NewClient("0", peerIds, threadTransport) + + // 写入 + var s clientPkg.Status + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = c.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + // 读写入数据 + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.Ok || value != "hello" { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 10; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } +} + +func TestRepeatClientReq(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动 + c := clientPkg.NewClient("0", peerIds, threadTransport) + for i := 0; i < n; i++ { + ctx.SetBehavior("", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2) + } + + // 写入 + var s clientPkg.Status + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s = c.Write(newlog) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + // 读写入数据 + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.Ok || value != "hello" { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 10; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } + + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 10) + } +} + +func TestParallelClientReq(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + var nodeCollections []*nodes.Node + ctx := nodes.NewCtx() + threadTransport := nodes.NewThreadTransport(ctx) + for i := 0; i < n; i++ { + n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + nodeCollections = append(nodeCollections, n) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动 + c1 := clientPkg.NewClient("0", peerIds, threadTransport) + c2 := clientPkg.NewClient("1", peerIds, threadTransport) + + + + // 写入 + go ClientWriteLog(t, 0, 10, c1) + go ClientWriteLog(t, 0, 10, c2) + + time.Sleep(time.Second) // 等待写入完毕 + // 读写入数据 + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + var value string + s := c1.Read(key, &value) + if s != clientPkg.Ok || value != "hello" { + t.Errorf("Read test1 fail") + } + } + + for i := 0; i < n; i++ { + CheckLogNum(t, nodeCollections[i], 20) + } +}