23 İşlemeler

Yazar SHA1 Mesaj Tarih
  李度 ea6fd20628 Merge pull request 'ld' (#2) from ld into master 10 ay önce
  augurier df82b6a8f0 报告补充 10 ay önce
  augurier 96b4df5d21 项目报告 10 ay önce
  augurier 07c5a3719f 适应两种版本 10 ay önce
  augurier 1bfa17a735 完整的fuzz 10 ay önce
  augurier 33a41f95c0 大规模debug,通过随机消息的fuzz_test 10 ay önce
  augurier a460aba985 补充测试 10 ay önce
  augurier 813c40913b 客户端请求幂等性 10 ay önce
  augurier 3ef099f3e8 细粒度消息控制 10 ay önce
  augurier 6ce161187a replication测试补充 10 ay önce
  augurier da415bfe37 election测试补充 10 ay önce
  augurier 6b29fe96b1 持久化方式从json迁移为leveldb,保证写入原子性 10 ay önce
  augurier 239cd694ab 补充分区测试 10 ay önce
  augurier f669abd7ad 分区测试(以及适应细粒度的修改) 10 ay önce
  augurier 0b7590ee2a 增加transport接口,支持线程模拟 10 ay önce
  augurier 053dbe107f 封装了rpc超时,修复并发candidate时的异常 11 ay önce
  augurier d83518d633 gitnore忽略.log 11 ay önce
  augurier e86202ca9b 增加客户端找主 11 ay önce
  augurier 28dc22fb16 修改了节点重启中,一些理解错误 11 ay önce
  augurier 15c7f2ad4f 项目框架调整 11 ay önce
  augurier fcf27b5770 节点数据本身持久化 11 ay önce
  augurier b398f10900 简单选主逻辑 11 ay önce
  augurier 1181f6d796 raft日志同步部分 11 ay önce
33 değiştirilmiş dosya ile 3978 ekleme ve 470 silme
  1. +3
    -0
      .gitignore
  2. +232
    -16
      README.md
  3. +25
    -24
      cmd/main.go
  4. +122
    -46
      internal/client/client_node.go
  5. +15
    -0
      internal/logprovider/traceback.go
  6. +217
    -77
      internal/nodes/init.go
  7. +19
    -6
      internal/nodes/log.go
  8. +43
    -72
      internal/nodes/node.go
  9. +194
    -0
      internal/nodes/node_storage.go
  10. +46
    -0
      internal/nodes/random_timetable.go
  11. +111
    -0
      internal/nodes/real_transport.go
  12. +271
    -0
      internal/nodes/replica.go
  13. +68
    -16
      internal/nodes/server_node.go
  14. +65
    -0
      internal/nodes/simulate_ctx.go
  15. +234
    -0
      internal/nodes/thread_transport.go
  16. +10
    -0
      internal/nodes/transport.go
  17. +215
    -0
      internal/nodes/vote.go
  18. BIN
      pics/plan.png
  19. BIN
      pics/plus.png
  20. BIN
      pics/robust.png
  21. BIN
      raft第二次汇报.pptx
  22. +0
    -61
      scripts/run.sh
  23. +7
    -15
      test/common.go
  24. +0
    -112
      test/restart_follower_test.go
  25. +104
    -0
      test/restart_node_test.go
  26. +22
    -25
      test/server_client_test.go
  27. +305
    -0
      threadTest/common.go
  28. +313
    -0
      threadTest/election_test.go
  29. +445
    -0
      threadTest/fuzz/fuzz_test.go
  30. +326
    -0
      threadTest/log_replication_test.go
  31. +245
    -0
      threadTest/network_partition_test.go
  32. +129
    -0
      threadTest/restart_node_test.go
  33. +192
    -0
      threadTest/server_client_test.go

+ 3
- 0
.gitignore Dosyayı Görüntüle

@ -26,3 +26,6 @@ go.work
main
leveldb
storage
*.log
testdata

+ 232
- 16
README.md Dosyayı Görüntüle

@ -1,24 +1,240 @@
# go-raft-kv
---
# 简介
本项目是基于go语言实现的一个raft算法分布式kv数据库。
基于go语言实现分布式kv数据库
项目亮点如下:
支持线程与进程两种通信机制,方便测试与部署切换的同时,降低了系统耦合度
提供基于状态不变量的fuzz测试机制,增强健壮性
项目高度模块化,便于扩展更多功能
本报告主要起到工作总结,以及辅助阅读代码的作用(并贴出了一些关键部分),一些工作中遇到的具体问题、思考和收获不太方便用简洁的文字描述,放在了汇报的ppt中。
# 项目框架
```
---cmd
---main.go 进程版的启动
---internal
---client 客户端使用节点提供的读写功能
---logprovider 封装了简单的日志打印,方便调试
---nodes 分布式核心代码
init.go 节点的初始化(包含两种版本),和大循环启动
log.go 节点存储的entry相关数据结构
node_storage.go 抽象了节点数据持久化方法,序列化后存到leveldb里
node.go 节点的相关数据结构
random_timetake.go 控制系统中的随机时间
real_transport.go 进程版的rpc通讯
replica.go 日志复制相关逻辑
server_node.go 节点作为server为 client提供的功能(读写)
simulate_ctx.go 测试中控制通讯消息行为
thread_transport.go 线程版的通讯方法
transport.go 为两种系统提供的基类通讯接口
vote.go 选主相关逻辑
---test 进程版的测试
---threadTest 线程版的测试
---fuzz 随机测试部分
election_test.go 选举部分
log_replication_test.go 日志复制部分
network_partition_test.go 网络分区部分
restart_node_test.go 恢复测试
server_client_test.go 客户端交互
```
# raft系统部分
## 主要流程
在init.go中每个节点会初始化,发布监听线程,然后在start函数开启主循环。主循环中每隔一段心跳时间,如果判断自己是leader,就broadcast并resetElectionTime。(init.go)
```go
func Start(node *Node, quitChan chan struct{}) {
node.Mu.Lock()
node.State = Follower // 所有节点以 Follower 状态启动
node.Mu.Unlock()
node.ResetElectionTimer() // 启动选举超时定时器
go func() {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.SelfId)
return // 退出 goroutine
case <-ticker.C:
node.Mu.Lock()
state := node.State
node.Mu.Unlock()
switch state {
case Follower:
// 监听心跳超时
case Leader:
// 发送心跳
node.ResetElectionTimer() // leader 不主动触发选举
node.BroadCastKV()
}
}
}
}()
}
```
两个监听方法:
appendEntries:broadcast中遍历每个peerNode,在sendkv中进行call,实现日志复制相关逻辑。
requestVote:每个node有个ResetElectionTimer定时器,一段时间没有reset它就会StartElection,其中遍历每个peerNode,在sendRequestVote中进行call,实现选主相关逻辑。leader会在心跳时reset(避免自己的再选举),follower则会在收到appendentries时reset。
```go
func (node *Node) ResetElectionTimer() {
node.MuElection.Lock()
defer node.MuElection.Unlock()
if node.ElectionTimer == nil {
node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout())
go func() {
for {
<-node.ElectionTimer.C
node.StartElection()
}
}()
} else {
node.ElectionTimer.Stop()
node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
}
}
```
## 客户端工作原理(client & server_node.go)
### 客户端写
客户端每次会随机连上集群中一个节点,此时有四种情况:
a 节点认为自己是leader,直接处理请求(记录后broadcast)
b 节点认为自己是follower,且有知道的leader,返回leader的id。客户端再连接这个新的id,新节点重新分析四种情况。
c 节点认为自己是follower,但不知道leader是谁,返回空的id。客户端再随机连一个节点
d 连接超时,客户端重新随机连一个节点
```go
func (client *Client) Write(kv nodes.LogEntry) Status {
kvCall := nodes.LogEntryCall{LogE: kv,
Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}}
client.NextLogId++
c := client.FindActiveNode()
var err error
timeout := time.Second
deadline := time.Now().Add(timeout)
for { // 根据存活节点的反馈,直到找到leader
if time.Now().After(deadline) {
return Fail
}
var reply nodes.ServerReply
reply.Isleader = false
callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
if !reply.Isleader { // 对方不是leader,根据反馈找leader
leaderId := reply.LeaderId
client.CloseRpcClient(c)
if leaderId == "" { // 这个节点不知道leader是谁,再随机找
c = client.FindActiveNode()
} else { // dial leader
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId)
if err != nil { // dial失败,重新找下一个存活节点
c = client.FindActiveNode()
}
}
} else { // 成功
client.CloseRpcClient(c)
return Ok
}
}
}
```
### 客户端读
随机连上集群中一个节点,读它commit的kv。
## 重要数据持久化
封装在了node_storage.go中,主要是setTermAndVote序列化后写入leveldb,log的写入。在这些数据变化时调用它们进行持久化,以及相应的恢复时读取。
```go
// SetTermAndVote 原子更新 term 和 vote
func (rs *RaftStorage) SetTermAndVote(term int, candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
batch.Put([]byte("current_term"), []byte(strconv.Itoa(term)))
batch.Put([]byte("voted_for"), []byte(candidate))
err := rs.db.Write(batch, nil) // 原子提交
if err != nil {
log.Error("SetTermAndVote 持久化失败:", zap.Error(err))
}
}
```
## 两种系统的切换
将raft系统中,所有涉及网络的部分提取出来,抽象为dial和call方法,作为每个node的接口类transport的两个基类方法,进程版与线程版的transport派生类分别实现,使得相互之间实现隔离。
```go
type Transport interface {
DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
}
```
进程版:dial和call均为go原生rpc库的方法,加一层timeout封装(real_transport.go)
线程版:threadTransport为每个节点共用,节点初始化时把一个自己的chan注册进里面的map,然后go一个线程去监听它,收到req后去调用自己对应的函数(thread_transport.go)
# 测试部分
## 单元测试
从Leader选举、日志复制、崩溃恢复、网络分区、客户端交互五个维度,对系统进行分模块的测试。测试中夹杂消息状态的细粒度模拟,尽可能在项目前中期验证代码与思路的一致性,避免大的问题。
## fuzz测试
分为不同节点、系统随机时间配置测试异常的多系统随机(basic),与对单个系统注入多个随机异常的单系统随机(robust),这两个维度,以及最后综合两个维度的进一步测试(plus)。
测试中加入了raft的TLA标准,作为测试断言,确保系统在运行中的稳定性。
fuzz test不仅覆盖了单元测试的内容,也在随机的测试中发现了更多边界条件的异常,以及通过系统状态的不变量检测,确保系统在不同配置下支持长时间的运行中保持正确可用。
![alt text](pics/plus.png)
![alt text](pics/robust.png)
## bug的简单记录
LogId0的歧义,不同接口对接日志编号出现问题
随机选举超时相同导致的candidate卡死问题
重要数据持久化不原子、与状态机概念混淆
客户端缺乏消息唯一标识,导致系统重复执行
重构系统过程中lock使用不当
伪同步接口的异步语义陷阱
测试和系统混合产生的bug(延迟导致的超时、退出不完全导致的异常、文件系统异常、lock不当)
# 环境与运行
使用环境是wsl+ubuntu
go mod download安装依赖
./scripts/build.sh 会在根目录下编译出main
./scripts/run.sh 运行三个节点,目前能在终端进行读入,leader(n1)节点输出send log,其余节点输出receive log。终端输入后如果超时就退出(脚本运行时间可以在其中调整)。
# 注意
脚本第一次运行需要权限获取 chmod +x <脚本>
如果出现tcp listen error可能是因为之前的进程没用正常退出,占用了端口
lsof -i :9091查看pid
kill -9 <pid>杀死进程
## 关于测试
通过新开进程的方式创建节点,如果通过线程创建,会出现重复注册rpc问题
# todo list
消息通讯异常的处理
kv本地持久化
崩溃与恢复(以及对应的测试)
./scripts/build.sh 会在根目录下编译出main(进程级的测试需要)
# 参考资料
In Search of an Understandable Consensus Algorithm
Consensus: Bridging Theory and Practice
Raft TLA+ Specification
全项目除了logprovider文件夹下的一些go的日志库使用参考了一篇博客的封装,其余皆为独立原创。
# 分工
| 姓名 | 工作 | 贡献度 |
|--------|--------|--------|
| 李度 | raft系统设计+实现,测试设计+实现 | 75% |
| 马也驰 | raft系统设计,测试设计+实现 | 25% |
![alt text](pics/plan.png)

+ 25
- 24
cmd/main.go Dosyayı Görüntüle

@ -3,6 +3,7 @@ package main
import (
"flag"
"fmt"
"github.com/syndtr/goleveldb/leveldb"
"os"
"os/signal"
"simple-kv-store/internal/logprovider"
@ -10,7 +11,6 @@ import (
"strconv"
"strings"
"syscall"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
@ -29,11 +29,9 @@ func main() {
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
port := flag.String("port", ":9091", "rpc listen port")
cluster := flag.String("cluster", "127.0.0.1:9092,127.0.0.1:9093", "comma sep")
cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep")
id := flag.String("id", "1", "node ID")
pipe := flag.String("pipe", "", "input from scripts")
isLeader := flag.Bool("isleader", false, "init node state")
isNewDb := flag.Bool("isNewDb", true, "new test or restart")
isRestart := flag.Bool("isRestart", false, "new test or restart")
// 参数解析
flag.Parse()
@ -42,43 +40,46 @@ func main() {
idCnt := 1
selfi, err := strconv.Atoi(*id)
if err != nil {
log.Error("figure id only")
log.Fatal("figure id only")
}
for _, addr := range clusters {
if idCnt == selfi {
idCnt++ // 命令行cluster按id排序传入,记录时跳过自己的id,先保证所有节点互相记录的id一致
continue
}
idClusterPairs[strconv.Itoa(idCnt)] = addr
idClusterPairs[strconv.Itoa(idCnt)] = addr
idCnt++
}
if *isNewDb {
os.RemoveAll("leveldb/simple-kv-store" + *id)
// storage/文件夹下为node重要数据持久化数据库,节点一旦创建成功就不能被删除
if !*isRestart {
os.RemoveAll("storage/node" + *id)
}
// 打开或创建每个结点自己的数据库
// 创建每个结点自己的数据库。这里一开始理解上有些误区,状态机的状态恢复应该靠节点的持久化log,
// 而用leveldb模拟状态机,造成了状态机本身的持久化,因此通过删去旧db避免这一矛盾
// 因此leveldb/文件夹下为状态机模拟数据库,每次节点启动都需要删除该数据库
os.RemoveAll("leveldb/simple-kv-store" + *id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil)
if err != nil {
log.Fatal("Failed to open database: ", zap.Error(err))
}
defer db.Close() // 确保数据库在使用完毕后关闭
iter := db.NewIterator(nil, nil)
defer iter.Release()
// 计数
count := 0
for iter.Next() {
count++
}
fmt.Printf(*id + "结点目前有数据:%d\n", count)
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + *id)
defer storage.Close()
// 初始化
node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, *isRestart)
node := nodes.Init(*id, idClusterPairs, *pipe, db)
log.Info("id: " + *id + "节点开始监听: " + *port + "端口")
// 监听rpc
node.Rpc(*port)
// 开启 raft
nodes.Start(node, *isLeader)
quitChan := make(chan struct{}, 1)
nodes.Start(node, quitChan)
sig := <-sigs
fmt.Println("node_" + *id + "接收到信号:", sig)
fmt.Println("node_"+ *id +"接收到信号:", sig)
close(quitChan)
}

+ 122
- 46
internal/client/client_node.go Dosyayı Görüntüle

@ -1,9 +1,10 @@
package clientPkg
import (
"net/rpc"
"math/rand"
"simple-kv-store/internal/logprovider"
"simple-kv-store/internal/nodes"
"time"
"go.uber.org/zap"
)
@ -11,9 +12,11 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
type Client struct {
// 连接的server端节点(node1)
ServerId string
Address string
ClientId string // 每个client唯一标识
NextLogId int
// 连接的server端节点群
PeerIds []string
Transport nodes.Transport
}
type Status = uint8
@ -24,32 +27,83 @@ const (
Fail
)
func (client *Client) Write(kvCall nodes.LogEntryCall) Status {
log.Info("client write request key :" + kvCall.LogE.Key)
c, err := rpc.DialHTTP("tcp", client.Address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return Fail
}
func NewClient(clientId string, peerIds []string, transport nodes.Transport) *Client {
return &Client{ClientId: clientId, NextLogId: 0, PeerIds: peerIds, Transport: transport}
}
defer func(server *rpc.Client) {
err := c.Close()
func getRandomAddress(peerIds []string) string {
// 随机选一个 id
randomKey := peerIds[rand.Intn(len(peerIds))]
return randomKey
}
func (client *Client) FindActiveNode() nodes.ClientInterface {
var err error
var c nodes.ClientInterface
for { // 直到找到一个可连接的节点(保证至少一个节点活着)
peerId := getRandomAddress(client.PeerIds)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", peerId)
if err != nil {
log.Error("client close err: ", zap.Error(err))
log.Error("dialing: ", zap.Error(err))
} else {
log.Sugar().Infof("client发现活跃节点[%s]", peerId)
return c
}
}(c)
}
}
var reply nodes.ServerReply
callErr := c.Call("Node.WriteKV", kvCall, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
return Fail
func (client *Client) CloseRpcClient(c nodes.ClientInterface) {
err := c.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}
func (client *Client) Write(kv nodes.LogEntry) Status {
defer logprovider.DebugTraceback("client")
log.Info("client write request key :" + kv.Key)
kvCall := nodes.LogEntryCall{LogE: kv,
Id: nodes.LogEntryCallId{ClientId: client.ClientId, LogId: client.NextLogId}}
client.NextLogId++
c := client.FindActiveNode()
var err error
timeout := time.Second
deadline := time.Now().Add(timeout)
for { // 根据存活节点的反馈,直到找到leader
if time.Now().After(deadline) {
log.Error("系统繁忙,疑似出错")
return Fail
}
var reply nodes.ServerReply
reply.Isleader = false
callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
if reply.Isconnect { // 发送成功
return Ok
} else { // 失败
return Fail
if !reply.Isleader { // 对方不是leader,根据反馈找leader
leaderId := reply.LeaderId
client.CloseRpcClient(c)
if leaderId == "" { // 这个节点不知道leader是谁,再随机找
c = client.FindActiveNode()
} else { // dial leader
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId)
if err != nil { // dial失败,重新找下一个存活节点
c = client.FindActiveNode()
}
}
} else { // 成功
client.CloseRpcClient(c)
return Ok
}
}
}
@ -58,37 +112,59 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value
if value == nil {
return Fail
}
c, err := rpc.DialHTTP("tcp", client.Address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return Fail
}
defer func(server *rpc.Client) {
err := c.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
var c nodes.ClientInterface
for {
c = client.FindActiveNode()
var reply nodes.ServerReply
callErr := client.Transport.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
continue
}
}(c)
var reply nodes.ServerReply
callErr := c.Call("Node.ReadKey", key, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
return Fail
}
if reply.Isconnect { // 发送成功
// 目前一定发送成功
if reply.HaveValue {
*value = reply.Value
client.CloseRpcClient(c)
return Ok
} else {
client.CloseRpcClient(c)
return NotFound
}
}
}
func (client *Client) FindLeader() string {
var arg struct{}
var reply nodes.FindLeaderReply
reply.Isleader = false
c := client.FindActiveNode()
var err error
for !reply.Isleader { // 根据存活节点的反馈,直到找到leader
callErr := client.Transport.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
c = client.FindActiveNode()
continue
}
} else { // 失败
return Fail
if !reply.Isleader { // 对方不是leader,根据反馈找leader
client.CloseRpcClient(c)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", reply.LeaderId)
for err != nil { // 重新找下一个存活节点
c = client.FindActiveNode()
}
} else { // 成功
client.CloseRpcClient(c)
return reply.LeaderId
}
}
log.Fatal("客户端会一直找存活节点,不会运行到这里")
return "fault"
}

+ 15
- 0
internal/logprovider/traceback.go Dosyayı Görüntüle

@ -0,0 +1,15 @@
package logprovider
import (
"fmt"
"os"
"runtime/debug"
)
func DebugTraceback(errFuncName string) {
if r := recover(); r != nil {
msg := fmt.Sprintf("panic in goroutine: %v\n%s", r, debug.Stack())
f, _ := os.Create(errFuncName + ".log")
fmt.Fprint(f, msg)
f.Close()
}
}

+ 217
- 77
internal/nodes/init.go Dosyayı Görüntüle

@ -1,13 +1,12 @@
package nodes
import (
"io"
"errors"
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"simple-kv-store/internal/logprovider"
"strconv"
"time"
"github.com/syndtr/goleveldb/leveldb"
@ -16,88 +15,48 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
func newNode(address string) *Public_node_info {
return &Public_node_info{
connect: false,
address: address,
}
}
func Init(id string, nodeAddr map[string]string, pipe string, db *leveldb.DB) *Node {
ns := make(map[string]*Public_node_info)
for id, addr := range nodeAddr {
ns[id] = newNode(addr)
// 运行在进程上的初始化 + rpc注册
func InitRPCNode(SelfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
var nodeIds []string
for id := range nodeAddr {
nodeIds = append(nodeIds, id)
}
// 创建节点
return &Node{
selfId: id,
nodes: ns,
pipeAddr: pipe,
maxLogId: 0,
log: make(map[int]LogEntry),
db: db,
node := &Node{
SelfId: SelfId,
LeaderId: "",
Nodes: nodeIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: &HTTPTransport{NodeMap: nodeAddr},
RTTable: NewRTTable(),
SeenRequests: make(map[LogEntryCallId]bool),
IsFinish: false,
}
}
func Start(node *Node, isLeader bool) {
if isLeader {
node.state = Candidate // 需要身份转变
} else {
node.state = Follower
node.initLeaderState()
if isRestart {
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
}
go func() {
for {
switch node.state {
case Follower:
case Candidate:
// candidate发布一个监听输入线程后,变成leader
node.state = Leader
go func() {
if node.pipeAddr == "" { // 客户端远程调用server_node方法
log.Info("请运行客户端进程进行读写")
} else { // 命令行提供了管道,支持管道(键盘)输入
pipe, err := os.Open(node.pipeAddr)
if err != nil {
log.Error("Failed to open pipe")
}
defer pipe.Close()
// 不断读取管道中的输入
buffer := make([]byte, 256)
for {
n, err := pipe.Read(buffer)
if err != nil && err != io.EOF {
log.Error("Error reading from pipe")
}
if n > 0 {
input := string(buffer[:n])
// 将用户输入封装成一个 LogEntry
kv := LogEntry{input, ""} // 目前键盘输入key,value 0
logId := node.maxLogId
node.maxLogId++
node.log[logId] = kv
log.Info("send : logId = " + strconv.Itoa(logId) + ", key = " + input)
// 广播给其它节点
kvCall := LogEntryCall{kv, Normal}
node.BroadCastKV(logId, kvCall)
// 持久化
node.db.Put([]byte(kv.Key), []byte(kv.Value), nil)
}
}
}
}()
case Leader:
time.Sleep(50 * time.Millisecond)
}
}
}()
log.Sugar().Infof("[%s]开始监听" + port + "端口", SelfId)
node.ListenPort(port)
return node
}
func (node *Node) Rpc(port string) {
func (node *Node) ListenPort(port string) {
err := rpc.Register(node)
if err != nil {
log.Fatal("rpc register failed", zap.Error(err))
@ -115,3 +74,184 @@ func (node *Node) Rpc(port string) {
}
}()
}
// 线程模拟的初始化
func InitThreadNode(SelfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) {
rpcChan := make(chan RPCRequest, 100) // 要监听的chan
// 创建节点
node := &Node{
SelfId: SelfId,
LeaderId: "",
Nodes: peerIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: threadTransport,
RTTable: NewRTTable(),
SeenRequests: make(map[LogEntryCallId]bool),
IsFinish: false,
}
node.initLeaderState()
if isRestart {
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
}
threadTransport.RegisterNodeChan(SelfId, rpcChan)
quitChan := make(chan struct{}, 1)
go node.listenForChan(rpcChan, quitChan)
return node, quitChan
}
func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) {
defer logprovider.DebugTraceback("listen")
for {
select {
case req := <-rpcChan:
switch req.Behavior {
case DelayRpc:
threadTran, ok := node.Transport.(*ThreadTransport)
if !ok {
log.Fatal("无效的delayRpc模式")
}
duration, ok2 := threadTran.Ctx.GetDelay(req.SourceId, node.SelfId)
if !ok2 {
log.Fatal("没有设置对应的delay时间")
}
go node.switchReq(req, duration)
case FailRpc:
continue
default:
go node.switchReq(req, 0)
}
case <-quitChan:
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s] 监听线程收到退出信号", node.SelfId)
node.IsFinish = true
node.Db.Close()
node.Storage.Close()
return
}
}
}
func (node *Node) switchReq(req RPCRequest, delayTime time.Duration) {
defer logprovider.DebugTraceback("switch")
time.Sleep(delayTime)
switch req.ServiceMethod {
case "Node.AppendEntries":
arg, ok := req.Args.(*AppendEntriesArg)
resp, ok2 := req.Reply.(*AppendEntriesReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for AppendEntries")
} else {
var respCopy AppendEntriesReply
err := node.AppendEntries(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.RequestVote":
arg, ok := req.Args.(*RequestVoteArgs)
resp, ok2 := req.Reply.(*RequestVoteReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for RequestVote")
} else {
var respCopy RequestVoteReply
err := node.RequestVote(arg, &respCopy)
*resp = respCopy
req.Done <- err
}
case "Node.WriteKV":
arg, ok := req.Args.(*LogEntryCall)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for WriteKV")
} else {
req.Done <- node.WriteKV(arg, resp)
}
case "Node.ReadKey":
arg, ok := req.Args.(*string)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for ReadKey")
} else {
req.Done <- node.ReadKey(arg, resp)
}
case "Node.FindLeader":
arg, ok := req.Args.(struct{})
resp, ok2 := req.Reply.(*FindLeaderReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for FindLeader")
} else {
req.Done <- node.FindLeader(arg, resp)
}
default:
req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod)
}
}
// 共同部分和启动
func (n *Node) initLeaderState() {
for _, peerId := range n.Nodes {
n.NextIndex[peerId] = len(n.Log) // 发送日志的下一个索引
n.MatchIndex[peerId] = 0 // 复制日志的最新匹配索引
}
}
func Start(node *Node, quitChan chan struct{}) {
node.Mu.Lock()
node.State = Follower // 所有节点以 Follower 状态启动
node.Mu.Unlock()
node.ResetElectionTimer() // 启动选举超时定时器
go func() {
defer logprovider.DebugTraceback("start")
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.SelfId)
return // 退出 goroutine
case <-ticker.C:
node.Mu.Lock()
state := node.State
node.Mu.Unlock()
switch state {
case Follower:
// 监听心跳超时
case Leader:
// 发送心跳
node.ResetElectionTimer() // leader 不主动触发选举
node.BroadCastKV()
}
}
}
}()
}

+ 19
- 6
internal/nodes/log.go Dosyayı Görüntüle

@ -1,19 +1,32 @@
package nodes
const (
Normal State = iota + 1
Delay
Fail
)
import "strconv"
type LogEntry struct {
Key string
Value string
}
func (LogE *LogEntry) print() string {
return "key: " + LogE.Key + ", value: " + LogE.Value
}
type RaftLogEntry struct {
LogE LogEntry
LogId int
Term int
}
func (RLogE *RaftLogEntry) print() string {
return "logid: " + strconv.Itoa(RLogE.LogId) + ", term: " + strconv.Itoa(RLogE.Term) + ", " + RLogE.LogE.print()
}
type LogEntryCall struct {
Id LogEntryCallId
LogE LogEntry
CallState State
}
type LogEntryCallId struct {
ClientId string
LogId int
}
type KVReply struct {

+ 43
- 72
internal/nodes/node.go Dosyayı Görüntüle

@ -1,13 +1,10 @@
package nodes
import (
"math/rand"
"net/rpc"
"strconv"
"sync"
"time"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
type State = uint8
@ -18,84 +15,58 @@ const (
Leader
)
type Public_node_info struct {
connect bool
address string
}
type Node struct {
Mu sync.Mutex
MuElection sync.Mutex
// 当前节点id
selfId string
SelfId string
// 记录的leader(不能用votedfor:投票的leader可能没有收到多数票)
LeaderId string
// 除当前节点外其他节点信息
nodes map[string]*Public_node_info
//管道名
pipeAddr string
// 除当前节点外其他节点id
Nodes []string
// 当前节点状态
state State
State State
// 任期
CurrTerm int
// 简单的kv存储
log map[int]LogEntry
Log []RaftLogEntry
// leader用来标记新log
maxLogId int
// leader用来标记新log, = log.len
MaxLogId int
db *leveldb.DB
}
// 已提交的index
CommitIndex int
func (node *Node) BroadCastKV(logId int, kvCall LogEntryCall) {
// 遍历所有节点
for id, _ := range node.nodes {
go func(id string, kv LogEntryCall) {
var reply KVReply
node.sendKV(id, logId, kvCall, &reply)
}(id, kvCall)
}
}
// 最后应用(写到db)的index
LastApplied int
func (node *Node) sendKV(id string, logId int, kvCall LogEntryCall, reply *KVReply) {
switch kvCall.CallState {
case Fail:
log.Info("模拟发送失败")
// 这么写向所有的node发送都失败,也可以随机数确定是否失败
case Delay:
log.Info("模拟发送延迟")
// 随机延迟0-5ms
time.Sleep(time.Millisecond * time.Duration(rand.Intn(5)))
default:
}
client, err := rpc.DialHTTP("tcp", node.nodes[id].address)
if err != nil {
log.Error("dialing: ", zap.Error(err))
return
}
defer func(client *rpc.Client) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
arg := LogIdAndEntry{logId, kvCall.LogE}
callErr := client.Call("Node.ReceiveKV", arg, reply) // RPC
if callErr != nil {
log.Error("dialing node_"+id+"fail: ", zap.Error(callErr))
}
}
// 需要发送给每个节点的下一个索引
NextIndex map[string]int
// RPC call
func (node *Node) ReceiveKV(arg LogIdAndEntry, reply *KVReply) error {
log.Info("node_" + node.selfId + " receive: logId = " + strconv.Itoa(arg.LogId) + ", key = " + arg.Entry.Key)
entry, ok := node.log[arg.LogId]
if !ok {
node.log[arg.LogId] = entry
}
// 持久化
node.db.Put([]byte(arg.Entry.Key), []byte(arg.Entry.Value), nil)
reply.Reply = true // rpc call需要有reply,但实际上调用是否成功是error返回值决定
return nil
// 已经发送给每个节点的最大索引
MatchIndex map[string]int
// 存kv(模拟状态机)
Db *leveldb.DB
// 持久化节点数据(currterm votedfor log)
Storage *RaftStorage
VotedFor string
ElectionTimer *time.Timer
// 通信方式
Transport Transport
// 系统的随机时间
RTTable *RandomTimeTable
// 已经处理过的客户端请求
SeenRequests map[LogEntryCallId]bool
IsFinish bool
}

+ 194
- 0
internal/nodes/node_storage.go Dosyayı Görüntüle

@ -0,0 +1,194 @@
package nodes
import (
"encoding/json"
"strconv"
"strings"
"sync"
"github.com/syndtr/goleveldb/leveldb"
"go.uber.org/zap"
)
// RaftStorage 结构,持久化 currentTerm、votedFor 和 logEntries
type RaftStorage struct {
mu sync.Mutex
db *leveldb.DB
filePath string
isfinish bool
}
// NewRaftStorage 创建 Raft 存储
func NewRaftStorage(filePath string) *RaftStorage {
db, err := leveldb.OpenFile(filePath, nil)
if err != nil {
log.Fatal("无法打开 LevelDB:", zap.Error(err))
}
return &RaftStorage{
db: db,
filePath: filePath,
isfinish: false,
}
}
// SetCurrentTerm 设置当前 term
func (rs *RaftStorage) SetCurrentTerm(term int) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
err := rs.db.Put([]byte("current_term"), []byte(strconv.Itoa(term)), nil)
if err != nil {
log.Error("SetCurrentTerm 持久化失败:", zap.Error(err))
}
}
// GetCurrentTerm 获取当前 term
func (rs *RaftStorage) GetCurrentTerm() int {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("current_term"), nil)
if err != nil {
return 0 // 默认 term = 0
}
term, _ := strconv.Atoi(string(data))
return term
}
// SetVotedFor 记录投票给谁
func (rs *RaftStorage) SetVotedFor(candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
err := rs.db.Put([]byte("voted_for"), []byte(candidate), nil)
if err != nil {
log.Error("SetVotedFor 持久化失败:", zap.Error(err))
}
}
// GetVotedFor 获取投票对象
func (rs *RaftStorage) GetVotedFor() string {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("voted_for"), nil)
if err != nil {
return ""
}
return string(data)
}
// SetTermAndVote 原子更新 term 和 vote
func (rs *RaftStorage) SetTermAndVote(term int, candidate string) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
batch.Put([]byte("current_term"), []byte(strconv.Itoa(term)))
batch.Put([]byte("voted_for"), []byte(candidate))
err := rs.db.Write(batch, nil) // 原子提交
if err != nil {
log.Error("SetTermAndVote 持久化失败:", zap.Error(err))
}
}
// AppendLog 追加日志
func (rs *RaftStorage) AppendLog(entry RaftLogEntry) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.db == nil {
return
}
// 序列化日志
batch := new(leveldb.Batch)
data, _ := json.Marshal(entry)
key := "log_" + strconv.Itoa(entry.LogId)
batch.Put([]byte(key), data)
lastIndex := strconv.Itoa(entry.LogId)
batch.Put([]byte("last_log_index"), []byte(lastIndex))
err := rs.db.Write(batch, nil)
if err != nil {
log.Error("AppendLog 持久化失败:", zap.Error(err))
}
}
// GetLastLogIndex 获取最新日志的 index
func (rs *RaftStorage) GetLastLogIndex() int {
rs.mu.Lock()
defer rs.mu.Unlock()
data, err := rs.db.Get([]byte("last_log_index"), nil)
if err != nil {
return -1
}
index, _ := strconv.Atoi(string(data))
return index
}
// WriteLog 批量写入日志(保证原子性)
func (rs *RaftStorage) WriteLog(entries []RaftLogEntry) {
if len(entries) == 0 {
return
}
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.isfinish {
return
}
batch := new(leveldb.Batch)
for _, entry := range entries {
data, _ := json.Marshal(entry)
key := "log_" + strconv.Itoa(entry.LogId)
batch.Put([]byte(key), data)
}
// 更新最新日志索引
lastIndex := strconv.Itoa(entries[len(entries)-1].LogId)
batch.Put([]byte("last_log_index"), []byte(lastIndex))
err := rs.db.Write(batch, nil)
if err != nil {
log.Error("WriteLog 持久化失败:", zap.Error(err))
}
}
// GetLogEntries 获取所有日志
func (rs *RaftStorage) GetLogEntries() []RaftLogEntry {
rs.mu.Lock()
defer rs.mu.Unlock()
var logs []RaftLogEntry
iter := rs.db.NewIterator(nil, nil) // 遍历所有键值
defer iter.Release()
for iter.Next() {
key := string(iter.Key())
if strings.HasPrefix(key, "log_") { // 过滤日志 key
var entry RaftLogEntry
if err := json.Unmarshal(iter.Value(), &entry); err == nil {
logs = append(logs, entry)
} else {
log.Error("解析日志失败:", zap.Error(err))
}
}
}
return logs
}
// Close 关闭数据库
func (rs *RaftStorage) Close() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.db.Close()
rs.isfinish = true
}

+ 46
- 0
internal/nodes/random_timetable.go Dosyayı Görüntüle

@ -0,0 +1,46 @@
package nodes
import (
"math/rand"
"sync"
"time"
)
type RandomTimeTable struct {
Mu sync.Mutex
electionTimeOut time.Duration
israndom bool
// heartbeat 50ms
// rpcTimeout 50ms
// follower变candidate 500ms
// 等待选举成功时间 300ms
}
func NewRTTable() *RandomTimeTable {
return &RandomTimeTable{
israndom: true,
}
}
func (rttable *RandomTimeTable) GetElectionTimeout() time.Duration {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
if rttable.israndom {
return time.Duration(500+rand.Intn(500)) * time.Millisecond
} else {
return rttable.electionTimeOut
}
}
func (rttable *RandomTimeTable) SetElectionTimeout(t time.Duration) {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
rttable.israndom = false
rttable.electionTimeOut = t
}
func (rttable *RandomTimeTable) ResetElectionTimeout() {
rttable.Mu.Lock()
defer rttable.Mu.Unlock()
rttable.israndom = true
}

+ 111
- 0
internal/nodes/real_transport.go Dosyayı Görüntüle

@ -0,0 +1,111 @@
package nodes
import (
"errors"
"fmt"
"net/rpc"
"time"
)
// 真实rpc通讯的transport类型实现
type HTTPTransport struct{
NodeMap map[string]string // id到addr的映射
}
// 封装有超时的dial
func (t *HTTPTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
done := make(chan struct{})
var client *rpc.Client
var err error
go func() {
client, err = rpc.DialHTTP(network, t.NodeMap[peerId])
close(done)
}()
select {
case <-done:
return &HTTPClient{rpcClient: client}, err
case <-time.After(50 * time.Millisecond):
return nil, fmt.Errorf("dial timeout: %s", t.NodeMap[peerId])
}
}
func (t *HTTPTransport) CallWithTimeout(clientInterface ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
c, ok := clientInterface.(*HTTPClient)
client := c.rpcClient
if !ok {
return fmt.Errorf("invalid client type")
}
done := make(chan error, 1)
go func() {
switch serviceMethod {
case "Node.AppendEntries":
arg, ok := args.(*AppendEntriesArg)
resp, ok2 := reply.(*AppendEntriesReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for AppendEntries")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.RequestVote":
arg, ok := args.(*RequestVoteArgs)
resp, ok2 := reply.(*RequestVoteReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for RequestVote")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.WriteKV":
arg, ok := args.(*LogEntryCall)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for WriteKV")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.ReadKey":
arg, ok := args.(*string)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for ReadKey")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.FindLeader":
arg, ok := args.(struct{})
resp, ok2 := reply.(*FindLeaderReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for FindLeader")
return
}
done <- client.Call(serviceMethod, arg, resp)
default:
done <- fmt.Errorf("unknown service method: %s", serviceMethod)
}
}()
select {
case err := <-done:
return err
case <-time.After(50 * time.Millisecond): // 设置超时时间
return fmt.Errorf("call timeout: %s", serviceMethod)
}
}
type HTTPClient struct {
rpcClient *rpc.Client
}
func (h *HTTPClient) Close() error {
return h.rpcClient.Close()
}

+ 271
- 0
internal/nodes/replica.go Dosyayı Görüntüle

@ -0,0 +1,271 @@
package nodes
import (
"simple-kv-store/internal/logprovider"
"sort"
"strconv"
"sync"
"go.uber.org/zap"
)
type AppendEntriesArg struct {
Term int
LeaderId string
PrevLogIndex int
PrevLogTerm int
Entries []RaftLogEntry
LeaderCommit int
}
type AppendEntriesReply struct {
Term int
Success bool
}
// leader收到新内容要广播,以及心跳广播(同步自己的log)
func (node *Node) BroadCastKV() {
log.Sugar().Infof("leader[%s]广播消息", node.SelfId)
defer logprovider.DebugTraceback("broadcast")
failCount := 0
// 这里增加一个锁,防止并发修改成功计数
var failMutex sync.Mutex
// 遍历所有节点
for _, id := range node.Nodes {
go func(id string) {
defer logprovider.DebugTraceback("send")
node.sendKV(id, &failCount, &failMutex)
}(id)
}
}
func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) {
node.Mu.Lock()
selfId := node.SelfId
node.Mu.Unlock()
client, err := node.Transport.DialHTTPWithTimeout("tcp", selfId, peerId)
if err != nil {
node.Mu.Lock()
log.Error("[" + node.SelfId + "]dialling [" + peerId + "] fail: ", zap.Error(err))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
node.LeaderId = ""
node.State = Follower
node.ResetElectionTimer()
}
failMutex.Unlock()
node.Mu.Unlock()
return
}
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
node.Mu.Lock()
NextIndex := node.NextIndex[peerId]
// log.Info("NextIndex " + strconv.Itoa(NextIndex))
for {
if NextIndex < 0 {
log.Fatal("assert >= 0 here")
}
sendEntries := node.Log[NextIndex:]
arg := AppendEntriesArg{
Term: node.CurrTerm,
PrevLogIndex: NextIndex - 1,
Entries: sendEntries,
LeaderCommit: node.CommitIndex,
LeaderId: node.SelfId,
}
if arg.PrevLogIndex >= 0 {
arg.PrevLogTerm = node.Log[arg.PrevLogIndex].Term
}
// 记录关键数据后解锁
currTerm := node.CurrTerm
currState := node.State
MaxLogId := node.MaxLogId
var appendReply AppendEntriesReply
appendReply.Success = false
node.Mu.Unlock()
callErr := node.Transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
node.Mu.Lock()
if node.CurrTerm != currTerm || node.MaxLogId != MaxLogId || node.State != currState {
node.Mu.Unlock()
return
}
if callErr != nil {
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
log.Info("term=" + strconv.Itoa(node.CurrTerm) + "的Leader[" + node.SelfId + "]无法联系到半数节点, 降级为 Follower")
node.LeaderId = ""
node.State = Follower
node.ResetElectionTimer()
}
failMutex.Unlock()
node.Mu.Unlock()
return
}
if appendReply.Term != node.CurrTerm {
log.Sugar().Infof("term=%s的leader[%s]因为[%s]收到更高的term=%s, 转换为follower",
strconv.Itoa(node.CurrTerm), node.SelfId, peerId, strconv.Itoa(appendReply.Term))
node.LeaderId = ""
node.CurrTerm = appendReply.Term
node.State = Follower
node.VotedFor = ""
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
node.ResetElectionTimer()
node.Mu.Unlock()
return
}
if appendReply.Success {
break
}
NextIndex-- // 失败往前传一格
}
// 不变成follower情况下
node.NextIndex[peerId] = node.MaxLogId + 1
node.MatchIndex[peerId] = node.MaxLogId
node.updateCommitIndex()
node.Mu.Unlock()
}
func (node *Node) updateCommitIndex() {
if node.Mu.TryLock() {
log.Fatal("这里要保证有锁")
}
if node.IsFinish {
return
}
totalNodes := len(node.Nodes)
// 收集所有 MatchIndex 并排序
MatchIndexes := make([]int, 0, totalNodes)
for _, index := range node.MatchIndex {
MatchIndexes = append(MatchIndexes, index)
}
sort.Ints(MatchIndexes) // 排序
// 计算多数派 CommitIndex
majorityIndex := MatchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派)
// 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志
if majorityIndex > node.CommitIndex && majorityIndex < len(node.Log) && node.Log[majorityIndex].Term == node.CurrTerm {
node.CommitIndex = majorityIndex
log.Info("Leader[" + node.SelfId + "]更新 CommitIndex: " + strconv.Itoa(majorityIndex))
// 应用日志到状态机
node.applyCommittedLogs()
}
}
// 应用日志到状态机
func (node *Node) applyCommittedLogs() {
for node.LastApplied < node.CommitIndex {
node.LastApplied++
logEntry := node.Log[node.LastApplied]
log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.SelfId)
err := node.Db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil)
if err != nil {
log.Error(node.SelfId + "应用状态机失败: ", zap.Error(err))
}
}
}
// RPC call
func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply) error {
defer logprovider.DebugTraceback("append")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]在term=%d收到[%s]的AppendEntries", node.SelfId, node.CurrTerm, arg.LeaderId)
// 如果 term 过期,拒绝接受日志
if node.CurrTerm > arg.Term {
reply.Term = node.CurrTerm
reply.Success = false
return nil
}
node.LeaderId = arg.LeaderId // 记录Leader
// 如果term比自己高,或自己不是follower但收到相同term的心跳
if node.CurrTerm < arg.Term || node.State != Follower {
log.Sugar().Infof("[%s]发现更高 term(%s)", node.SelfId, strconv.Itoa(arg.Term))
node.CurrTerm = arg.Term
node.State = Follower
node.VotedFor = ""
// node.storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
}
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
// 检查 prevLogIndex 是否有效
if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) {
reply.Term = node.CurrTerm
reply.Success = false
return nil
}
// 处理日志冲突(如果存在不同 term,则截断日志)
idx := arg.PrevLogIndex + 1
for i := idx; i < len(node.Log) && i-idx < len(arg.Entries); i++ {
if node.Log[i].Term != arg.Entries[i-idx].Term {
node.Log = node.Log[:idx]
break
}
}
// log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.Log)))
// 追加新的日志条目
for _, raftLogEntry := range arg.Entries {
log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.SelfId)
if idx < len(node.Log) {
node.Log[idx] = raftLogEntry
} else {
node.Log = append(node.Log, raftLogEntry)
}
idx++
}
// 暴力持久化
node.Storage.WriteLog(node.Log)
// 更新 MaxLogId
node.MaxLogId = len(node.Log) - 1
// 更新 CommitIndex
if arg.LeaderCommit < node.MaxLogId {
node.CommitIndex = arg.LeaderCommit
} else {
node.CommitIndex = node.MaxLogId
}
// 提交已提交的日志
node.applyCommittedLogs()
// 在成功接受日志或心跳后,重置选举超时
node.ResetElectionTimer()
reply.Term = node.CurrTerm
reply.Success = true
return nil
}

+ 68
- 16
internal/nodes/server_node.go Dosyayı Görüntüle

@ -1,42 +1,94 @@
package nodes
import (
"strconv"
"simple-kv-store/internal/logprovider"
"github.com/syndtr/goleveldb/leveldb"
)
// leader node作为server为client注册的方法
type ServerReply struct{
Isconnect bool
Isleader bool
LeaderId string // 自己不是leader则返回leader
HaveValue bool
Value string
}
// RPC call
func (node *Node) WriteKV(kvCall LogEntryCall, reply *ServerReply) error {
logId := node.maxLogId
node.maxLogId++
node.log[logId] = kvCall.LogE
node.db.Put([]byte(kvCall.LogE.Key), []byte(kvCall.LogE.Value), nil)
log.Info("server write : logId = " + strconv.Itoa(logId) + ", key = " + kvCall.LogE.Key)
func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error {
defer logprovider.DebugTraceback("write")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]收到客户端write请求", node.SelfId)
// 自己不是leader,转交leader地址回复
if node.State != Leader {
reply.Isleader = false
reply.LeaderId = node.LeaderId // 可能是空,那client就随机再找一个节点
log.Sugar().Infof("[%s]转交给[%s]", node.SelfId, node.LeaderId)
return nil
}
if node.SeenRequests[kvCall.Id] {
log.Sugar().Infof("Leader [%s] 已处理过client[%s]的请求 %d, 跳过", node.SelfId, kvCall.Id.ClientId, kvCall.Id.LogId)
reply.Isleader = true
return nil
}
node.SeenRequests[kvCall.Id] = true
// 自己是leader,修改自己的记录并广播
node.MaxLogId++
logId := node.MaxLogId
rLogE := RaftLogEntry{kvCall.LogE, logId, node.CurrTerm}
node.Log = append(node.Log, rLogE)
node.Storage.AppendLog(rLogE)
log.Info("leader[" + node.SelfId + "]处理请求 : " + kvCall.LogE.print())
// 广播给其它节点
node.BroadCastKV(logId, kvCall)
reply.Isconnect = true
node.BroadCastKV()
reply.Isleader = true
return nil
}
// RPC call
func (node *Node) ReadKey(key string, reply *ServerReply) error {
log.Info("server read : " + key)
// 先只读leader自己
value, err := node.db.Get([]byte(key), nil)
func (node *Node) ReadKey(key *string, reply *ServerReply) error {
defer logprovider.DebugTraceback("read")
node.Mu.Lock()
defer node.Mu.Unlock()
log.Sugar().Infof("[%s]收到客户端read请求", node.SelfId)
// 先只读自己(无论自己是不是leader),也方便测试
value, err := node.Db.Get([]byte(*key), nil)
if err == leveldb.ErrNotFound {
reply.HaveValue = false
} else {
reply.HaveValue = true
reply.Value = string(value)
}
reply.Isconnect = true
reply.Isleader = true
return nil
}
// RPC call 测试中寻找当前leader
type FindLeaderReply struct{
Isleader bool
LeaderId string
}
func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error {
defer logprovider.DebugTraceback("find")
node.Mu.Lock()
defer node.Mu.Unlock()
// 自己不是leader,转交leader地址回复
if node.State != Leader {
reply.Isleader = false
if (node.LeaderId == "") {
log.Fatal("还没选出第一个leader")
return nil
}
reply.LeaderId = node.LeaderId
return nil
}
reply.LeaderId = node.SelfId
reply.Isleader = true
return nil
}

+ 65
- 0
internal/nodes/simulate_ctx.go Dosyayı Görüntüle

@ -0,0 +1,65 @@
package nodes
import (
"fmt"
"sync"
"time"
)
// Ctx 结构体:管理不同节点之间的通信行为
type Ctx struct {
mu sync.Mutex
Behavior map[string]CallBehavior // (src,target) -> CallBehavior
Delay map[string]time.Duration // (src,target) -> 延迟时间
Retries map[string]int // 记录 (src,target) 的重发调用次数
}
// NewCtx 创建上下文
func NewCtx() *Ctx {
return &Ctx{
Behavior: make(map[string]CallBehavior),
Delay: make(map[string]time.Duration),
Retries: make(map[string]int),
}
}
// SetBehavior 设置 A->B 的 RPC 行为
func (c *Ctx) SetBehavior(src, dst string, behavior CallBehavior, delay time.Duration, retries int) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
c.Behavior[key] = behavior
c.Delay[key] = delay
c.Retries[key] = retries
}
// GetBehavior 获取 A->B 的行为
func (c *Ctx) GetBehavior(src, dst string) (CallBehavior) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if state, exists := c.Behavior[key]; exists {
return state
}
return NormalRpc
}
func (c *Ctx) GetDelay(src, dst string) (t time.Duration, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if t, ok = c.Delay[key]; ok {
return t, ok
}
return 0, ok
}
func (c *Ctx) GetRetries(src, dst string) (times int, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
key := fmt.Sprintf("%s->%s", src, dst)
if times, ok = c.Retries[key]; ok {
return times, ok
}
return 0, ok
}

+ 234
- 0
internal/nodes/thread_transport.go Dosyayı Görüntüle

@ -0,0 +1,234 @@
package nodes
import (
"fmt"
"sync"
"time"
)
type CallBehavior = uint8
const (
NormalRpc CallBehavior = iota + 1
DelayRpc
RetryRpc
FailRpc
)
// RPC 请求结构
type RPCRequest struct {
ServiceMethod string
Args interface{}
Reply interface{}
Done chan error // 用于返回响应
SourceId string
// 模拟rpc请求状态
Behavior CallBehavior
}
// 线程版 Transport
type ThreadTransport struct {
mu sync.Mutex
nodeChans map[string]chan RPCRequest // 每个节点的消息通道
connectivityMap map[string]map[string]bool // 模拟网络分区
Ctx *Ctx
}
// 线程版 dial的返回clientinterface
type ThreadClient struct {
SourceId string
TargetId string
}
func (c *ThreadClient) Close() error {
return nil
}
// 初始化线程通信系统
func NewThreadTransport(ctx *Ctx) *ThreadTransport {
return &ThreadTransport{
nodeChans: make(map[string]chan RPCRequest),
connectivityMap: make(map[string]map[string]bool),
Ctx: ctx,
}
}
// 注册一个新节点chan
func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) {
t.mu.Lock()
defer t.mu.Unlock()
t.nodeChans[nodeId] = ch
// 初始化连通性(默认所有节点互相可达)
if _, exists := t.connectivityMap[nodeId]; !exists {
t.connectivityMap[nodeId] = make(map[string]bool)
}
for peerId := range t.nodeChans {
t.connectivityMap[nodeId][peerId] = true
t.connectivityMap[peerId][nodeId] = true
}
}
// 设置两个节点的连通性
func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.connectivityMap[from]; exists {
t.connectivityMap[from][to] = isConnected
}
}
func (t *ThreadTransport) ResetConnectivity() {
t.mu.Lock()
defer t.mu.Unlock()
for firstId:= range t.nodeChans {
for peerId:= range t.nodeChans {
if firstId != peerId {
t.connectivityMap[firstId][peerId] = true
t.connectivityMap[peerId][firstId] = true
}
}
}
}
// 获取节点的 channel
func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) {
t.mu.Lock()
defer t.mu.Unlock()
ch, ok := t.nodeChans[nodeId]
return ch, ok
}
// 模拟 Dial 操作
func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.nodeChans[peerId]; !exists {
return nil, fmt.Errorf("节点 [%s] 不存在", peerId)
}
return &ThreadClient{SourceId: myId, TargetId: peerId}, nil
}
// 模拟 Call 操作
func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
threadClient, ok := client.(*ThreadClient)
if !ok {
return fmt.Errorf("无效的caller")
}
var isConnected bool
if threadClient.SourceId == "" {
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId]
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
}
targetChan, exists := t.getNodeChan(threadClient.TargetId)
if !exists {
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId)
}
done := make(chan error, 1)
behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId)
// 辅助函数:复制 replyCopy 到原始 reply
copyReply := func(dst, src interface{}) {
switch d := dst.(type) {
case *AppendEntriesReply:
*d = *(src.(*AppendEntriesReply))
case *RequestVoteReply:
*d = *(src.(*RequestVoteReply))
}
}
sendRequest := func(req RPCRequest, ch chan RPCRequest) bool {
select {
case ch <- req:
return true
default:
return false
}
}
switch behavior {
case RetryRpc:
retryTimes, ok := t.Ctx.GetRetries(threadClient.SourceId, threadClient.TargetId)
if !ok {
log.Fatal("没有设置对应的retry次数")
}
var lastErr error
for i := 0; i < retryTimes; i++ {
var replyCopy interface{}
useCopy := true
switch r := reply.(type) {
case *AppendEntriesReply:
tmp := *r
replyCopy = &tmp
case *RequestVoteReply:
tmp := *r
replyCopy = &tmp
default:
replyCopy = reply // 其他类型不复制
useCopy = false
}
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: replyCopy,
Done: done,
SourceId: threadClient.SourceId,
Behavior: NormalRpc,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
select {
case err := <-done:
if err == nil && useCopy {
copyReply(reply, replyCopy)
}
if err == nil {
return nil
}
lastErr = err
case <-time.After(250 * time.Millisecond):
lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
return lastErr
default:
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
SourceId: threadClient.SourceId,
Behavior: behavior,
}
if !sendRequest(request, targetChan) {
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
}
select {
case err := <-done:
return err
case <-time.After(250 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
}
}

+ 10
- 0
internal/nodes/transport.go Dosyayı Görüntüle

@ -0,0 +1,10 @@
package nodes
type ClientInterface interface{
Close() error
}
type Transport interface {
DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
}

+ 215
- 0
internal/nodes/vote.go Dosyayı Görüntüle

@ -0,0 +1,215 @@
package nodes
import (
"math/rand"
"simple-kv-store/internal/logprovider"
"strconv"
"sync"
"time"
"go.uber.org/zap"
)
type RequestVoteArgs struct {
Term int // 候选人的当前任期
CandidateId string // 候选人 ID
LastLogIndex int // 候选人最后一条日志的索引
LastLogTerm int // 候选人最后一条日志的任期
}
type RequestVoteReply struct {
Term int // 当前节点的最新任期
VoteGranted bool // 是否同意投票
}
func (n *Node) StartElection() {
defer logprovider.DebugTraceback("startElection")
n.Mu.Lock()
if n.IsFinish {
n.Mu.Unlock()
return
}
// 增加当前任期,转换为 Candidate
n.CurrTerm++
n.State = Candidate
n.VotedFor = n.SelfId // 自己投自己
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.SelfId, n.CurrTerm)
// 重新设置选举超时,防止重复选举
n.ResetElectionTimer()
// 构造 RequestVote 请求
var lastLogIndex int
var lastLogTerm int
if len(n.Log) == 0 {
lastLogIndex = 0
lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0
} else {
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
}
args := RequestVoteArgs{
Term: n.CurrTerm,
CandidateId: n.SelfId,
LastLogIndex: lastLogIndex,
LastLogTerm: lastLogTerm,
}
// 并行向其他节点发送请求投票
var Mu sync.Mutex
totalNodes := len(n.Nodes)
grantedVotes := 1 // 自己的票
currTerm := n.CurrTerm
currState := n.State
n.Mu.Unlock()
for _, peerId := range n.Nodes {
go func(peerId string) {
defer logprovider.DebugTraceback("vote")
var reply RequestVoteReply
if n.sendRequestVote(peerId, &args, &reply) {
Mu.Lock()
defer Mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
if currTerm != n.CurrTerm || currState != n.State {
return
}
if reply.Term > n.CurrTerm {
// 发现更高任期,回退为 Follower
log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term)
n.CurrTerm = reply.Term
n.State = Follower
n.VotedFor = ""
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
n.ResetElectionTimer()
return
}
if reply.VoteGranted {
grantedVotes++
}
if grantedVotes == totalNodes / 2 + 1 {
n.State = Leader
log.Sugar().Infof("[%s] 当选 Leader!", n.SelfId)
n.initLeaderState()
}
}
}(peerId)
}
// 等待选举结果
time.Sleep(300 * time.Millisecond)
Mu.Lock()
defer Mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
if n.State == Candidate {
log.Sugar().Infof("[%s] 选举超时,等待后将重新发起选举", n.SelfId)
// n.State = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower
n.ResetElectionTimer()
}
}
func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool {
log.Sugar().Infof("[%s] 请求 [%s] 投票", node.SelfId, peerId)
client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId)
if err != nil {
log.Error("[" + node.SelfId + "]dialing [" + peerId + "] fail: ", zap.Error(err))
return false
}
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
callErr := node.Transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
if callErr != nil {
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
}
return callErr == nil
}
func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error {
defer logprovider.DebugTraceback("requestVote")
n.Mu.Lock()
defer n.Mu.Unlock()
// 如果候选人的任期小于当前任期,则拒绝投票
if args.Term < n.CurrTerm {
reply.Term = n.CurrTerm
reply.VoteGranted = false
return nil
}
// 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower
if args.Term > n.CurrTerm {
n.CurrTerm = args.Term
n.State = Follower
n.VotedFor = ""
n.ResetElectionTimer() // 重新设置选举超时
}
// 检查是否已经投过票,且是否投给了同一个候选人
if n.VotedFor == "" || n.VotedFor == args.CandidateId {
// 检查日志是否足够新
var lastLogIndex int
var lastLogTerm int
if len(n.Log) == 0 {
lastLogIndex = -1
lastLogTerm = 0
} else {
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
}
if args.LastLogTerm > lastLogTerm ||
(args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) {
// 够新就投票给候选人
n.VotedFor = args.CandidateId
log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.CurrTerm), n.SelfId, n.VotedFor)
reply.VoteGranted = true
n.ResetElectionTimer()
} else {
reply.VoteGranted = false
}
} else {
reply.VoteGranted = false
}
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
reply.Term = n.CurrTerm
return nil
}
// follower 一段时间内没收到appendentries心跳,就变成candidate发起选举
func (node *Node) ResetElectionTimer() {
node.MuElection.Lock()
defer node.MuElection.Unlock()
if node.ElectionTimer == nil {
node.ElectionTimer = time.NewTimer(node.RTTable.GetElectionTimeout())
go func() {
defer logprovider.DebugTraceback("reset")
for {
<-node.ElectionTimer.C
node.StartElection()
}
}()
} else {
node.ElectionTimer.Stop()
node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
}
}

BIN
pics/plan.png Dosyayı Görüntüle

Önce Sonra
Genişlik: 2046  |  Yükseklik: 966  |  Boyut: 160 KiB

BIN
pics/plus.png Dosyayı Görüntüle

Önce Sonra
Genişlik: 1327  |  Yükseklik: 855  |  Boyut: 185 KiB

BIN
pics/robust.png Dosyayı Görüntüle

Önce Sonra
Genişlik: 1151  |  Yükseklik: 539  |  Boyut: 81 KiB

BIN
raft第二次汇报.pptx Dosyayı Görüntüle


+ 0
- 61
scripts/run.sh Dosyayı Görüntüle

@ -1,61 +0,0 @@
#!/bin/bash
# 设置运行时间限制:s
RUN_TIME=10
# 需要传递数据的管道
PIPE_NAME="/tmp/input_pipe"
# 启动节点1
echo "Starting Node 1..."
timeout $RUN_TIME ./main -id 1 -port ":9091" -cluster "127.0.0.1:9092,127.0.0.1:9093" -pipe "$PIPE_NAME" -isleader=true &
# 启动节点2
echo "Starting Node 2..."
timeout $RUN_TIME ./main -id 2 -port ":9092" -cluster "127.0.0.1:9091,127.0.0.1:9093" -pipe "$PIPE_NAME" &
# 启动节点3
echo "Starting Node 3..."
timeout $RUN_TIME ./main -id 3 -port ":9093" -cluster "127.0.0.1:9091,127.0.0.1:9092" -pipe "$PIPE_NAME"&
echo "All nodes started successfully!"
# 创建一个管道用于进程间通信
if [[ ! -p "$PIPE_NAME" ]]; then
mkfifo "$PIPE_NAME"
fi
# 捕获终端输入并通过管道传递给三个节点
echo "Enter input to send to nodes:"
start_time=$(date +%s)
while true; do
# 从终端读取用户输入
read -r user_input
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
# 如果运行时间大于限制时间,就退出
if [ $elapsed_time -ge $RUN_TIME ]; then
echo 'Timeout reached, normal exit now'
break
fi
# 如果输入为空,跳过
if [[ -z "$user_input" ]]; then
continue
fi
# 将用户输入发送到管道
echo "$user_input" > "$PIPE_NAME"
# 如果输入 "exit",结束脚本
if [[ "$user_input" == "exit" ]]; then
break
fi
done
# 删除管道
rm "$PIPE_NAME"
# 等待所有节点完成启动
wait

+ 7
- 15
test/common.go Dosyayı Görüntüle

@ -8,29 +8,21 @@ import (
"strings"
)
func ExecuteNodeI(i int, isLeader bool, isNewDb bool, clusters []string) *exec.Cmd {
tmpClusters := append(clusters[:i], clusters[i+1:]...)
func ExecuteNodeI(i int, isRestart bool, clusters []string) *exec.Cmd {
port := fmt.Sprintf(":%d", uint16(9090)+uint16(i))
var isleader string
if isLeader {
isleader = "true"
var isRestartStr string
if isRestart {
isRestartStr = "true"
} else {
isleader = "false"
}
var isnewdb string
if isNewDb {
isnewdb = "true"
} else {
isnewdb = "false"
isRestartStr = "false"
}
cmd := exec.Command(
"../main",
"-id", strconv.Itoa(i + 1),
"-port", port,
"-cluster", strings.Join(tmpClusters, ","),
"-isleader=" + isleader,
"-isNewDb=" + isnewdb,
"-cluster", strings.Join(clusters, ","),
"-isRestart=" + isRestartStr,
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

+ 0
- 112
test/restart_follower_test.go Dosyayı Görüntüle

@ -1,112 +0,0 @@
package test
import (
"fmt"
"os/exec"
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"syscall"
"testing"
"time"
)
func TestFollowerRestart(t *testing.T) {
// 登记结点信息
n := 3
var clusters []string
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
var cmd *exec.Cmd
if i == 0 {
cmd = ExecuteNodeI(i, true, true, clusters)
} else {
cmd = ExecuteNodeI(i, false, true, clusters)
}
if cmd == nil {
return
} else {
cmds = append(cmds, cmd)
}
}
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接leader
cWrite := clientPkg.Client{Address: clusters[0], ServerId: "1"}
// 写入
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 模拟最后一个结点崩溃
err := cmds[n - 1].Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
// 继续写入
for i := 5; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
// 恢复结点
cmd := ExecuteNodeI(n - 1, false, false, clusters)
if cmd == nil {
t.Errorf("recover test1 fail")
return
} else {
cmds[n - 1] = cmd
}
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接节点n-1(去读它的数据)
cRead := clientPkg.Client{Address: clusters[n - 1], ServerId: "n"}
// 读崩溃前写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
// 通知进程结束
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}

+ 104
- 0
test/restart_node_test.go Dosyayı Görüntüle

@ -0,0 +1,104 @@
package test
import (
"fmt"
"os/exec"
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"syscall"
"testing"
"time"
)
func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
// 通知所有进程结束
defer func(){
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
transport := &nodes.HTTPTransport{NodeMap: addressMap}
cWrite := clientPkg.NewClient("0", peerIds, transport)
// 写入
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 模拟结点轮流崩溃
for i := 0; i < n; i++ {
err := cmds[i].Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
time.Sleep(time.Second)
cmd := ExecuteNodeI(i, true, clusters)
if cmd == nil {
t.Errorf("recover test1 fail")
return
} else {
cmds[i] = cmd
}
time.Sleep(time.Second) // 等待启动完毕
}
// client启动
cRead := clientPkg.NewClient("0", peerIds, transport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}

+ 22
- 25
test/server_client_test.go Dosyayı Görüntüle

@ -15,50 +15,57 @@ func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
var cmd *exec.Cmd
if i == 0 {
cmd = ExecuteNodeI(i, true, true, clusters)
} else {
cmd = ExecuteNodeI(i, false, true, clusters)
}
if cmd == nil {
return
} else {
cmds = append(cmds, cmd)
}
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
// 通知所有进程结束
defer func(){
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.Client{Address: "127.0.0.1:9090", ServerId: "1"}
transport := &nodes.HTTPTransport{NodeMap: addressMap}
c := clientPkg.NewClient("0", peerIds, transport)
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s := c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok && value != "hello" + key {
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
@ -72,14 +79,4 @@ func TestServerClient(t *testing.T) {
t.Errorf("Read test2 fail")
}
}
// 通知进程结束
for _, cmd := range cmds {
err := cmd.Process.Signal(syscall.SIGTERM)
if err != nil {
fmt.Println("Error sending signal:", err)
return
}
}
}

+ 305
- 0
threadTest/common.go Dosyayı Görüntüle

@ -0,0 +1,305 @@
package threadTest
import (
"fmt"
"os"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
"github.com/syndtr/goleveldb/leveldb"
)
func ExecuteNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) {
if !isRestart {
os.RemoveAll("storage/node" + id)
}
// 创建临时目录用于 leveldb
dbPath, err := os.MkdirTemp("", "simple-kv-store-"+id+"-")
if err != nil {
panic(fmt.Sprintf("无法创建临时数据库目录: %s", err))
}
// 创建临时目录用于 storage
storagePath, err := os.MkdirTemp("", "raft-storage-"+id+"-")
if err != nil {
panic(fmt.Sprintf("无法创建临时存储目录: %s", err))
}
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
panic(fmt.Sprintf("Failed to open database: %s", err))
}
// 初始化 Raft 存储
storage := nodes.NewRaftStorage(storagePath)
var otherIds []string
for _, ids := range peerIds {
if ids != id {
otherIds = append(otherIds, ids) // 删除目标元素
}
}
// 初始化
node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport)
// 开启 raft
go nodes.Start(node, quitChan)
return node, quitChan
}
func ExecuteStaticNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) {
if !isRestart {
os.RemoveAll("storage/node" + id)
}
os.RemoveAll("leveldb/simple-kv-store" + id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store"+id, nil)
if err != nil {
fmt.Println("Failed to open database: ", err)
}
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + id)
var otherIds []string
for _, ids := range peerIds {
if ids != id {
otherIds = append(otherIds, ids) // 删除目标元素
}
}
// 初始化
node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport)
// 开启 raft
// go nodes.Start(node, quitChan)
return node, quitChan
}
func StopElectionReset(nodeCollections []*nodes.Node) {
for i := 0; i < len(nodeCollections); i++ {
node := nodeCollections[i]
go func(node *nodes.Node) {
ticker := time.NewTicker(400 * time.Millisecond)
defer ticker.Stop()
for {
<-ticker.C
node.ResetElectionTimer() // 不主动触发选举
}
}(node)
}
}
func SendKvCall(kvCall *nodes.LogEntryCall, node *nodes.Node) {
node.Mu.Lock()
defer node.Mu.Unlock()
node.MaxLogId++
logId := node.MaxLogId
rLogE := nodes.RaftLogEntry{LogE: kvCall.LogE, LogId: logId, Term: node.CurrTerm}
node.Log = append(node.Log, rLogE)
node.Storage.AppendLog(rLogE)
// 广播给其它节点
node.BroadCastKV()
}
func ClientWriteLog(t *testing.T, startLogid int, endLogid int, cWrite *clientPkg.Client) {
var s clientPkg.Status
for i := startLogid; i < endLogid; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = cWrite.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
}
func FindLeader(t *testing.T, nodeCollections []*nodes.Node) (i int) {
for i, node := range nodeCollections {
if node.State == nodes.Leader {
return i
}
}
t.Errorf("系统目前没有leader")
t.FailNow()
return 0
}
func CheckOneLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt != 1 {
t.Errorf("实际有%d个leader(!=1)", cnt)
t.FailNow()
}
}
func CheckNoLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt != 0 {
t.Errorf("实际有%d个leader(!=0)", cnt)
t.FailNow()
}
}
func CheckZeroOrOneLeader(t *testing.T, nodeCollections []*nodes.Node) {
cnt := 0
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
cnt++
}
node.Mu.Unlock()
}
if cnt > 1 {
errmsg := fmt.Sprintf("%d个节点中,实际有%d个leader(>1)", len(nodeCollections), cnt)
WriteFailLog(nodeCollections[0].SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
}
func CheckIsLeader(t *testing.T, node *nodes.Node) {
node.Mu.Lock()
defer node.Mu.Unlock()
if node.State != nodes.Leader {
t.Errorf("[%s]不是leader", node.SelfId)
t.FailNow()
}
}
func CheckTerm(t *testing.T, node *nodes.Node, targetTerm int) {
node.Mu.Lock()
defer node.Mu.Unlock()
if node.CurrTerm != targetTerm {
t.Errorf("[%s]实际term=%d (!=%d)", node.SelfId, node.CurrTerm, targetTerm)
t.FailNow()
}
}
func CheckLogNum(t *testing.T, node *nodes.Node, targetnum int) {
node.Mu.Lock()
defer node.Mu.Unlock()
if len(node.Log) != targetnum {
t.Errorf("[%s]实际logNum=%d (!=%d)", node.SelfId, len(node.Log), targetnum)
t.FailNow()
}
}
func CheckSameLog(t *testing.T, nodeCollections []*nodes.Node) {
nodeCollections[0].Mu.Lock()
defer nodeCollections[0].Mu.Unlock()
standard_node := nodeCollections[0]
for i, node := range nodeCollections {
if i != 0 {
node.Mu.Lock()
if len(node.Log) != len(standard_node.Log) {
errmsg := fmt.Sprintf("[%s]和[%s]日志数量不一致", nodeCollections[0].SelfId, node.SelfId)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
for idx, log := range node.Log {
standard_log := standard_node.Log[idx]
if log.Term != standard_log.Term ||
log.LogE.Key != standard_log.LogE.Key ||
log.LogE.Value != standard_log.LogE.Value {
errmsg := fmt.Sprintf("[1]和[%s]日志id%d不一致", node.SelfId, idx)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
}
node.Mu.Unlock()
}
}
}
func CheckLeaderInvariant(t *testing.T, nodeCollections []*nodes.Node) {
leaderCnt := make(map[int]bool)
for _, node := range nodeCollections {
node.Mu.Lock()
if node.State == nodes.Leader {
if _, exist := leaderCnt[node.CurrTerm]; exist {
errmsg := fmt.Sprintf("在%d有多个leader(%s)", node.CurrTerm, node.SelfId)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
} else {
leaderCnt[node.CurrTerm] = true
}
}
node.Mu.Unlock()
}
}
func CheckLogInvariant(t *testing.T, nodeCollections []*nodes.Node) {
nodeCollections[0].Mu.Lock()
defer nodeCollections[0].Mu.Unlock()
standard_node := nodeCollections[0]
standard_len := len(standard_node.Log)
for i, node := range nodeCollections {
if i != 0 {
node.Mu.Lock()
len2 := len(node.Log)
var shorti int
if len2 < standard_len {
shorti = len2
} else {
shorti = standard_len
}
if shorti == 0 {
node.Mu.Unlock()
continue
}
alreadySame := false
for i := shorti - 1; i >= 0; i-- {
standard_log := standard_node.Log[i]
log := node.Log[i]
if alreadySame {
if log.Term != standard_log.Term ||
log.LogE.Key != standard_log.LogE.Key ||
log.LogE.Value != standard_log.LogE.Value {
errmsg := fmt.Sprintf("[%s]和[%s]日志id%d不一致", standard_node.SelfId, node.SelfId, i)
WriteFailLog(node.SelfId, errmsg)
t.Error(errmsg)
t.FailNow()
}
} else {
if log.Term == standard_log.Term &&
log.LogE.Key == standard_log.LogE.Key &&
log.LogE.Value == standard_log.LogE.Value {
alreadySame = true
}
}
}
node.Mu.Unlock()
}
}
}
func WriteFailLog(name string, errmsg string) {
f, _ := os.Create(name + ".log")
fmt.Fprint(f, errmsg)
f.Close()
}

+ 313
- 0
threadTest/election_test.go Dosyayı Görüntüle

@ -0,0 +1,313 @@
package threadTest
import (
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestInitElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
}
func TestRepeatElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 3)
}
func TestBelowHalfCandidateElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
}
func TestOverHalfCandidateElection(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
go nodeCollections[2].StartElection()
time.Sleep(time.Second)
CheckZeroOrOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
}
func TestRepeatVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
ctx.SetBehavior("1", "2", nodes.RetryRpc, 0, 2)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
ctx.SetBehavior("2", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
go nodeCollections[0].StartElection()
go nodeCollections[1].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckTerm(t, nodeCollections[0], 3)
}
func TestFailVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
ctx.SetBehavior("1", "2", nodes.FailRpc, 0, 0)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
ctx.SetBehavior("1", "3", nodes.FailRpc, 0, 0)
ctx.SetBehavior("1", "4", nodes.FailRpc, 0, 0)
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckNoLeader(t, nodeCollections)
CheckTerm(t, nodeCollections[0], 3)
}
func TestDelayVoteRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0)
}
nodeCollections[0].StartElection()
time.Sleep(2 * time.Second)
CheckNoLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 2)
}
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, 50 * time.Millisecond, 0)
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
for i := 0; i < n; i++ {
CheckTerm(t, nodeCollections[i], 3)
}
}

+ 445
- 0
threadTest/fuzz/fuzz_test.go Dosyayı Görüntüle

@ -0,0 +1,445 @@
package fuzz
import (
"fmt"
"math/rand"
"os"
"runtime/debug"
"sync"
"testing"
"time"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"simple-kv-store/threadTest"
"strconv"
)
// 1.针对随机配置随机消息状态
func FuzzRaftBasic(f *testing.F) {
var seenSeeds sync.Map
// 添加初始种子
f.Add(int64(1))
fmt.Println("Running")
f.Fuzz(func(t *testing.T, seed int64) {
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
n := 3 + 2*(r.Intn(4))
fmt.Printf("随机了%d个节点\n", n)
logs := (r.Intn(10))
fmt.Printf("随机了%d份日志\n", logs)
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
for i := 0; i < n; i++ {
node, quitChan := threadTest.ExecuteNodeI(strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1), false, peerIds, threadTransport)
nodeCollections = append(nodeCollections, node)
node.RTTable.SetElectionTimeout(750 * time.Millisecond)
quitCollections = append(quitCollections, quitChan)
}
// 模拟 a-b 通讯行为
faultyNodes := injectRandomBehavior(ctx, r, peerIds)
time.Sleep(time.Second)
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < logs; i++ {
key := fmt.Sprintf("k%d", i)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
}
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
if !faultyNodes[node.SelfId] {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckSameLog(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, nodeCollections)
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
// 确保完成退出
nodeCollections[i].Mu.Lock()
if !nodeCollections[i].IsFinish {
nodeCollections[i].IsFinish = true
}
nodeCollections[i].Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1))
os.RemoveAll("storage/node" + strconv.Itoa(int(seed)) + "." + strconv.Itoa(i+1))
}
})
}
// 注入节点间行为
func injectRandomBehavior(ctx *nodes.Ctx, r *rand.Rand, peers []string) map[string]bool /*id:Isfault*/ {
behaviors := []nodes.CallBehavior{
nodes.FailRpc,
nodes.DelayRpc,
nodes.RetryRpc,
}
n := len(peers)
maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点
// 随机选择出问题的节点
shuffled := append([]string(nil), peers...)
r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
faultyNodes := make(map[string]bool)
for i := 0; i < maxFaulty; i++ {
faultyNodes[shuffled[i]] = true
}
for _, one := range peers {
if faultyNodes[one] {
b := behaviors[r.Intn(len(behaviors))]
delay := time.Duration(r.Intn(100)) * time.Millisecond
switch b {
case nodes.FailRpc:
fmt.Printf("[%s]的异常行为是fail\n", one)
case nodes.DelayRpc:
fmt.Printf("[%s]的异常行为是delay\n", one)
case nodes.RetryRpc:
fmt.Printf("[%s]的异常行为是retry\n", one)
}
for _, two := range peers {
if one == two {
continue
}
if faultyNodes[one] && faultyNodes[two] {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, b, delay, 2)
ctx.SetBehavior(two, one, b, delay, 2)
}
}
}
}
return faultyNodes
}
// 2.对一个长时间运行的系统,注入随机行为
func FuzzRaftRobust(f *testing.F) {
var seenSeeds sync.Map
var fuzzMu sync.Mutex
// 添加初始种子
f.Add(int64(0))
fmt.Println("Running")
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
quitCollections := make(map[string]chan struct{})
nodeCollections := make(map[string]*nodes.Node)
for i := 0; i < n; i++ {
id := strconv.Itoa(i+1)
node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport)
nodeCollections[id] = node
quitCollections[id] = quitChan
}
f.Fuzz(func(t *testing.T, seed int64) {
fuzzMu.Lock()
defer fuzzMu.Unlock()
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections)
key := fmt.Sprintf("k%d", seed % 10)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
_, exist := faultyNodes[node.SelfId]
if !exist {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckLogInvariant(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, rightNodeCollections)
// ResetFaultyNodes
threadTransport.ResetConnectivity()
for id, isrestart := range faultyNodes {
if !isrestart {
for _, peerIds := range peerIds {
if id == peerIds {
continue
}
ctx.SetBehavior(id, peerIds, nodes.NormalRpc, 0, 0)
ctx.SetBehavior(peerIds, id, nodes.NormalRpc, 0, 0)
}
} else {
newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport)
quitCollections[id] = quitChan
nodeCollections[id] = newNode
}
fmt.Printf("[%s]恢复异常\n", id)
}
})
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for id, node := range nodeCollections {
// 确保完成退出
node.Mu.Lock()
if !node.IsFinish {
node.IsFinish = true
}
node.Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + id)
os.RemoveAll("storage/node" + id)
}
}
// 3.综合
func FuzzRaftPlus(f *testing.F) {
var seenSeeds sync.Map
// 添加初始种子
f.Add(int64(0))
fmt.Println("Running")
f.Fuzz(func(t *testing.T, seed int64) {
if _, loaded := seenSeeds.LoadOrStore(seed, true); loaded {
t.Skipf("Seed %d already tested, skipping...", seed)
return
}
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("goroutine panic: %v\n%s", r, debug.Stack())
f, _ := os.Create("panic_goroutine.log")
fmt.Fprint(f, msg)
f.Close()
}
}()
r := rand.New(rand.NewSource(seed)) // 使用局部 rand
n := 3 + 2*(r.Intn(4))
fmt.Printf("随机了%d个节点\n", n)
ElectionTimeOut := 500 + r.Intn(500)
fmt.Printf("随机的投票超时时间:%d\n", ElectionTimeOut)
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1))
}
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
quitCollections := make(map[string]chan struct{})
nodeCollections := make(map[string]*nodes.Node)
for i := 0; i < n; i++ {
id := strconv.Itoa(int(seed))+"."+strconv.Itoa(i+1)
node, quitChan := threadTest.ExecuteNodeI(id, false, peerIds, threadTransport)
nodeCollections[id] = node
node.RTTable.SetElectionTimeout(time.Duration(ElectionTimeOut) * time.Millisecond)
quitCollections[id] = quitChan
}
clientObj := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < 5; i++ { // 模拟10次异常
fmt.Printf("第%d轮异常注入开始\n", i + 1)
faultyNodes := injectRandomBehavior2(ctx, r, peerIds, threadTransport, quitCollections)
key := fmt.Sprintf("k%d", i)
log := nodes.LogEntry{Key: key, Value: "v"}
clientObj.Write(log)
time.Sleep(time.Second)
var rightNodeCollections []*nodes.Node
for _, node := range nodeCollections {
_, exist := faultyNodes[node.SelfId]
if !exist {
rightNodeCollections = append(rightNodeCollections, node)
}
}
threadTest.CheckLogInvariant(t, rightNodeCollections)
threadTest.CheckLeaderInvariant(t, rightNodeCollections)
// ResetFaultyNodes
threadTransport.ResetConnectivity()
for id, isrestart := range faultyNodes {
if !isrestart {
for _, peerId := range peerIds {
if id == peerId {
continue
}
ctx.SetBehavior(id, peerId, nodes.NormalRpc, 0, 0)
ctx.SetBehavior(peerId, id, nodes.NormalRpc, 0, 0)
}
} else {
newNode, quitChan := threadTest.ExecuteNodeI(id, true, peerIds, threadTransport)
quitCollections[id] = quitChan
nodeCollections[id] = newNode
}
fmt.Printf("[%s]恢复异常\n", id)
}
}
for _, quitChan := range quitCollections {
close(quitChan)
}
time.Sleep(time.Second)
for id, node := range nodeCollections {
// 确保完成退出
node.Mu.Lock()
if !node.IsFinish {
node.IsFinish = true
}
node.Mu.Unlock()
os.RemoveAll("leveldb/simple-kv-store" + id)
os.RemoveAll("storage/node" + id)
}
})
}
func injectRandomBehavior2(ctx *nodes.Ctx, r *rand.Rand, peers []string, tran *nodes.ThreadTransport, quitCollections map[string]chan struct{}) map[string]bool /*id:needRestart*/ {
n := len(peers)
maxFaulty := r.Intn(n/2 + 1) // 随机选择 0 ~ n/2 个出问题的节点
// 随机选择出问题的节点
shuffled := append([]string(nil), peers...)
r.Shuffle(n, func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
faultyNodes := make(map[string]bool)
for i := 0; i < maxFaulty; i++ {
faultyNodes[shuffled[i]] = false
}
PartitionNodes := make(map[string]bool)
for _, one := range peers {
_, exist := faultyNodes[one]
if exist {
b := r.Intn(5)
switch b {
case 0:
fmt.Printf("[%s]的异常行为是fail\n", one)
for _, two := range peers {
if one == two {
continue
}
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
}
case 1:
fmt.Printf("[%s]的异常行为是delay\n", one)
t := r.Intn(100)
fmt.Printf("[%s]的delay time = %d\n", one, t)
delay := time.Duration(t) * time.Millisecond
for _, two := range peers {
if one == two {
continue
}
_, exist2 := faultyNodes[two]
if exist2 {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, nodes.DelayRpc, delay, 0)
ctx.SetBehavior(two, one, nodes.DelayRpc, delay, 0)
}
}
case 2:
fmt.Printf("[%s]的异常行为是retry\n", one)
for _, two := range peers {
if one == two {
continue
}
_, exist2 := faultyNodes[two]
if exist2 {
ctx.SetBehavior(one, two, nodes.FailRpc, 0, 0)
ctx.SetBehavior(two, one, nodes.FailRpc, 0, 0)
} else {
ctx.SetBehavior(one, two, nodes.RetryRpc, 0, 2)
ctx.SetBehavior(two, one, nodes.RetryRpc, 0, 2)
}
}
case 3:
fmt.Printf("[%s]的异常行为是stop\n", one)
faultyNodes[one] = true
close(quitCollections[one])
case 4:
fmt.Printf("[%s]的异常行为是partition\n", one)
PartitionNodes[one] = true
}
}
}
for id, _ := range PartitionNodes {
for _, two := range peers {
if !PartitionNodes[two] {
tran.SetConnectivity(id, two, false)
tran.SetConnectivity(two, id, false)
}
}
}
return faultyNodes
}

+ 326
- 0
threadTest/log_replication_test.go Dosyayı Görüntüle

@ -0,0 +1,326 @@
package threadTest
import (
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestNormalReplication(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestParallelReplication(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
go nodeCollections[0].BroadCastKV()
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestFollowerLagging(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
close(quitCollections[1])
time.Sleep(time.Second)
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
node, q := ExecuteStaticNodeI("2", true, peerIds, threadTransport)
quitCollections[1] = q
nodeCollections[1] = node
nodeCollections[1].State = nodes.Follower
StopElectionReset(nodeCollections[1:2])
nodeCollections[0].BroadCastKV()
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestFailLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.FailRpc, 0, 0)
}
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 1; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 0)
}
}
func TestRepeatLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestDelayLogAppendRpc(t *testing.T) {
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteStaticNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
StopElectionReset(nodeCollections)
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
for i := 0; i < n; i++ {
nodeCollections[i].State = nodes.Follower
}
nodeCollections[0].StartElection()
time.Sleep(time.Second)
CheckOneLeader(t, nodeCollections)
CheckIsLeader(t, nodeCollections[0])
CheckTerm(t, nodeCollections[0], 2)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.DelayRpc, time.Second, 0)
}
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Millisecond * 100)
for i := 0; i < n; i++ {
ctx.SetBehavior("1", nodeCollections[i].SelfId, nodes.NormalRpc, 0, 0)
}
for i := 5; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
go SendKvCall(&nodes.LogEntryCall{LogE: newlog}, nodeCollections[0])
}
time.Sleep(time.Second * 2)
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}

+ 245
- 0
threadTest/network_partition_test.go Dosyayı Görüntüle

@ -0,0 +1,245 @@
package threadTest
import (
"fmt"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"strings"
"testing"
"time"
)
func TestBasicConnectivity(t *testing.T) {
transport := nodes.NewThreadTransport(nodes.NewCtx())
transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10))
transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10))
// 断开 A 和 B
transport.SetConnectivity("1", "2", false)
err := transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if err == nil {
t.Errorf("Expected network partition error, but got nil")
}
// 恢复连接
transport.SetConnectivity("1", "2", true)
err = transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if !strings.Contains(err.Error(), "RPC 调用超时") {
t.Errorf("Expected success, but got error: %v", err)
}
}
func TestSingelPartition(t *testing.T) {
// 登记结点信息
n := 3
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
fmt.Println("开始分区模拟1")
var leaderNo int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderNo = i
for j := 0; j < n; j++ {
if i != j { // 切断其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false)
}
}
}
}
time.Sleep(2 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j { // 恢复其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[leaderNo].SelfId, true)
}
}
time.Sleep(1 * time.Second)
var leaderCnt int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
leaderNo = i
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
fmt.Println("开始分区模拟2")
for j := 0; j < n; j++ {
if leaderNo != j { // 切断leader到其它节点的消息
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, false)
}
}
time.Sleep(1 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
leaderCnt = 0
for j := 0; j < n; j++ {
if nodeCollections[j].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j {
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, true)
}
}
time.Sleep(time.Second)
// 日志一致性检查
for i := 0; i < n; i++ {
if len(nodeCollections[i].Log) != 5 {
t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log)))
}
}
}
func TestQuorumPartition(t *testing.T) {
// 登记结点信息
n := 5 // 奇数,模拟不超过半数节点分区
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
fmt.Println("开始分区模拟1")
for i := 0; i < n / 2; i++ {
for j := n / 2; j < n; j++ {
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false)
threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, false)
}
}
time.Sleep(2 * time.Second)
leaderCnt := 0
for i := 0; i < n / 2; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 0 {
t.Errorf("少数分区不应该产生leader")
}
for i := n / 2; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多数分区应该产生一个leader")
}
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 恢复网络
for i := 0; i < n / 2; i++ {
for j := n / 2; j < n; j++ {
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, true)
threadTransport.SetConnectivity(nodeCollections[i].SelfId, nodeCollections[j].SelfId, true)
}
}
time.Sleep(1 * time.Second)
leaderCnt = 0
for j := 0; j < n; j++ {
if nodeCollections[j].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
// 日志一致性检查
for i := 0; i < n; i++ {
if len(nodeCollections[i].Log) != 5 {
t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log)))
}
}
}

+ 129
- 0
threadTest/restart_node_test.go Dosyayı Görüntüle

@ -0,0 +1,129 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
cWrite := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
ClientWriteLog(t, 0, 5, cWrite)
time.Sleep(time.Second) // 等待写入完毕
// 模拟结点轮流崩溃
for i := 0; i < n; i++ {
close(quitCollections[i])
time.Sleep(time.Second)
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), true, peerIds, threadTransport)
quitCollections[i] = quitChan
time.Sleep(time.Second) // 等待启动完毕
}
// client启动
cRead := clientPkg.NewClient("0", peerIds, threadTransport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}
func TestRestartWhileWriting(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
leaderIdx := FindLeader(t, nodeCollections)
// client启动, 连接任意节点
cWrite := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
go ClientWriteLog(t, 0, 5, cWrite)
go func() {
close(quitCollections[leaderIdx])
n, quitChan := ExecuteNodeI(strconv.Itoa(leaderIdx + 1), true, peerIds, threadTransport)
quitCollections[leaderIdx] = quitChan
nodeCollections[leaderIdx] = n
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
cRead := clientPkg.NewClient("0", peerIds, threadTransport)
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s := cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
CheckLogNum(t, nodeCollections[leaderIdx], 5)
}

+ 192
- 0
threadTest/server_client_test.go Dosyayı Görüntüle

@ -0,0 +1,192 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport(nodes.NewCtx())
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 10; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}
func TestRepeatClientReq(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.NewClient("0", peerIds, threadTransport)
for i := 0; i < n; i++ {
ctx.SetBehavior("", nodeCollections[i].SelfId, nodes.RetryRpc, 0, 2)
}
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(newlog)
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 10; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 10)
}
}
func TestParallelClientReq(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
ctx := nodes.NewCtx()
threadTransport := nodes.NewThreadTransport(ctx)
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c1 := clientPkg.NewClient("0", peerIds, threadTransport)
c2 := clientPkg.NewClient("1", peerIds, threadTransport)
// 写入
go ClientWriteLog(t, 0, 10, c1)
go ClientWriteLog(t, 0, 10, c2)
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s := c1.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
for i := 0; i < n; i++ {
CheckLogNum(t, nodeCollections[i], 20)
}
}

Yükleniyor…
İptal
Kaydet