Browse Source

Merge pull request 'ld' (#2) from ld into master

Reviewed-on: https://gitea.shuishan.net.cn/10225501448/simple-kv-store/pulls/2
master
李度 5 months ago
parent
commit
ea6fd20628
33 changed files with 3978 additions and 470 deletions
  1. +3
    -0
      .gitignore
  2. +232
    -16
      README.md
  3. +25
    -24
      cmd/main.go
  4. +122
    -46
      internal/client/client_node.go
  5. +15
    -0
      internal/logprovider/traceback.go
  6. +217
    -77
      internal/nodes/init.go
  7. +19
    -6
      internal/nodes/log.go
  8. +43
    -72
      internal/nodes/node.go
  9. +194
    -0
      internal/nodes/node_storage.go
  10. +46
    -0
      internal/nodes/random_timetable.go
  11. +111
    -0
      internal/nodes/real_transport.go
  12. +271
    -0
      internal/nodes/replica.go
  13. +68
    -16
      internal/nodes/server_node.go
  14. +65
    -0
      internal/nodes/simulate_ctx.go
  15. +234
    -0
      internal/nodes/thread_transport.go
  16. +10
    -0
      internal/nodes/transport.go
  17. +215
    -0
      internal/nodes/vote.go
  18. BIN
      pics/plan.png
  19. BIN
      pics/plus.png
  20. BIN
      pics/robust.png
  21. BIN
      raft第二次汇报.pptx
  22. +0
    -61
      scripts/run.sh
  23. +7
    -15
      test/common.go
  24. +0
    -112
      test/restart_follower_test.go
  25. +104
    -0
      test/restart_node_test.go
  26. +22
    -25
      test/server_client_test.go
  27. +305
    -0
      threadTest/common.go
  28. +313
    -0
      threadTest/election_test.go
  29. +445
    -0
      threadTest/fuzz/fuzz_test.go
  30. +326
    -0
      threadTest/log_replication_test.go
  31. +245
    -0
      threadTest/network_partition_test.go
  32. +129
    -0
      threadTest/restart_node_test.go
  33. +192
    -0
      threadTest/server_client_test.go

+ 3
- 0
.gitignore View File

@ -26,3 +26,6 @@ go.work
main
leveldb
storage
*.log
testdata

+ 232
- 16
README.md View File

@ -1,24 +1,240 @@
# go-raft-kv
---
# 简介
本项目是基于go语言实现的一个raft算法分布式kv数据库。
基于go语言实现分布式kv数据库
项目亮点如下:
支持线程与进程两种通信机制,方便测试与部署切换的同时,降低了系统耦合度
提供基于状态不变量的fuzz测试机制,增强健壮性
项目高度模块化,便于扩展更多功能
本报告主要起到工作总结,以及辅助阅读代码的作用(并贴出了一些关键部分),一些工作中遇到的具体问题、思考和收获不太方便用简洁的文字描述,放在了汇报的ppt中。
# 项目框架
```
---cmd
---main.go 进程版的启动
---internal
---client 客户端使用节点提供的读写功能
---logprovider 封装了简单的日志打印,方便调试
---nodes 分布式核心代码
init.go 节点的初始化(包含两种版本),和大循环启动
log.go 节点存储的entry相关数据结构
node_storage.go 抽象了节点数据持久化方法,序列化后存到leveldb里
node.go 节点的相关数据结构
random_timetake.go 控制系统中的随机时间
real_transport.go 进程版的rpc通讯
replica.go 日志复制相关逻辑
server_node.go 节点作为server为 client提供的功能(读写)
simulate_ctx.go 测试中控制通讯消息行为
thread_transport.go 线程版的通讯方法
transport.go 为两种系统提供的基类通讯接口
vote.go 选主相关逻辑
---test 进程版的测试
---threadTest 线程版的测试
---fuzz 随机测试部分
election_test.go 选举部分
log_replication_test.go 日志复制部分
network_partition_test.go 网络分区部分
restart_node_test.go 恢复测试
server_client_test.go 客户端交互
```
# raft系统部分
## 主要流程
在init.go中每个节点会初始化,发布监听线程,然后在start函数开启主循环。主循环中每隔一段心跳时间,如果判断自己是leader,就broadcast并resetElectionTime。(init.go)
```go
func Start(node *Node, quitChan chan struct{}) {
node.Mu.Lock()
node.State = Follower // 所有节点以 Follower 状态启动
node.Mu.Unlock()
node.ResetElectionTimer() // 启动选举超时定时器
go func() {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.SelfId)
return // 退出 goroutine
case <-ticker.C:
node.Mu.Lock()
state := node.State
node.Mu.Unlock()
switch state {
case Follower:
// 监听心跳超时
case Leader:
// 发送心跳
node.ResetElectionTimer() // leader 不主动触发选举
node.BroadCastKV()
}
}
}
}()
}
```
两个监听方法:
appendEntries:broadcast中遍历每个peerNode,在sendkv中进行call,实现日志复制相关逻辑。
requestVote:每个node有个ResetElectionTimer定时器,一段时间没有reset它就会StartElection,其中遍历每个peerNode,在sendRequestVote中进行call,实现选主相关逻辑。leader会在心跳时reset(避免自己的再选举),follower则会在收到appendentries时reset。
```go
func (node *Node) ResetElectionTimer() {
node.MuElection.Lock()
defer node.MuElection.Unlock()
if node.ElectionTimer == nil {
node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout())
go func() {
for {
<-node.ElectionTimer.C
node.StartElection()
}
}()
} else {
node.ElectionTimer.Stop()
node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
}
}
```
## 客户端工作原理(client & server_node.go)
### 客户端写
客户端每次会随机连上集群中一个节点,此时有四种情况:
a 节点认为自己是leader,直接处理请求(记录后broadcast)
b 节点认为自己是follower,且有知道的leader,返回leader的id。客户端再连接这个新的id,新节点重新分析四种情况。
c 节点认为自己是follower,但不知道leader是谁,返回空的id。客户端再随机连一个节点
d 连接超时,客户端重新随机连一个节点
```go
func (client *Client) Write(kv nodes.LogEntry) Status {
kvCall := nodes.LogEntryCall{LogE: kv,
Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}}
client.NextLogId++
c := client.FindActiveNode()
var err error
timeout := time.Second
deadline := time.Now().Add(timeout)
for { // 根据存活节点的反馈,直到找到leader
if time.Now().After(deadline) {
return Fail
}
var reply nodes.ServerReply
reply.Isleader = false
callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
if !reply.Isleader { // 对方不是leader,根据反馈找leader
leaderId := reply.LeaderId
client.CloseRpcClient(c)
if leaderId == "" { // 这个节点不知道leader是谁,再随机找
c = client.FindActiveNode()
} else { // dial leader
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId)
if err != nil { // dial失败,重新找下一个存活节点
c = client.FindActiveNode()
}
}
} else { // 成功
client.CloseRpcClient(c)
return Ok
}
}
}
```
### 客户端读
随机连上集群中一个节点,读它commit的kv。
## 重要数据持久化
封装在了node_storage.go中,主要是setTermAndVote序列化后写入leveldb,log的写入。在这些数据变化时调用它们进行持久化,以及相应的恢复时读取。
```go
// SetTermAndVote 原子更新 term 和 vote
func (rs *RaftStorage) SetTermAndVote(term int, candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
batch.Put([]byte("current_term"), []byte(strconv.Itoa(term)))
batch.Put([]byte("voted_for"), []byte(candidate))
err := rs.db.Write(batch, nil) // 原子提交
if err != nil {
log.Error("SetTermAndVote 持久化失败:", zap.Error(err))
}
}
```
## 两种系统的切换
将raft系统中,所有涉及网络的部分提取出来,抽象为dial和call方法,作为每个node的接口类transport的两个基类方法,进程版与线程版的transport派生类分别实现,使得相互之间实现隔离。
```go
type Transport interface {
DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
}
```
进程版:dial和call均为go原生rpc库的方法,加一层timeout封装(real_transport.go)
线程版:threadTransport为每个节点共用,节点初始化时把一个自己的chan注册进里面的map,然后go一个线程去监听它,收到req后去调用自己对应的函数(thread_transport.go)
# 测试部分
## 单元测试
从Leader选举、日志复制、崩溃恢复、网络分区、客户端交互五个维度,对系统进行分模块的测试。测试中夹杂消息状态的细粒度模拟,尽可能在项目前中期验证代码与思路的一致性,避免大的问题。
## fuzz测试
分为不同节点、系统随机时间配置测试异常的多系统随机(basic),与对单个系统注入多个随机异常的单系统随机(robust),这两个维度,以及最后综合两个维度的进一步测试(plus)。
测试中加入了raft的TLA标准,作为测试断言,确保系统在运行中的稳定性。
fuzz test不仅覆盖了单元测试的内容,也在随机的测试中发现了更多边界条件的异常,以及通过系统状态的不变量检测,确保系统在不同配置下支持长时间的运行中保持正确可用。
![alt text](pics/plus.png)
![alt text](pics/robust.png)
## bug的简单记录
LogId0的歧义,不同接口对接日志编号出现问题
随机选举超时相同导致的candidate卡死问题
重要数据持久化不原子、与状态机概念混淆
客户端缺乏消息唯一标识,导致系统重复执行
重构系统过程中lock使用不当
伪同步接口的异步语义陷阱
测试和系统混合产生的bug(延迟导致的超时、退出不完全导致的异常、文件系统异常、lock不当)
# 环境与运行
使用环境是wsl+ubuntu
go mod download安装依赖
./scripts/build.sh 会在根目录下编译出main
./scripts/run.sh 运行三个节点,目前能在终端进行读入,leader(n1)节点输出send log,其余节点输出receive log。终端输入后如果超时就退出(脚本运行时间可以在其中调整)。
# 注意
脚本第一次运行需要权限获取 chmod +x <脚本>
如果出现tcp listen error可能是因为之前的进程没用正常退出,占用了端口
lsof -i :9091查看pid
kill -9 <pid>杀死进程
## 关于测试
通过新开进程的方式创建节点,如果通过线程创建,会出现重复注册rpc问题
# todo list
消息通讯异常的处理
kv本地持久化
崩溃与恢复(以及对应的测试)
./scripts/build.sh 会在根目录下编译出main(进程级的测试需要)
# 参考资料
In Search of an Understandable Consensus Algorithm
Consensus: Bridging Theory and Practice
Raft TLA+ Specification
全项目除了logprovider文件夹下的一些go的日志库使用参考了一篇博客的封装,其余皆为独立原创。
# 分工
| 姓名 | 工作 | 贡献度 |
|--------|--------|--------|
| 李度 | raft系统设计+实现,测试设计+实现 | 75% |
| 马也驰 | raft系统设计,测试设计+实现 | 25% |
![alt text](pics/plan.png)

+ 25
- 24
cmd/main.go View File

@ -3,6 +3,7 @@ package main
import (
"flag"
"fmt"
"github.com/syndtr/goleveldb/leveldb"
"os"
"os/signal"
"simple-kv-store/internal/logprovider"
@ -10,7 +11,6 @@ import (
"strconv"
"strings"
"syscall"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
@ -29,11 +29,9 @@ func main() {
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
port := flag.String("port", ":9091", "rpc listen port")
cluster := flag.String("cluster", "127.0.0.1:9092,127.0.0.1:9093", "comma sep")
cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep")
id := flag.String("id", "1", "node ID")
pipe := flag.String("pipe", "", "input from scripts")
isLeader := flag.Bool("isleader", false, "init node state")
isNewDb := flag.Bool("isNewDb", true, "new test or restart")
isRestart := flag.Bool("isRestart", false, "new test or restart")
// 参数解析
flag.Parse()
@ -42,43 +40,46 @@ func main() {
idCnt := 1
selfi, err := strconv.Atoi(*id)
if err != nil {
log.Error("figure id only")
log.Fatal("figure id only")
}
for _, addr := range clusters {
if idCnt == selfi {
idCnt++ // 命令行cluster按id排序传入,记录时跳过自己的id,先保证所有节点互相记录的id一致
continue
}
idClusterPairs[strconv.Itoa(idCnt)] = addr
idClusterPairs[strconv.Itoa(idCnt)] = addr
idCnt++
}
if *isNewDb {
os.RemoveAll("leveldb/simple-kv-store" + *id)
// storage/文件夹下为node重要数据持久化数据库,节点一旦创建成功就不能被删除
if !*isRestart {
os.RemoveAll("storage/node" + *id)
}
// 打开或创建每个结点自己的数据库
// 创建每个结点自己的数据库。这里一开始理解上有些误区,状态机的状态恢复应该靠节点的持久化log,
// 而用leveldb模拟状态机,造成了状态机本身的持久化,因此通过删去旧db避免这一矛盾
// 因此leveldb/文件夹下为状态机模拟数据库,每次节点启动都需要删除该数据库
os.RemoveAll("leveldb/simple-kv-store" + *id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil)
if err != nil {
log.Fatal("Failed to open database: ", zap.Error(err))
}
defer db.Close() // 确保数据库在使用完毕后关闭
iter := db.NewIterator(nil, nil)
defer iter.Release()
// 计数
count := 0
for iter.Next() {
count++
}
fmt.Printf(*id + "结点目前有数据:%d\n", count)
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + *id)
defer storage.Close()
// 初始化
node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, *isRestart)
node := nodes.Init(*id, idClusterPairs, *pipe, db)
log.Info("id: " + *id + "节点开始监听: " + *port + "端口")
// 监听rpc
node.Rpc(*port)
// 开启 raft
nodes.Start(node, *isLeader)
quitChan := make(chan struct{}, 1)
nodes.Start(node, quitChan)
sig := <-sigs
fmt.Println("node_" + *id + "接收到信号:", sig)
fmt.Println("node_"+ *id +"接收到信号:", sig)
close(quitChan)
}

+ 122
- 46
internal/client/client_node.go View File

@ -1,9 +1,10 @@
package clientPkg
import (
"net/rpc"
"math/rand"
"simple-kv-store/internal/logprovider"
"simple-kv-store/internal/nodes"
"time"
"go.uber.org/zap"
)
@ -11,9 +12,11 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
type Client struct {
// 连接的server端节点(node1)
ServerId string
Address string
ClientId string // 每个client唯一标识
NextLogId int
// 连接的server端节点群
PeerIds []string
Transport nodes.Transport
}
type Status = uint8
@ -24,32 +27,83 @@ const (
Fail
)
func (client *Client) Write(kvCall nodes.LogEntryCall) Status {
log.Info("client write request key :" + kvCall.LogE.Key)
c, err := rpc.DialHTTP("tcp", client.Address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return Fail
}
func NewClient(clientId string, peerIds []string, transport nodes.Transport) *Client {
return &Client{ClientId: clientId, NextLogId: 0, PeerIds: peerIds, Transport: transport}
}
defer func(server *rpc.Client) {
err := c.Close()
func getRandomAddress(peerIds []string) string {
// 随机选一个 id
randomKey := peerIds[rand.Intn(len(peerIds))]
return randomKey
}
func (client *Client) FindActiveNode() nodes.ClientInterface {
var err error
var c nodes.ClientInterface
for { // 直到找到一个可连接的节点(保证至少一个节点活着)
peerId := getRandomAddress(client.PeerIds)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", peerId)
if err != nil {
log.Error("client close err: ", zap.Error(err))
log.Error("dialing: ", zap.Error(err))
} else {
log.Sugar().Infof("client发现活跃节点[%s]", peerId)
return c
}
}(c)
}
}
var reply nodes.ServerReply
callErr := c.Call("Node.WriteKV", kvCall, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
return Fail
func (client *Client) CloseRpcClient(c nodes.ClientInterface) {
err := c.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}
func (client *Client) Write(kv nodes.LogEntry) Status {
defer logprovider.DebugTraceback("client")
log.Info("client write request key :" + kv.Key)
kvCall := nodes.LogEntryCall{LogE: kv,
Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}}
client.NextLogId++
c := client.FindActiveNode()
var err error
timeout := time.Second
deadline := time.Now().Add(timeout)
for { // 根据存活节点的反馈,直到找到leader
if time.Now().After(deadline) {
log.Error("系统繁忙,疑似出错")
return Fail
}
var reply nodes.ServerReply
reply.Isleader = false
callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
if reply.Isconnect { // 发送成功
return Ok
} else { // 失败
return Fail
if !reply.Isleader { // 对方不是leader,根据反馈找leader
leaderId := reply.LeaderId
client.CloseRpcClient(c)
if leaderId == "" { // 这个节点不知道leader是谁,再随机找
c = client.FindActiveNode()
} else { // dial leader
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId)
if err != nil { // dial失败,重新找下一个存活节点
c = client.FindActiveNode()
}
}
} else { // 成功
client.CloseRpcClient(c)
return Ok
}
}
}
@ -58,37 +112,59 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value
if value == nil {
return Fail
}
c, err := rpc.DialHTTP("tcp", client.Address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return Fail
}
defer func(server *rpc.Client) {
err := c.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
var c nodes.ClientInterface
for {
c = client.FindActiveNode()
var reply nodes.ServerReply
callErr := client.Transport.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
continue
}
}(c)
var reply nodes.ServerReply
callErr := c.Call("Node.ReadKey", key, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
return Fail
}
if reply.Isconnect { // 发送成功
// 目前一定发送成功
if reply.HaveValue {
*value = reply.Value
client.CloseRpcClient(c)
return Ok
} else {
client.CloseRpcClient(c)
return NotFound
}
}
}
func (client *Client) FindLeader() string {
var arg struct{}
var reply nodes.FindLeaderReply
reply.Isleader = false
c := client.FindActiveNode()
var err error
for !reply.Isleader { // 根据存活节点的反馈,直到找到leader
callErr := client.Transport.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
} else { // 失败
return Fail
if !reply.Isleader { // 对方不是leader,根据反馈找leader
client.CloseRpcClient(c)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", reply.LeaderId)
for err != nil { // 重新找下一个存活节点
c = client.FindActiveNode()
}
} else { // 成功
client.CloseRpcClient(c)
return reply.LeaderId
}
}
log.Fatal("客户端会一直找存活节点,不会运行到这里")
return "fault"
}

+ 15
- 0
internal/logprovider/traceback.go View File

@ -0,0 +1,15 @@
package logprovider
import (
"fmt"
"os"
"runtime/debug"
)
func DebugTraceback(errFuncName string) {
if r := recover(); r != nil {
msg := fmt.Sprintf("panic in goroutine: %v\n%s", r, debug.Stack())
f, _ := os.Create(errFuncName + ".log")
fmt.Fprint(f, msg)
f.Close()
}
}

+ 217
- 77
internal/nodes/init.go View File

@ -1,13 +1,12 @@
package nodes
import (
"io"
"errors"
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"simple-kv-store/internal/logprovider"
"strconv"
"time"
"github.com/syndtr/goleveldb/leveldb"
@ -16,88 +15,48 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
func newNode(address string) *Public_node_info {
return &Public_node_info{
connect: false,
address: address,
}
}
func Init(id string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node {
ns := make(map[string]*Public_node_info)
for id, addr := range nodeAddr {
ns[id] = newNode(addr)
// 运行在进程上的初始化 + rpc注册
func InitRPCNode(SelfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
var nodeIds []string
for id := range nodeAddr {
nodeIds = append(nodeIds, id)
}
// 创建节点
return &Node{
selfId: id,
nodes: ns,
pipeAddr: pipe,
maxLogId: 0,
log: make(map[int]LogEntry),
db: db,
node := &Node{
SelfId: SelfId,
LeaderId: "",
Nodes: nodeIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: &HTTPTransport{NodeMap: nodeAddr},
RTTable: NewRTTable(),
SeenRequests: make(map[LogEntryCallId]bool),
IsFinish: false,
}
}
func Start(node *Node, isLeader bool) {
if isLeader {
node.state = Candidate // 需要身份转变
} else {
node.state = Follower
node.initLeaderState()
if isRestart {
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
}
go func() {
for {
switch node.state {
case Follower:
case Candidate:
// candidate发布一个监听输入线程后,变成leader
node.state = Leader
go func() {
if node.pipeAddr == "" { // 客户端远程调用server_node方法
log.Info("请运行客户端进程进行读写")
} else { // 命令行提供了管道,支持管道(键盘)输入
pipe, err := os.Open(node.pipeAddr)
if err != nil {
log.Error("Failed to open pipe")
}
defer pipe.Close()
// 不断读取管道中的输入
buffer := make([]byte, 256)
for {
n, err := pipe.Read(buffer)
if err != nil && err != io.EOF {
log.Error("Error reading from pipe")
}
if n > 0 {
input := string(buffer[:n])
// 将用户输入封装成一个 LogEntry
kv := LogEntry{input, ""} // 目前键盘输入key,value 0
logId := node.maxLogId
node.maxLogId++
node.log[logId] = kv
log.Info("send : logId = " + strconv.Itoa(logId) + ", key = " + input)
// 广播给其它节点
kvCall := LogEntryCall{kv, Normal}
node.BroadCastKV(logId, kvCall)
// 持久化
node.db.Put([]byte(kv.Key), []byte(kv.Value), nil)
}
}
}
}()
case Leader:
time.Sleep(50 * time.Millisecond)
}
}
}()
log.Sugar().Infof("[%s]开始监听" + port + "端口", SelfId)
node.ListenPort(port)
return node
}
func (node *Node) Rpc(port string) {
func (node *Node) ListenPort(port string) {
err := rpc.Register(node)
if err != nil {
log.Fatal("rpc register failed", zap.Error(err))
@ -115,3 +74,184 @@ func (node *Node) Rpc(port string) {
}
}()
}
// 线程模拟的初始化
func InitThreadNode(SelfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) {
rpcChan := make(chan RPCRequest, 100) // 要监听的chan
// 创建节点
node := &Node{
SelfId: SelfId,
LeaderId: "",
Nodes: peerIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: threadTransport,
RTTable: NewRTTable(),
SeenRequests: make(map[LogEntryCallId]bool),
IsFinish: false,
}
node.initLeaderState()
if isRestart {
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
}
threadTransport.RegisterNodeChan(SelfId, rpcChan)
quitChan := make(chan struct{}, 1)
go node.listenForChan(rpcChan, quitChan)
return node, quitChan
}
func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) {
defer logprovider.DebugTraceback("listen")
for {
select {
case req := <-rpcChan:
switch req.Behavior {
case DelayRpc:
threadTran, ok := node.Transport.(*ThreadTransport)
if !ok {
log.Fatal("无效的delayRpc模式")
}
duration, ok2 := threadTran.Ctx.GetDelay(req.SourceId, node.SelfId)
if !ok2 {
log.Fatal("没有设置对应的delay时间")
}
go node.switchReq(req, duration)
case FailRpc:
continue
default:
go node.switchReq(req, 0)
}
case <-quitChan:
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s] 监听线程收到退出信号", node.SelfId)
node.IsFinish = true
node.Db.Close()
node.Storage.Close()
return
}
}
}
func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) {
defer logprovider.DebugTraceback("switch")
time.Sleep(delayTime)
switch req.ServiceMethod {
case "Node.AppendEntries":
arg, ok := req.Args.(*AppendEntriesArg)
resp, ok2 := req.Reply.(*AppendEntriesReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for AppendEntries")
} else {
var respCopy AppendEntriesReply
err := node.AppendEntries(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.RequestVote":
arg, ok := req.Args.(*RequestVoteArgs)
resp, ok2 := req.Reply.(*RequestVoteReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for RequestVote")
} else {
var respCopy RequestVoteReply
err := node.RequestVote(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.WriteKV":
arg, ok := req.Args.(*LogEntryCall)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for WriteKV")
} else {
req.Done <- node.WriteKV(arg, resp)
}
case "Node.ReadKey":
arg, ok := req.Args.(*string)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for ReadKey")
} else {
req.Done <- node.ReadKey(arg, resp)
}
case "Node.FindLeader":
arg, ok := req.Args.(struct{})
resp, ok2 := req.Reply.(*FindLeaderReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for FindLeader")
} else {
req.Done <- node.FindLeader(arg, resp)
}
default:
req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod)
}
}
// 共同部分和启动
func (n *Node) initLeaderState() {
for _, peerId := range n.Nodes {
n.NextIndex[peerId] = len(n.Log) // 发送日志的下一个索引
n.MatchIndex[peerId] = 0 // 复制日志的最新匹配索引
}
}
func Start(node *Node, quitChan chan struct{}) {
node.Mu.Lock()
node.State = Follower // 所有节点以 Follower 状态启动
node.Mu.Unlock()
node.ResetElectionTimer() // 启动选举超时定时器
go func() {
defer logprovider.DebugTraceback("start")
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.SelfId)
return // 退出 goroutine
case <-ticker.C:
node.Mu.Lock()
state := node.State
node.Mu.Unlock()
switch state {
case Follower:
// 监听心跳超时
case Leader:
// 发送心跳
node.ResetElectionTimer() // leader 不主动触发选举
node.BroadCastKV()
}
}
}
}()
}

+ 19
- 6
internal/nodes/log.go View File

@ -1,19 +1,32 @@
package nodes
const (
Normal State = iota + 1
Delay
Fail
)
import "strconv"
type LogEntry struct {
Key string
Value string
}
func (LogE *LogEntry) print() string {
return "key: " + LogE.Key + ", value: " + LogE.Value
}
type RaftLogEntry struct {
LogE LogEntry
LogId int
Term int
}
func (RLogE *RaftLogEntry) print() string {
return "logid: " + strconv.Itoa(RLogE.LogId) + ", term: " + strconv.Itoa(RLogE.Term) + ", " + RLogE.LogE.print()
}
type LogEntryCall struct {
Id LogEntryCallId
LogE LogEntry
CallState State
}
type LogEntryCallId struct {
ClientId string
LogId int
}
type KVReply struct {

+ 43
- 72
internal/nodes/node.go View File

@ -1,13 +1,10 @@
package nodes
import (
"math/rand"
"net/rpc"
"strconv"
"sync"
"time"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
type State = uint8
@ -18,84 +15,58 @@ const (
Leader
)
type Public_node_info struct {
connect bool
address string
}
type Node struct {
Mu sync.Mutex
MuElection sync.Mutex
// 当前节点id
selfId string
SelfId string
// 记录的leader(不能用votedfor:投票的leader可能没有收到多数票)
LeaderId string
// 除当前节点外其他节点信息
nodes map[string]*Public_node_info
//管道名
pipeAddr string
// 除当前节点外其他节点id
Nodes []string
// 当前节点状态
state State
State State
// 任期
CurrTerm int
// 简单的kv存储
log map[int]LogEntry
Log []RaftLogEntry
// leader用来标记新log
maxLogId int
// leader用来标记新log, = log.len
MaxLogId int
db *leveldb.DB
}
// 已提交的index
CommitIndex int
func (node *Node) BroadCastKV(logId int, kvCall LogEntryCall) {
// 遍历所有节点
for id, _ := range node.nodes {
go func(id string, kv LogEntryCall) {
var reply KVReply
node.sendKV(id, logId, kvCall, &reply)
}(id, kvCall)
}
}
// 最后应用(写到db)的index
LastApplied int
func (node *Node) sendKV(id string, logId int, kvCall LogEntryCall, reply *KVReply) {
switch kvCall.CallState {
case Fail:
log.Info("模拟发送失败")
// 这么写向所有的node发送都失败,也可以随机数确定是否失败
case Delay:
log.Info("模拟发送延迟")
// 随机延迟0-5ms
time.Sleep(time.Millisecond * time.Duration(rand.Intn(5)))
default:
}
client, err := rpc.DialHTTP("tcp", node.nodes[id].address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return
}
defer func(client *rpc.Client) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
arg := LogIdAndEntry{logId, kvCall.LogE}
callErr := client.Call("Node.ReceiveKV", arg, reply) // RPC
if callErr != nil {
log.Error("dialing node_"+id+"fail: ", zap.Error(callErr))
}
}
// 需要发送给每个节点的下一个索引
NextIndex map[string]int
// RPC call
func (node *Node) ReceiveKV(arg LogIdAndEntry, reply *KVReply) error {
log.Info("node_" + node.selfId + " receive: logId = " + strconv.Itoa(arg.LogId) + ", key = " + arg.Entry.Key)
entry, ok := node.log[arg.LogId]
if !ok {
node.log[arg.LogId] = entry
}
// 持久化
node.db.Put([]byte(arg.Entry.Key), []byte(arg.Entry.Value), nil)
reply.Reply = true // rpc call需要有reply,但实际上调用是否成功是error返回值决定
return nil
// 已经发送给每个节点的最大索引
MatchIndex map[string]int
// 存kv(模拟状态机)
Db *leveldb.DB
// 持久化节点数据(currterm votedfor log)
Storage *RaftStorage
VotedFor string
ElectionTimer *time.Timer
// 通信方式
Transport Transport
// 系统的随机时间
RTTable *RandomTimeTable
// 已经处理过的客户端请求
SeenRequests map[LogEntryCallId]bool
IsFinish bool
}

+ 194
- 0
internal/nodes/node_storage.go View File

@ -0,0 +1,194 @@
package nodes
import (
"encoding/json"
"strconv"
"strings"
"sync"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
// RaftStorage 结构,持久化 currentTerm、votedFor 和 logEntries
type RaftStorage struct {
mu sync.Mutex
db *leveldb.DB
filePath string
isfinish bool
}
// NewRaftStorage 创建 Raft 存储
func NewRaftStorage(filePath string) *RaftStorage {
db, err := leveldb.OpenFile(filePath, nil)
if err != nil {
log.Fatal("无法打开 LevelDB:", zap.Error(err))
}
return &RaftStorage{
db: db,
filePath: filePath,
isfinish: false,
}
}
// SetCurrentTerm 设置当前 term
func (rs *RaftStorage) SetCurrentTerm(term int) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
err := rs.db.Put([]byte("current_term"), []byte(strconv.Itoa(term)), nil)
if err != nil {
log.Error("SetCurrentTerm 持久化失败:", zap.Error(err))
}
}
// GetCurrentTerm 获取当前 term
func (rs *RaftStorage) GetCurrentTerm() int {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("current_term"), nil)
if err != nil {
return 0 // 默认 term = 0
}
term, _ := strconv.Atoi(string(data))
return term
}
// SetVotedFor 记录投票给谁
func (rs *RaftStorage) SetVotedFor(candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
err := rs.db.Put([]byte("voted_for"), []byte(candidate), nil)
if err != nil {
log.Error("SetVotedFor 持久化失败:", zap.Error(err))
}
}
// GetVotedFor 获取投票对象
func (rs *RaftStorage) GetVotedFor() string {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("voted_for"), nil)
if err != nil {
return ""
}
return string(data)
}
// SetTermAndVote 原子更新 term 和 vote
func (rs *RaftStorage) SetTermAndVote(term int, candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
batch.Put([]byte("current_term"), []byte(strconv.Itoa(term)))
batch.Put([]byte("voted_for"), []byte(candidate))
err := rs.db.Write(batch, nil) // 原子提交
if err != nil {
log.Error("SetTermAndVote 持久化失败:", zap.Error(err))
}
}
// AppendLog 追加日志
func (rs *RaftStorage) AppendLog(entry RaftLogEntry) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.db == nil {
return
}
// 序列化日志
batch := new(leveldb.Batch)
data, _ := json.Marshal(entry)
key := "log_" + strconv.Itoa(entry.LogId)
batch.Put([]byte(key), data)
lastIndex := strconv.Itoa(entry.LogId)
batch.Put([]byte("last_log_index"), []byte(lastIndex))
err := rs.db.Write(batch, nil)
if err != nil {
log.Error("AppendLog 持久化失败:", zap.Error(err))
}
}
// GetLastLogIndex 获取最新日志的 index
func (rs *RaftStorage) GetLastLogIndex() int {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("last_log_index"), nil)
if err != nil {
return -1
}
index, _ := strconv.Atoi(string(data))
return index
}
// WriteLog 批量写入日志(保证原子性)
func (rs *RaftStorage) WriteLog(entries []RaftLogEntry) {
if len(entries) == 0 {
return
}
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
for _, entry := range entries {
data, _ := json.Marshal(entry)
key := "log_" + strconv.Itoa(entry.LogId)
batch.Put([]byte(key), data)
}
// 更新最新日志索引
lastIndex := strconv.Itoa(entries[len(entries)-1].LogId)
batch.Put([]byte("last_log_index"), []byte(lastIndex))
err := rs.db.Write(batch, nil)
if err != nil {
log.Error("WriteLog 持久化失败:", zap.Error(err))
}
}
// GetLogEntries 获取所有日志
func (rs *RaftStorage) GetLogEntries() []RaftLogEntry {
rs.mu.Lock()
defer rs.mu.Unlock()
var logs []RaftLogEntry
iter := rs.db.NewIterator(nil, nil) // 遍历所有键值
defer iter.Release()
for iter.Next() {
key := string(iter.Key())
if strings.HasPrefix(key, "log_") { // 过滤日志 key
var entry RaftLogEntry
if err := json.Unmarshal(iter.Value(), &entry); err == nil {
logs = append(logs, entry)
} else {
log.Error("解析日志失败:", zap.Error(err))
}
}
}
return logs
}
// Close 关闭数据库
func (rs *RaftStorage) Close() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.db.Close()
rs.isfinish = true
}

+ 46
- 0
internal/nodes/random_timetable.go View File

@ -0,0 +1,46 @@
package nodes
import (
"math/rand"
"sync"
"time"
)
type RandomTimeTable struct {
Mu sync.Mutex
electionTimeOut time.Duration
israndom bool
// heartbeat 50ms
// rpcTimeout 50ms
// follower变candidate 500ms
// 等待选举成功时间 300ms
}
func NewRTTable() *RandomTimeTable {
return &RandomTimeTable{
israndom: true,
}
}
func (rttable *RandomTimeTable) GetElectionTimeout() time.Duration {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
if rttable.israndom {
return time.Duration(500+rand.Intn(500)) * time.Millisecond
} else {
return rttable.electionTimeOut
}
}
func (rttable *RandomTimeTable) SetElectionTimeout(t time.Duration) {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
rttable.israndom = false
rttable.electionTimeOut = t
}
func (rttable *RandomTimeTable) ResetElectionTimeout() {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
rttable.israndom = true
}

+ 111
- 0
internal/nodes/real_transport.go View File

@ -0,0 +1,111 @@
package nodes
import (
"errors"
"fmt"
"net/rpc"
"time"
)
// 真实rpc通讯的transport类型实现
type HTTPTransport struct{
NodeMap map[string]string // id到addr的映射
}
// 封装有超时的dial
func (t *HTTPTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
done := make(chan struct{})
var client *rpc.Client
var err error
go func() {
client, err = rpc.DialHTTP(network, t.NodeMap[peerId])
close(done)
}()
select {
case <-done:
return &HTTPClient{rpcClient: client}, err
case <-time.After(50 * time.Millisecond):
return nil, fmt.Errorf("dial timeout: %s", t.NodeMap[peerId])
}
}
func (t *HTTPTransport) CallWithTimeout(clientInterface ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
c, ok := clientInterface.(*HTTPClient)
client := c.rpcClient
if !ok {
return fmt.Errorf("invalid client type")
}
done := make(chan error, 1)
go func() {
switch serviceMethod {
case "Node.AppendEntries":
arg, ok := args.(*AppendEntriesArg)
resp, ok2 := reply.(*AppendEntriesReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for AppendEntries")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.RequestVote":
arg, ok := args.(*RequestVoteArgs)
resp, ok2 := reply.(*RequestVoteReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for RequestVote")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.WriteKV":
arg, ok := args.(*LogEntryCall)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for WriteKV")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.ReadKey":
arg, ok := args.(*string)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for ReadKey")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.FindLeader":
arg, ok := args.(struct{})
resp, ok2 := reply.(*FindLeaderReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for FindLeader")
return
}
done <- client.Call(serviceMethod, arg, resp)
default:
done <- fmt.Errorf("unknown service method: %s", serviceMethod)
}
}()
select {
case err := <-done:
return err
case <-time.After(50 * time.Millisecond): // 设置超时时间
return fmt.Errorf("call timeout: %s", serviceMethod)
}
}
type HTTPClient struct {
rpcClient *rpc.Client
}
func (h *HTTPClient) Close() error {
return h.rpcClient.Close()
}

+ 271
- 0
internal/nodes/replica.go View File

@ -0,0 +1,271 @@
package nodes
import (
"simple-kv-store/internal/logprovider"
"sort"
"strconv"
"sync"
"go.uber.org/zap"
)
type AppendEntriesArg struct {
Term int
LeaderId string
PrevLogIndex int
PrevLogTerm int
Entries []RaftLogEntry
LeaderCommit int
}
type AppendEntriesReply struct {
Term int
Success bool
}
// leader收到新内容要广播,以及心跳广播(同步自己的log)
func (node *Node) BroadCastKV() {
log.Sugar().Infof("leader[%s]广播消息", node.SelfId)
defer logprovider.DebugTraceback("broadcast")
failCount := 0
// 这里增加一个锁,防止并发修改成功计数
var failMutex sync.Mutex
// 遍历所有节点
for _, id := range node.Nodes {
go func(id string) {
defer logprovider.DebugTraceback("send")
node.sendKV(id, &failCount, &failMutex)
}(id)
}
}
func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) {
node.Mu.Lock()
selfId := node.SelfId
node.Mu.Unlock()
client, err := node.Transport.DialHTTPWithTimeout("tcp", selfId, peerId)
if err != nil {
node.Mu.Lock()
log.Error("[" + node.SelfId + "]dialling [" + peerId + "] fail: ", zap.Error(err))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
node.LeaderId = ""
node.State = Follower
node.ResetElectionTimer()
}
failMutex.Unlock()
node.Mu.Unlock()
return
}
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
node.Mu.Lock()
NextIndex := node.NextIndex[peerId]
// log.Info("NextIndex " + strconv.Itoa(NextIndex))
for {
if NextIndex < 0 {
log.Fatal("assert >= 0 here")
}
sendEntries := node.Log[NextIndex:]
arg := AppendEntriesArg{
Term: node.CurrTerm,
PrevLogIndex: NextIndex - 1,
Entries: sendEntries,
LeaderCommit: node.CommitIndex,
LeaderId: node.SelfId,
}
if arg.PrevLogIndex >= 0 {
arg.PrevLogTerm = node.Log[arg.PrevLogIndex].Term
}
// 记录关键数据后解锁
currTerm := node.CurrTerm
currState := node.State
MaxLogId := node.MaxLogId
var appendReply AppendEntriesReply
appendReply.Success = false
node.Mu.Unlock()
callErr := node.Transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
node.Mu.Lock()
if node.CurrTerm != currTerm || node.MaxLogId != MaxLogId || node.State != currState {
node.Mu.Unlock()
return
}
if callErr != nil {
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
log.Info("term=" + strconv.Itoa(node.CurrTerm) + "的Leader[" + node.SelfId + "]无法联系到半数节点, 降级为 Follower")
node.LeaderId = ""
node.State = Follower
node.ResetElectionTimer()
}
failMutex.Unlock()
node.Mu.Unlock()
return
}
if appendReply.Term != node.CurrTerm {
log.Sugar().Infof("term=%s的leader[%s]因为[%s]收到更高的term=%s, 转换为follower",
strconv.Itoa(node.CurrTerm), node.SelfId, peerId, strconv.Itoa(appendReply.Term))
node.LeaderId = ""
node.CurrTerm = appendReply.Term
node.State = Follower
node.VotedFor = ""
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
node.ResetElectionTimer()
node.Mu.Unlock()
return
}
if appendReply.Success {
break
}
NextIndex-- // 失败往前传一格
}
// 不变成follower情况下
node.NextIndex[peerId] = node.MaxLogId + 1
node.MatchIndex[peerId] = node.MaxLogId
node.updateCommitIndex()
node.Mu.Unlock()
}
func (node *Node) updateCommitIndex() {
if node.Mu.TryLock() {
log.Fatal("这里要保证有锁")
}
if node.IsFinish {
return
}
totalNodes := len(node.Nodes)
// 收集所有 MatchIndex 并排序
MatchIndexes := make([]int, 0, totalNodes)
for _, index := range node.MatchIndex {
MatchIndexes = append(MatchIndexes, index)
}
sort.Ints(MatchIndexes) // 排序
// 计算多数派 CommitIndex
majorityIndex := MatchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派)
// 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志
if majorityIndex > node.CommitIndex && majorityIndex < len(node.Log) && node.Log[majorityIndex].Term == node.CurrTerm {
node.CommitIndex = majorityIndex
log.Info("Leader[" + node.SelfId + "]更新 CommitIndex: " + strconv.Itoa(majorityIndex))
// 应用日志到状态机
node.applyCommittedLogs()
}
}
// 应用日志到状态机
func (node *Node) applyCommittedLogs() {
for node.LastApplied < node.CommitIndex {
node.LastApplied++
logEntry := node.Log[node.LastApplied]
log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.SelfId)
err := node.Db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil)
if err != nil {
log.Error(node.SelfId + "应用状态机失败: ", zap.Error(err))
}
}
}
// RPC call
func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply) error {
defer logprovider.DebugTraceback("append")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]在term=%d收到[%s]的AppendEntries", node.SelfId, node.CurrTerm, arg.LeaderId)
// 如果 term 过期,拒绝接受日志
if node.CurrTerm > arg.Term {
reply.Term = node.CurrTerm
reply.Success = false
return nil
}
node.LeaderId = arg.LeaderId // 记录Leader
// 如果term比自己高,或自己不是follower但收到相同term的心跳
if node.CurrTerm < arg.Term || node.State != Follower {
log.Sugar().Infof("[%s]发现更高 term(%s)", node.SelfId, strconv.Itoa(arg.Term))
node.CurrTerm = arg.Term
node.State = Follower
node.VotedFor = ""
// node.storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
}
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
// 检查 prevLogIndex 是否有效
if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) {
reply.Term = node.CurrTerm
reply.Success = false
return nil
}
// 处理日志冲突(如果存在不同 term,则截断日志)
idx := arg.PrevLogIndex + 1
for i := idx; i < len(node.Log) && i-idx < len(arg.Entries); i++ {
if node.Log[i].Term != arg.Entries[i-idx].Term {
node.Log = node.Log[:idx]
break
}
}
// log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.Log)))
// 追加新的日志条目
for _, raftLogEntry := range arg.Entries {
log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.SelfId)
if idx < len(node.Log) {
node.Log[idx] = raftLogEntry
} else {
node.Log = append(node.Log, raftLogEntry)
}
idx++
}
// 暴力持久化
node.Storage.WriteLog(node.Log)
// 更新 MaxLogId
node.MaxLogId = len(node.Log) - 1
// 更新 CommitIndex
if arg.LeaderCommit < node.MaxLogId {
node.CommitIndex = arg.LeaderCommit
} else {
node.CommitIndex = node.MaxLogId
}
// 提交已提交的日志
node.applyCommittedLogs()
// 在成功接受日志或心跳后,重置选举超时
node.ResetElectionTimer()
reply.Term = node.CurrTerm
reply.Success = true
return nil
}

+ 68
- 16
internal/nodes/server_node.go View File

@ -1,42 +1,94 @@
package nodes
import (
"strconv"
"simple-kv-store/internal/logprovider"
"github.com/syndtr/goleveldb/leveldb"
)
// leader node作为server为client注册的方法
type ServerReply struct{
Isconnect bool
Isleader bool
LeaderId string // 自己不是leader则返回leader
HaveValue bool
Value string
}
// RPC call
func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error {
logId := node.maxLogId
node.maxLogId++
node.log[logId] = kvCall.LogE
node.db.Put([]byte(kvCall.LogE.Key), []byte(kvCall.LogE.Value), nil)
log.Info("server write : logId = " + strconv.Itoa(logId) + ", key = " + kvCall.LogE.Key)
func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error {
defer logprovider.DebugTraceback("write")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]收到客户端write请求", node.SelfId)
// 自己不是leader,转交leader地址回复
if node.State != Leader {
reply.Isleader = false
reply.LeaderId = node.LeaderId // 可能是空,那client就随机再找一个节点
log.Sugar().Infof("[%s]转交给[%s]", node.SelfId, node.LeaderId)
return nil
}
if node.SeenRequests[kvCall.Id] {
log.Sugar().Infof("Leader [%s] 已处理过client[%s]的请求 %d, 跳过", node.SelfId, kvCall.Id.ClientId, kvCall.Id.LogId)
reply.Isleader = true
return nil
}
node.SeenRequests[kvCall.Id] = true
// 自己是leader,修改自己的记录并广播
node.MaxLogId++
logId := node.MaxLogId
rLogE := RaftLogEntry{kvCall.LogE, logId, node.CurrTerm}
node.Log = append(node.Log, rLogE)
node.Storage.AppendLog(rLogE)
log.Info("leader[" + node.SelfId + "]处理请求 : " + kvCall.LogE.print())
// 广播给其它节点
node.BroadCastKV(logId, kvCall)
reply.Isconnect = true
node.BroadCastKV()
reply.Isleader = true
return nil
}
// RPC call
func (node *Node) ReadKey(key string, reply *ServerReply) error {
log.Info("server read : " + key)
// 先只读leader自己
value, err := node.db.Get([]byte(key), nil)
func (node *Node) ReadKey(key *string, reply *ServerReply) error {
defer logprovider.DebugTraceback("read")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]收到客户端read请求", node.SelfId)
// 先只读自己(无论自己是不是leader),也方便测试
value, err := node.Db.Get([]byte(*key), nil)
if err == leveldb.ErrNotFound {
reply.HaveValue = false
} else {
reply.HaveValue = true
reply.Value = string(value)
}
reply.Isconnect = true
reply.Isleader = true
return nil
}
// RPC call 测试中寻找当前leader
type FindLeaderReply struct{
Isleader bool
LeaderId string
}
func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error {
defer logprovider.DebugTraceback("find")
node.Mu.Lock()
defer node.Mu.Unlock()
// 自己不是leader,转交leader地址回复
if node.State != Leader {
reply.Isleader = false
if (node.LeaderId == "") {
log.Fatal("还没选出第一个leader")
return nil
}
reply.LeaderId = node.LeaderId
return nil
}
reply.LeaderId = node.SelfId
reply.Isleader = true
return nil
}

+ 65
- 0
internal/nodes/simulate_ctx.go View File

@ -0,0 +1,65 @@
package nodes
import (
"fmt"
"sync"
"time"
)
// Ctx 结构体:管理不同节点之间的通信行为
type Ctx struct {
mu sync.Mutex
Behavior map[string]CallBehavior // (src,target) -> CallBehavior
Delay map[string]time.Duration // (src,target) -> 延迟时间
Retries map[string]int // 记录 (src,target) 的重发调用次数
}
// NewCtx 创建上下文
func NewCtx() *Ctx {
return &Ctx{
Behavior: make(map[string]CallBehavior),
Delay: make(map[string]time.Duration),
Retries: make(map[string]int),
}
}
// SetBehavior 设置 A->B 的 RPC 行为
func (c *Ctx) SetBehavior(src, dst string, behavior CallBehavior, delay time.Duration, retries int) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
c.Behavior[key] = behavior
c.Delay[key] = delay
c.Retries[key] = retries
}
// GetBehavior 获取 A->B 的行为
func (c *Ctx) GetBehavior(src, dst string) (CallBehavior) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if state, exists := c.Behavior[key]; exists {
return state
}
return NormalRpc
}
func (c *Ctx) GetDelay(src, dst string) (t time.Duration, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if t, ok = c.Delay[key]; ok {
return t, ok
}
return 0, ok
}
func (c *Ctx) GetRetries(src, dst string) (times int, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if times, ok = c.Retries[key]; ok {
return times, ok
}
return 0, ok
}

+ 234
- 0
internal/nodes/thread_transport.go View File

@ -0,0 +1,234 @@
package nodes
import (
"fmt"
"sync"
"time"
)
type CallBehavior = uint8
const (
NormalRpc CallBehavior = iota + 1
DelayRpc
RetryRpc
FailRpc
)
// RPC 请求结构
type RPCRequest struct {
ServiceMethod string
Args interface{}
Reply interface{}
Done chan error // 用于返回响应
SourceId string
// 模拟rpc请求状态
Behavior CallBehavior
}
// 线程版 Transport
type ThreadTransport struct {
mu sync.Mutex
nodeChans map[string]chan RPCRequest // 每个节点的消息通道
connectivityMap map[string]map[string]bool // 模拟网络分区
Ctx *Ctx
}
// 线程版 dial的返回clientinterface
type ThreadClient struct {
SourceId string
TargetId string
}
func (c *ThreadClient) Close() error {
return nil
}
// 初始化线程通信系统
func NewThreadTransport(ctx *Ctx) *ThreadTransport {
return &ThreadTransport{
nodeChans: make(map[string]chan RPCRequest),
connectivityMap: make(map[string]map[string]bool),
Ctx: ctx,
}
}
// 注册一个新节点chan
func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) {
t.mu.Lock()
defer t.mu.Unlock()
t.nodeChans[nodeId] = ch
// 初始化连通性(默认所有节点互相可达)
if _, exists := t.connectivityMap[nodeId]; !exists {
t.connectivityMap[nodeId] = make(map[string]bool)
}
for peerId := range t.nodeChans {
t.connectivityMap[nodeId][peerId] = true
t.connectivityMap[peerId][nodeId] = true
}
}
// 设置两个节点的连通性
func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.connectivityMap[from]; exists {
t.connectivityMap[from][to] = isConnected
}
}
func (t *ThreadTransport) ResetConnectivity() {
t.mu.Lock()
defer t.mu.Unlock()
for firstId:= range t.nodeChans {
for peerId:= range t.nodeChans {
if firstId != peerId {
t.connectivityMap[firstId][peerId] = true
t.connectivityMap[peerId][firstId] = true
}
}
}
}
// 获取节点的 channel
func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) {
t.mu.Lock()
defer t.mu.Unlock()
ch, ok := t.nodeChans[nodeId]
return ch, ok
}
// 模拟 Dial 操作
func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.nodeChans[peerId]; !exists {
return nil, fmt.Errorf("节点 [%s] 不存在", peerId)
}
return &ThreadClient{SourceId: myId, TargetId: peerId}, nil
}
// 模拟 Call 操作
func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
threadClient, ok := client.(*ThreadClient)
if !ok {
return fmt.Errorf("无效的caller")
}
var isConnected bool
if threadClient.SourceId == "" {
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId]
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
}
targetChan, exists := t.getNodeChan(threadClient.TargetId)
if !exists {
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId)
}
done := make(chan error, 1)
behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId)
// 辅助函数:复制 replyCopy 到原始 reply
copyReply := func(dst, src interface{}) {
switch d := dst.(type) {
case *AppendEntriesReply:
*d = *(src.(*AppendEntriesReply))
case *RequestVoteReply:
*d = *(src.(*RequestVoteReply))
}
}
sendRequest := func(req RPCRequest, ch chan RPCRequest) bool {
select {
case ch <- req:
return true
default:
return false
}
}
switch behavior {
case RetryRpc:
retryTimes, ok := t.Ctx.GetRetries(threadClient.SourceId, threadClient.TargetId)
if !ok {
log.Fatal("没有设置对应的retry次数")
}
var lastErr error
for i := 0; i < retryTimes; i++ {
var replyCopy interface{}
useCopy := true
switch r := reply.(type) {
case *AppendEntriesReply:
tmp := *r
replyCopy = &tmp
case *RequestVoteReply:
tmp := *r
replyCopy = &tmp
default:
replyCopy = reply // 其他类型不复制
useCopy = false
}
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: replyCopy,
Done: done,
SourceId: threadClient.SourceId,
Behavior: NormalRpc,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
select {
case err := <-done:
if err == nil && useCopy {
copyReply(reply, replyCopy)
}
if err == nil {
return nil
}
lastErr = err
case <-time.After(250 * time.Millisecond):
lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
return lastErr
default:
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
SourceId: threadClient.SourceId,
Behavior: behavior,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
select {
case err := <-done:
return err
case <-time.After(250 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
}

+ 10
- 0
internal/nodes/transport.go View File

@ -0,0 +1,10 @@
package nodes
type ClientInterface interface{
Close() error
}
type Transport interface {
DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
}

+ 215
- 0
internal/nodes/vote.go View File

@ -0,0 +1,215 @@
package nodes
import (
"math/rand"
"simple-kv-store/internal/logprovider"
"strconv"
"sync"
"time"
"go.uber.org/zap"
)
type RequestVoteArgs struct {
Term int // 候选人的当前任期
CandidateId string // 候选人 ID
LastLogIndex int // 候选人最后一条日志的索引
LastLogTerm int // 候选人最后一条日志的任期
}
type RequestVoteReply struct {
Term int // 当前节点的最新任期
VoteGranted bool // 是否同意投票
}
func (n *Node) StartElection() {
defer logprovider.DebugTraceback("startElection")
n.Mu.Lock()
if n.IsFinish {
n.Mu.Unlock()
return
}
// 增加当前任期,转换为 Candidate
n.CurrTerm++
n.State = Candidate
n.VotedFor = n.SelfId // 自己投自己
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.SelfId, n.CurrTerm)
// 重新设置选举超时,防止重复选举
n.ResetElectionTimer()
// 构造 RequestVote 请求
var lastLogIndex int
var lastLogTerm int
if len(n.Log) == 0 {
lastLogIndex = 0
lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0
} else {
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
}
args := RequestVoteArgs{
Term: n.CurrTerm,
CandidateId: n.SelfId,
LastLogIndex: lastLogIndex,
LastLogTerm: lastLogTerm,
}
// 并行向其他节点发送请求投票
var Mu sync.Mutex
totalNodes := len(n.Nodes)
grantedVotes := 1 // 自己的票
currTerm := n.CurrTerm
currState := n.State
n.Mu.Unlock()
for _, peerId := range n.Nodes {
go func(peerId string) {
defer logprovider.DebugTraceback("vote")
var reply RequestVoteReply
if n.sendRequestVote(peerId, &args, &reply) {
Mu.Lock()
defer Mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
if currTerm != n.CurrTerm || currState != n.State {
return
}
if reply.Term > n.CurrTerm {
// 发现更高任期,回退为 Follower
log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term)
n.CurrTerm = reply.Term
n.State = Follower
n.VotedFor = ""
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
n.ResetElectionTimer()
return
}
if reply.VoteGranted {
grantedVotes++
}
if grantedVotes == totalNodes / 2 + 1 {
n.State = Leader
log.Sugar().Infof("[%s] 当选 Leader!", n.SelfId)
n.initLeaderState()
}
}
}(peerId)
}
// 等待选举结果
time.Sleep(300 * time.Millisecond)
Mu.Lock()
defer Mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
if n.State == Candidate {
log.Sugar().Infof("[%s] 选举超时,等待后将重新发起选举", n.SelfId)
// n.State = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower
n.ResetElectionTimer()
}
}
func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool {
log.Sugar().Infof("[%s] 请求 [%s] 投票", node.SelfId, peerId)
client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId)
if err != nil {
log.Error("[" + node.SelfId + "]dialing [" + peerId + "] fail: ", zap.Error(err))
return false
}
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
callErr := node.Transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
if callErr != nil {
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
}
return callErr == nil
}
func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error {
defer logprovider.DebugTraceback("requestVote")
n.Mu.Lock()
defer n.Mu.Unlock()
// 如果候选人的任期小于当前任期,则拒绝投票
if args.Term < n.CurrTerm {
reply.Term = n.CurrTerm
reply.VoteGranted = false
return nil
}
// 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower
if args.Term > n.CurrTerm {
n.CurrTerm = args.Term
n.State = Follower
n.VotedFor = ""
n.ResetElectionTimer() // 重新设置选举超时
}
// 检查是否已经投过票,且是否投给了同一个候选人
if n.VotedFor == "" || n.VotedFor == args.CandidateId {
// 检查日志是否足够新
var lastLogIndex int
var lastLogTerm int
if len(n.Log) == 0 {
lastLogIndex = -1
lastLogTerm = 0
} else {
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
}
if args.LastLogTerm > lastLogTerm ||
(args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) {
// 够新就投票给候选人
n.VotedFor = args.CandidateId
log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.CurrTerm), n.SelfId, n.VotedFor)
reply.VoteGranted = true
n.ResetElectionTimer()
} else {
reply.VoteGranted = false
}
} else {
reply.VoteGranted = false
}
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
reply.Term = n.CurrTerm
return nil
}
// follower 一段时间内没收到appendentries心跳,就变成candidate发起选举
func (node *Node) ResetElectionTimer() {
node.MuElection.Lock()
defer node.MuElection.Unlock()
if node.ElectionTimer == nil {
node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout())
go func() {
defer logprovider.DebugTraceback("reset")
for {
<-node.ElectionTimer.C
node.StartElection()
}
}()
} else {
node.ElectionTimer.Stop()
node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
}
}

BIN
pics/plan.png View File

Before After
Width: 2046  |  Height: 966  |  Size: 160 KiB

BIN
pics/plus.png View File

Before After
Width: 1327  |  Height: 855  |  Size: 185 KiB

BIN
pics/robust.png View File

Before After
Width: 1151  |  Height: 539  |  Size: 81 KiB

BIN
raft第二次汇报.pptx View File


+ 0
- 61
scripts/run.sh View File

@ -1,61 +0,0 @@
#!/bin/bash
# 设置运行时间限制:s
RUN_TIME=10
# 需要传递数据的管道
PIPE_NAME="/tmp/input_pipe"
# 启动节点1
echo "Starting Node 1..."
timeout $RUN_TIME ./main -id 1 -port ":9091" -cluster "127.0.0.1:9092,127.0.0.1:9093" -pipe "$PIPE_NAME" -isleader=true &
# 启动节点2
echo "Starting Node 2..."
timeout $RUN_TIME ./main -id 2 -port ":9092" -cluster "127.0.0.1:9091,127.0.0.1:9093" -pipe "$PIPE_NAME" &
# 启动节点3
echo "Starting Node 3..."
timeout $RUN_TIME ./main -id 3 -port ":9093" -cluster "127.0.0.1:9091,127.0.0.1:9092" -pipe "$PIPE_NAME"&
echo "All nodes started successfully!"
# 创建一个管道用于进程间通信
if [[ ! -p "$PIPE_NAME" ]]; then
mkfifo "$PIPE_NAME"
fi
# 捕获终端输入并通过管道传递给三个节点
echo "Enter input to send to nodes:"
start_time=$(date +%s)
while true; do
# 从终端读取用户输入
read -r user_input
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
# 如果运行时间大于限制时间,就退出
if [ $elapsed_time -ge $RUN_TIME ]; then
echo 'Timeout reached, normal exit now'
break
fi
# 如果输入为空,跳过
if [[ -z "$user_input" ]]; then
continue
fi
# 将用户输入发送到管道
echo "$user_input" > "$PIPE_NAME"
# 如果输入 "exit",结束脚本
if [[ "$user_input" == "exit" ]]; then
break
fi
done
# 删除管道
rm "$PIPE_NAME"
# 等待所有节点完成启动
wait

+ 7
- 15
test/common.go View File

@ -8,29 +8,21 @@ import (
"strings"
)
func ExecuteNodeI(i int, isLeader bool, isNewDb bool, clusters []string) *exec.Cmd {
tmpClusters := append(clusters[:i], clusters[i+1:]...)
func ExecuteNodeI(i int, isRestart bool, clusters []string) *exec.Cmd {
port := fmt.Sprintf(":%d", uint16(9090)+uint16(i))
var isleader string
if isLeader {
isleader = "true"
var isRestartStr string
if isRestart {
isRestartStr = "true"
} else {
isleader = "false"
}
var isnewdb string
if isNewDb {
isnewdb = "true"
} else {
isnewdb = "false"
isRestartStr = "false"
}
cmd := exec.Command(
"../main",
"-id", strconv.Itoa(i + 1),
"-port", port,
"-cluster", strings.Join(tmpClusters, ","),
"-isleader=" + isleader,
"-isNewDb=" + isnewdb,
"-cluster", strings.Join(clusters, ","),
"-isRestart=" + isRestartStr,
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

+ 0
- 112
test/restart_follower_test.go View File

@ -1,112 +0,0 @@
package test
import (
"fmt"
"os/exec"
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"syscall"
"testing"
"time"
)
func TestFollowerRestart(t *testing.T) {
// 登记结点信息
n := 3
var clusters []string
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
var cmd *exec.Cmd
if i == 0 {
cmd = ExecuteNodeI(i, true, true, clusters)
} else {
cmd = ExecuteNodeI(i, false, true, clusters)
}
if cmd == nil {
return
} else {
cmds = append(cmds, cmd)
}
}
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接leader
cWrite := clientPkg.Client{Address: clusters[0], ServerId: "1"}
// 写入
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 模拟最后一个结点崩溃
err := cmds[n - 1].Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
// 继续写入
for i := 5; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
// 恢复结点
cmd := ExecuteNodeI(n - 1, false, false, clusters)
if cmd == nil {
t.Errorf("recover test1 fail")
return
} else {
cmds[n - 1] = cmd
}
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接节点n-1(去读它的数据)
cRead := clientPkg.Client{Address: clusters[n - 1], ServerId: "n"}
// 读崩溃前写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
// 通知进程结束
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}

+ 104
- 0
test/restart_node_test.go View File

@ -0,0 +1,104 @@
package test
import (
"fmt"
"os/exec"
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"syscall"
"testing"
"time"
)
func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
// 通知所有进程结束
defer func(){
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
transport := &nodes.HTTPTransport{NodeMap: addressMap}
cWrite := clientPkg.NewClient("0", peerIds, transport)
// 写入
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 模拟结点轮流崩溃
for i := 0; i < n; i++ {
err := cmds[i].Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
time.Sleep(time.Second)
cmd := ExecuteNodeI(i, true, clusters)
if cmd == nil {
t.Errorf("recover test1 fail")
return
} else {
cmds[i] = cmd
}
time.Sleep(time.Second) // 等待启动完毕
}
// client启动
cRead := clientPkg.NewClient("0", peerIds, transport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}

+ 22
- 25
test/server_client_test.go View File

@ -15,50 +15,57 @@ func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
var cmd *exec.Cmd
if i == 0 {
cmd = ExecuteNodeI(i, true, true, clusters)
} else {
cmd = ExecuteNodeI(i, false, true, clusters)
}
if cmd == nil {
return
} else {
cmds = append(cmds, cmd)
}
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
// 通知所有进程结束
defer func(){
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.Client{Address: "127.0.0.1:9090", ServerId: "1"}
transport := &nodes.HTTPTransport{NodeMap: addressMap}
c := clientPkg.NewClient("0", peerIds, transport)
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s := c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok && value != "hello" + key {
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
@ -72,14 +79,4 @@ func TestServerClient(t *testing.T) {
t.Errorf("Read test2 fail")
}
}
// 通知进程结束
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}

+ 305
- 0
threadTest/common.go View File

@ -0,0 +1,305 @@
package threadTest
import (
"fmt"
"os"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
"github.com/syndtr/goleveldb/leveldb"
)
func ExecuteNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) {
if !isRestart {
os.RemoveAll("storage/node" + id)
}
// 创建临时目录用于 leveldb
dbPath, err := os.MkdirTemp("", "simple-kv-store-"+id+"-")
if err != nil {
panic(fmt.Sprintf("无法创建临时数据库目录: %s", err))
}
// 创建临时目录用于 storage
storagePath, err := os.MkdirTemp("", "raft-storage-"+id+"-")
if err != nil {
panic(fmt.Sprintf("无法创建临时存储目录: %s", err))
}
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
panic(fmt.Sprintf("Failed to open database: %s", err))
}
// 初始化 Raft 存储
storage := nodes.NewRaftStorage(storagePath)
var otherIds []string
for _, ids := range peerIds {
if ids != id {
otherIds = append(otherIds, ids) // 删除目标元素
}
}
// 初始化
node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport)
// 开启 raft
go nodes.Start(node, quitChan)
return node, quitChan
}
func ExecuteStaticNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) {
if !isRestart {
os.RemoveAll("storage/node" + id)
}
os.RemoveAll("leveldb/simple-kv-store" + id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store"+id, nil)
if err != nil {
fmt.Println("Failed to open database: ", err)
}
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + id)
var otherIds []string
for _, ids := range peerIds {
if ids != id {
otherIds = append(otherIds, ids) // 删除目标元素
}
}
// 初始化
node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport)
// 开启 raft
// go nodes.Start(node, quitChan)
return node, quitChan
}
func StopElectionReset(nodeCollections []*nodes.Node) {
for i := 0; i < len(nodeCollections); i++ {
node := nodeCollections[i]
go func(node *nodes.Node) {
ticker := time.NewTicker(400 * time.Millisecond)
defer ticker.Stop()
for {
<-ticker.C
node.ResetElectionTimer() // 不主动触发选举
}
}(node)
}
}
func SendKvCall(kvCall *nodes.LogEntryCall, node *nodes.Node) {
node.Mu.Lock()
defer node.Mu.Unlock()
node.MaxLogId++
logId := node.MaxLogId
rLogE := nodes.RaftLogEntry{LogE: kvCall.LogE, LogId: logId, Term: node.CurrTerm}
node.Log = append(node.Log, rLogE)
node.Storage.AppendLog(rLogE)
// 广播给其它节点
node.BroadCastKV()
}
func ClientWriteLog(t *testing.T, startLogid int, endLogid int, cWrite *clientPkg.Client) {
var s clientPkg.Status
for i := startLogid; i < endLogid; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = cWrite.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
}
func FindLeader(t *testing.T, nodeCollections []*nodes.Node) (i int) {
for i, node := range nodeCollections {
if node.State == nodes.Leader {
return i
}
}
t.Errorf("系统目前没有leader")
t.FailNow()
return 0
}
func CheckOneLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt != 1 {
t.Errorf("实际有%d个leader(!=1)", cnt)
t.FailNow()
}
}
func CheckNoLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt != 0 {
t.Errorf("实际有%d个leader(!=0)", cnt)
t.FailNow()
}
}
func CheckZeroOrOneLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt > 1 {
errmsg := fmt.Sprintf("%d个节点中,实际有%d个leader(>1)", len(nodeCollections), cnt)
WriteFailLog(nodeCollections[0].SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
}
func CheckIsLeader(t *testing.T, node *nodes.Node) {
node.Mu.Lock()
defer node.Mu.Unlock()
if node.State != nodes.Leader {
t.Errorf("[%s]不是leader", node.SelfId)
t.FailNow()
}
}
func CheckTerm(t *testing.T, node *nodes.Node, targetTerm int) {
node.Mu.Lock()
defer node.Mu.Unlock()
if node.CurrTerm != targetTerm {
t.Errorf("[%s]实际term=%d (!=%d)", node.SelfId, node.CurrTerm, targetTerm)
t.FailNow()
}
}
func CheckLogNum(t *testing.T, node *nodes.Node, targetnum int) {
node.Mu.Lock()
defer node.Mu.Unlock()
if len(node.Log) != targetnum {
t.Errorf("[%s]实际logNum=%d (!=%d)", node.SelfId, len(node.Log), targetnum)
t.FailNow()
}
}
func CheckSameLog(t *testing.T, nodeCollections []*nodes.Node) {
nodeCollections[0].Mu.Lock()
defer nodeCollections[0].Mu.Unlock()
standard_node := nodeCollections[0]
for i, node := range nodeCollections {
if i != 0 {
node.Mu.Lock()
if len(node.Log) != len(standard_node.Log) {
errmsg := fmt.Sprintf("[%s]和[%s]日志数量不一致", nodeCollections[0].SelfId, node.SelfId)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
for idx, log := range node.Log {
standard_log := standard_node.Log[idx]
if log.Term != standard_log.Term ||
log.LogE.Key != standard_log.LogE.Key ||
log.LogE.Value != standard_log.LogE.Value {
errmsg := fmt.Sprintf("[1]和[%s]日志id%d不一致", node.SelfId, idx)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
}
node.Mu.Unlock()
}
}
}
func CheckLeaderInvariant(t *testing.T, nodeCollections []*nodes.Node) {
leaderCnt := make(map[int]bool)
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
if _, exist := leaderCnt[node.CurrTerm]; exist {
errmsg := fmt.Sprintf("在%d有多个leader(%s)", node.CurrTerm, node.SelfId)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
} else {
leaderCnt[node.CurrTerm] = true
}
}
node.Mu.Unlock()
}
}
func CheckLogInvariant(t *testing.T, nodeCollections []*nodes.Node) {
nodeCollections[0].Mu.Lock()
defer nodeCollections[0].Mu.Unlock()
standard_node := nodeCollections[0]
standard_len := len(standard_node.Log)
for i, node := range nodeCollections {
if i != 0 {
node.Mu.Lock()
len2 := len(node.Log)
var shorti int
if len2 < standard_len {
shorti = len2
} else {
shorti = standard_len
}
if shorti == 0 {
node.Mu.Unlock()
continue
}
alreadySame := false
for i := shorti - 1; i >= 0; i-- {
standard_log := standard_node.Log[i]
log := node.Log[i]
if alreadySame {
if log.Term != standard_log.Term ||
log.LogE.Key != standard_log.LogE.Key ||
log.LogE.Value != standard_log.LogE.Value {
errmsg := fmt.Sprintf("[%s]和[%s]日志id%d不一致", standard_node.SelfId, node.SelfId, i)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
} else {
if log.Term == standard_log.Term &&
log.LogE.Key == standard_log.LogE.Key &&
log.LogE.Value == standard_log.LogE.Value {
alreadySame = true
}
}
}
node.Mu.Unlock()
}
}
}
func WriteFailLog(name string, errmsg string) {
f, _ := os.Create(name + ".log")
fmt.Fprint(f, errmsg)
f.Close()
}

+ 313
- 0
threadTest/election_test.go View File

@ -0,0 +1,313 @@
package threadTest
import (
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestInitElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
}
func TestRepeatElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 3)
}
func TestBelowHalfCandidateElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
}
func TestOverHalfCandidateElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
go nodeCollections[2].StartElection()
time.Sleep(time.Second)
CheckZeroOrOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
}
func TestRepeatVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
ctx.SetBehavior("1", "2", nodes.RetryRpc, 0, 2)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
ctx.SetBehavior("2", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckTerm(t, nodeCollections[0], 3)
}
func TestFailVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
ctx.SetBehavior("1", "2", nodes.FailRpc, 0, 0)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
ctx.SetBehavior("1", "3", nodes.FailRpc, 0, 0)
ctx.SetBehavior("1", "4", nodes.FailRpc, 0, 0)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckNoLeader(t, nodeCollections)
CheckTerm(t, nodeCollections[0], 3)
}
func TestDelayVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0)
}
nodeCollections[0].StartElection()
time.Sleep(2 * time.Second)
CheckNoLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, 50 * time.Millisecond, 0)
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 3)
}
}

+ 445
- 0
threadTest/fuzz/fuzz_test.go View File

@ -0,0 +1,445 @@
package fuzz
import (
"fmt"
"math/rand"
"os"
"runtime/debug"
"sync"
"testing"
"time"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"simple-kv-store/threadTest"
"strconv"
)
// 1.针对随机配置随机消息状态
func FuzzRaftBasic(f *testing.F) {
var seenSeeds sync.Map
// 添加初始种子
f.Add(int64(1))
fmt.Println("Running")
f.Fuzz(func(t *testing.T, seed int64) {
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
n := 3 + 2*(r.Intn(4))
fmt.Printf("随机了%d个节点\n", n)
logs := (r.Intn(10))
fmt.Printf("随机了%d份日志\n", logs)
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
for i := 0; i < n; i++ {
node, quitChan := threadTest.ExecuteNodeI(strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1), false, peerIds, threadTransport)
nodeCollections = append(nodeCollections, node)
node.RTTable.SetElectionTimeout(750 * time.Millisecond)
quitCollections = append(quitCollections, quitChan)
}
// 模拟 a-b 通讯行为
faultyNodes := injectRandomBehavior(ctx, r, peerIds)
time.Sleep(time.Second)
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < logs; i++ {
key := fmt.Sprintf("k%d", i)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
}
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
if !faultyNodes[node.SelfId] {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckSameLog(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, nodeCollections)
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
// 确保完成退出
nodeCollections[i].Mu.Lock()
if !nodeCollections[i].IsFinish {
nodeCollections[i].IsFinish = true
}
nodeCollections[i].Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1))
os.RemoveAll("storage/node" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1))
}
})
}
// 注入节点间行为
func injectRandomBehavior(ctx *nodes.Ctx, r *rand.Rand, peers []string) map[string]bool /*id:Isfault*/ {
behaviors := []nodes.CallBehavior{
nodes.FailRpc,
nodes.DelayRpc,
nodes.RetryRpc,
}
n := len(peers)
maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点
// 随机选择出问题的节点
shuffled := append([]string(nil), peers...)
r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
faultyNodes := make(map[string]bool)
for i := 0; i < maxFaulty; i++ {
faultyNodes[shuffled[i]] = true
}
for _, one := range peers {
if faultyNodes[one] {
b := behaviors[r.Intn(len(behaviors))]
delay := time.Duration(r.Intn(100)) * time.Millisecond
switch b {
case nodes.FailRpc:
fmt.Printf("[%s]的异常行为是fail\n", one)
case nodes.DelayRpc:
fmt.Printf("[%s]的异常行为是delay\n", one)
case nodes.RetryRpc:
fmt.Printf("[%s]的异常行为是retry\n", one)
}
for _, two := range peers {
if one == two {
continue
}
if faultyNodes[one] && faultyNodes[two] {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, b, delay, 2)
ctx.SetBehavior(two, one, b, delay, 2)
}
}
}
}
return faultyNodes
}
// 2.对一个长时间运行的系统,注入随机行为
func FuzzRaftRobust(f *testing.F) {
var seenSeeds sync.Map
var fuzzMu sync.Mutex
// 添加初始种子
f.Add(int64(0))
fmt.Println("Running")
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
quitCollections := make(map[string]chan struct{})
nodeCollections := make(map[string]*nodes.Node)
for i := 0; i < n; i++ {
id := strconv.Itoa(i+1)
node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport)
nodeCollections[id] = node
quitCollections[id] = quitChan
}
f.Fuzz(func(t *testing.T, seed int64) {
fuzzMu.Lock()
defer fuzzMu.Unlock()
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections)
key := fmt.Sprintf("k%d", seed % 10)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
_, exist := faultyNodes[node.SelfId]
if !exist {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckLogInvariant(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, rightNodeCollections)
// ResetFaultyNodes
threadTransport.ResetConnectivity()
for id, isrestart := range faultyNodes {
if !isrestart {
for _, peerIds := range peerIds {
if id == peerIds {
continue
}
ctx.SetBehavior(id, peerIds, nodes.NormalRpc, 0, 0)
ctx.SetBehavior(peerIds, id, nodes.NormalRpc, 0, 0)
}
} else {
newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport)
quitCollections[id] = quitChan
nodeCollections[id] = newNode
}
fmt.Printf("[%s]恢复异常\n", id)
}
})
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for id, node := range nodeCollections {
// 确保完成退出
node.Mu.Lock()
if !node.IsFinish {
node.IsFinish = true
}
node.Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + id)
os.RemoveAll("storage/node" + id)
}
}
// 3.综合
func FuzzRaftPlus(f *testing.F) {
var seenSeeds sync.Map
// 添加初始种子
f.Add(int64(0))
fmt.Println("Running")
f.Fuzz(func(t *testing.T, seed int64) {
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
n := 3 + 2*(r.Intn(4))
fmt.Printf("随机了%d个节点\n", n)
ElectionTimeOut := 500 + r.Intn(500)
fmt.Printf("随机的投票超时时间:%d\n", ElectionTimeOut)
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
quitCollections := make(map[string]chan struct{})
nodeCollections := make(map[string]*nodes.Node)
for i := 0; i < n; i++ {
id := strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1)
node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport)
nodeCollections[id] = node
node.RTTable.SetElectionTimeout(time.Duration(ElectionTimeOut) * time.Millisecond)
quitCollections[id] = quitChan
}
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < 5; i++ { // 模拟10次异常
fmt.Printf("第%d轮异常注入开始\n", i + 1)
faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections)
key := fmt.Sprintf("k%d", i)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
_, exist := faultyNodes[node.SelfId]
if !exist {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckLogInvariant(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, rightNodeCollections)
// ResetFaultyNodes
threadTransport.ResetConnectivity()
for id, isrestart := range faultyNodes {
if !isrestart {
for _, peerId := range peerIds {
if id == peerId {
continue
}
ctx.SetBehavior(id, peerId, nodes.NormalRpc, 0, 0)
ctx.SetBehavior(peerId, id, nodes.NormalRpc, 0, 0)
}
} else {
newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport)
quitCollections[id] = quitChan
nodeCollections[id] = newNode
}
fmt.Printf("[%s]恢复异常\n", id)
}
}
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for id, node := range nodeCollections {
// 确保完成退出
node.Mu.Lock()
if !node.IsFinish {
node.IsFinish = true
}
node.Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + id)
os.RemoveAll("storage/node" + id)
}
})
}
func injectRandomBehavior2(ctx *nodes.Ctx, r *rand.Rand, peers []string, tran *nodes.ThreadTransport, quitCollections map[string]chan struct{}) map[string]bool /*id:needRestart*/ {
n := len(peers)
maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点
// 随机选择出问题的节点
shuffled := append([]string(nil), peers...)
r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
faultyNodes := make(map[string]bool)
for i := 0; i < maxFaulty; i++ {
faultyNodes[shuffled[i]] = false
}
PartitionNodes := make(map[string]bool)
for _, one := range peers {
_, exist := faultyNodes[one]
if exist {
b := r.Intn(5)
switch b {
case 0:
fmt.Printf("[%s]的异常行为是fail\n", one)
for _, two := range peers {
if one == two {
continue
}
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
}
case 1:
fmt.Printf("[%s]的异常行为是delay\n", one)
t := r.Intn(100)
fmt.Printf("[%s]的delay time = %d\n", one, t)
delay := time.Duration(t) * time.Millisecond
for _, two := range peers {
if one == two {
continue
}
_, exist2 := faultyNodes[two]
if exist2 {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, nodes.DelayRpc, delay, 0)
ctx.SetBehavior(two, one, nodes.DelayRpc, delay, 0)
}
}
case 2:
fmt.Printf("[%s]的异常行为是retry\n", one)
for _, two := range peers {
if one == two {
continue
}
_, exist2 := faultyNodes[two]
if exist2 {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, nodes.RetryRpc, 0, 2)
ctx.SetBehavior(two, one, nodes.RetryRpc, 0, 2)
}
}
case 3:
fmt.Printf("[%s]的异常行为是stop\n", one)
faultyNodes[one] = true
close(quitCollections[one])
case 4:
fmt.Printf("[%s]的异常行为是partition\n", one)
PartitionNodes[one] = true
}
}
}
for id, _ := range PartitionNodes {
for _, two := range peers {
if !PartitionNodes[two] {
tran.SetConnectivity(id, two, false)
tran.SetConnectivity(two, id, false)
}
}
}
return faultyNodes
}

+ 326
- 0
threadTest/log_replication_test.go View File

@ -0,0 +1,326 @@
package threadTest
import (
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestNormalReplication(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestParallelReplication(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
go nodeCollections[0].BroadCastKV()
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestFollowerLagging(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
close(quitCollections[1])
time.Sleep(time.Second)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
node, q := ExecuteStaticNodeI("2", true, peerIds, threadTransport)
quitCollections[1] = q
nodeCollections[1] = node
nodeCollections[1].State = nodes.Follower
StopElectionReset(nodeCollections[1:2])
nodeCollections[0].BroadCastKV()
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestFailLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.FailRpc, 0, 0)
}
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 1; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 0)
}
}
func TestRepeatLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestDelayLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0)
}
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Millisecond * 100)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.NormalRpc, 0, 0)
}
for i := 5; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second * 2)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}

+ 245
- 0
threadTest/network_partition_test.go View File

@ -0,0 +1,245 @@
package threadTest
import (
"fmt"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"strings"
"testing"
"time"
)
func TestBasicConnectivity(t *testing.T) {
transport := nodes.NewThreadTransport(nodes.NewCtx())
transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10))
transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10))
// 断开 A 和 B
transport.SetConnectivity("1", "2", false)
err := transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if err == nil {
t.Errorf("Expected network partition error, but got nil")
}
// 恢复连接
transport.SetConnectivity("1", "2", true)
err = transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if !strings.Contains(err.Error(), "RPC 调用超时") {
t.Errorf("Expected success, but got error: %v", err)
}
}
func TestSingelPartition(t *testing.T) {
// 登记结点信息
n := 3
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
fmt.Println("开始分区模拟1")
var leaderNo int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderNo = i
for j := 0; j < n; j++ {
if i != j { // 切断其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false)
}
}
}
}
time.Sleep(2 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j { // 恢复其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[leaderNo].SelfId, true)
}
}
time.Sleep(1 * time.Second)
var leaderCnt int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
leaderNo = i
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
fmt.Println("开始分区模拟2")
for j := 0; j < n; j++ {
if leaderNo != j { // 切断leader到其它节点的消息
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, false)
}
}
time.Sleep(1 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
leaderCnt = 0
for j := 0; j < n; j++ {
if nodeCollections[j].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j {
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, true)
}
}
time.Sleep(time.Second)
// 日志一致性检查
for i := 0; i < n; i++ {
if len(nodeCollections[i].Log) != 5 {
t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log)))
}
}
}
func TestQuorumPartition(t *testing.T) {
// 登记结点信息
n := 5 // 奇数,模拟不超过半数节点分区
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
fmt.Println("开始分区模拟1")
for i := 0; i < n / 2; i++ {
for j := n / 2; j < n; j++ {
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false)
threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, false)
}
}
time.Sleep(2 * time.Second)
leaderCnt := 0
for i := 0; i < n / 2; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 0 {
t.Errorf("少数分区不应该产生leader")
}
for i := n / 2; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多数分区应该产生一个leader")
}
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 恢复网络
for i := 0; i < n / 2; i++ {
for j := n / 2; j < n; j++ {
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, true)
threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, true)
}
}
time.Sleep(1 * time.Second)
leaderCnt = 0
for j := 0; j < n; j++ {
if nodeCollections[j].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
// 日志一致性检查
for i := 0; i < n; i++ {
if len(nodeCollections[i].Log) != 5 {
t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log)))
}
}
}

+ 129
- 0
threadTest/restart_node_test.go View File

@ -0,0 +1,129 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
cWrite := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
ClientWriteLog(t, 0, 5, cWrite)
time.Sleep(time.Second) // 等待写入完毕
// 模拟结点轮流崩溃
for i := 0; i < n; i++ {
close(quitCollections[i])
time.Sleep(time.Second)
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), true, peerIds, threadTransport)
quitCollections[i] = quitChan
time.Sleep(time.Second) // 等待启动完毕
}
// client启动
cRead := clientPkg.NewClient("0", peerIds, threadTransport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}
func TestRestartWhileWriting(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
leaderIdx := FindLeader(t, nodeCollections)
// client启动, 连接任意节点
cWrite := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
go ClientWriteLog(t, 0, 5, cWrite)
go func() {
close(quitCollections[leaderIdx])
n, quitChan := ExecuteNodeI(strconv.Itoa(leaderIdx + 1), true, peerIds, threadTransport)
quitCollections[leaderIdx] = quitChan
nodeCollections[leaderIdx] = n
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
cRead := clientPkg.NewClient("0", peerIds, threadTransport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
CheckLogNum(t, nodeCollections[leaderIdx], 5)
}

+ 192
- 0
threadTest/server_client_test.go View File

@ -0,0 +1,192 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 10; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}
func TestRepeatClientReq(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < n; i++ {
ctx.SetBehavior("", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 10; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestParallelClientReq(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c1 := clientPkg.NewClient("0", peerIds, threadTransport)
c2 := clientPkg.NewClient("1", peerIds, threadTransport)
// 写入
go ClientWriteLog(t, 0, 10, c1)
go ClientWriteLog(t, 0, 10, c2)
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s := c1.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 20)
}
}

Loading…
Cancel
Save