Browse Source

分区测试(以及适应细粒度的修改)

ld
augurier 5 months ago
parent
commit
f669abd7ad
15 changed files with 457 additions and 262 deletions
  1. +8
    -4
      internal/client/client_node.go
  2. +49
    -49
      internal/nodes/init.go
  3. +0
    -8
      internal/nodes/log.go
  4. +17
    -17
      internal/nodes/node.go
  5. +1
    -1
      internal/nodes/real_transport.go
  6. +100
    -91
      internal/nodes/replica.go
  7. +17
    -23
      internal/nodes/server_node.go
  8. +53
    -7
      internal/nodes/thread_transport.go
  9. +1
    -1
      internal/nodes/transport.go
  10. +57
    -57
      internal/nodes/vote.go
  11. +1
    -1
      test/restart_node_test.go
  12. +1
    -1
      test/server_client_test.go
  13. +150
    -0
      threadTest/network_partition_test.go
  14. +1
    -1
      threadTest/restart_node_test.go
  15. +1
    -1
      threadTest/server_client_test.go

+ 8
- 4
internal/client/client_node.go View File

@ -35,7 +35,7 @@ func (client *Client) FindActiveNode() nodes.ClientInterface {
var c nodes.ClientInterface var c nodes.ClientInterface
for { // 直到找到一个可连接的节点(保证至少一个节点活着) for { // 直到找到一个可连接的节点(保证至少一个节点活着)
peerId := getRandomAddress(client.PeerIds) peerId := getRandomAddress(client.PeerIds)
c, err = client.Transport.DialHTTPWithTimeout("tcp", peerId)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", peerId)
if err != nil { if err != nil {
log.Error("dialing: ", zap.Error(err)) log.Error("dialing: ", zap.Error(err))
} else { } else {
@ -72,9 +72,13 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status {
if !reply.Isleader { // 对方不是leader,根据反馈找leader if !reply.Isleader { // 对方不是leader,根据反馈找leader
leaderId := reply.LeaderId leaderId := reply.LeaderId
client.CloseRpcClient(c) client.CloseRpcClient(c)
c, err = client.Transport.DialHTTPWithTimeout("tcp", leaderId)
for err != nil { // 重新找下一个存活节点
if leaderId == "" { // 这个节点不知道leader是谁,再随机找
c = client.FindActiveNode() c = client.FindActiveNode()
} else { // dial leader
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", leaderId)
for err != nil { // dial失败,重新找下一个存活节点
c = client.FindActiveNode()
}
} }
} else { // 成功 } else { // 成功
client.CloseRpcClient(c) client.CloseRpcClient(c)
@ -132,7 +136,7 @@ func (client *Client) FindLeader() string {
if !reply.Isleader { // 对方不是leader,根据反馈找leader if !reply.Isleader { // 对方不是leader,根据反馈找leader
client.CloseRpcClient(c) client.CloseRpcClient(c)
c, err = client.Transport.DialHTTPWithTimeout("tcp", reply.LeaderId)
c, err = client.Transport.DialHTTPWithTimeout("tcp", "", reply.LeaderId)
for err != nil { // 重新找下一个存活节点 for err != nil { // 重新找下一个存活节点
c = client.FindActiveNode() c = client.FindActiveNode()
} }

+ 49
- 49
internal/nodes/init.go View File

@ -16,7 +16,7 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
// 运行在进程上的初始化 + rpc注册 // 运行在进程上的初始化 + rpc注册
func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
func InitRPCNode(SelfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
var nodeIds []string var nodeIds []string
for id := range nodeAddr { for id := range nodeAddr {
nodeIds = append(nodeIds, id) nodeIds = append(nodeIds, id)
@ -24,29 +24,29 @@ func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *lev
// 创建节点 // 创建节点
node := &Node{ node := &Node{
selfId: selfId,
leaderId: "",
nodes: nodeIds,
maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
currTerm: 1,
log: make([]RaftLogEntry, 0),
commitIndex: -1,
lastApplied: -1,
nextIndex: make(map[string]int),
matchIndex: make(map[string]int),
db: db,
storage: rstorage,
transport: &HTTPTransport{NodeMap: nodeAddr},
SelfId: SelfId,
LeaderId: "",
Nodes: nodeIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: &HTTPTransport{NodeMap: nodeAddr},
} }
node.initLeaderState() node.initLeaderState()
if isRestart { if isRestart {
node.currTerm = rstorage.GetCurrentTerm()
node.votedFor = rstorage.GetVotedFor()
node.log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log))
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
} }
log.Sugar().Infof("[%s]开始监听" + port + "端口", selfId)
log.Sugar().Infof("[%s]开始监听" + port + "端口", SelfId)
node.ListenPort(port) node.ListenPort(port)
return node return node
@ -73,33 +73,33 @@ func (node *Node) ListenPort(port string) {
} }
// 线程模拟的初始化 // 线程模拟的初始化
func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) {
func InitThreadNode(SelfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) {
rpcChan := make(chan RPCRequest, 100) // 要监听的chan rpcChan := make(chan RPCRequest, 100) // 要监听的chan
// 创建节点 // 创建节点
node := &Node{ node := &Node{
selfId: selfId,
leaderId: "",
nodes: peerIds,
maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
currTerm: 1,
log: make([]RaftLogEntry, 0),
commitIndex: -1,
lastApplied: -1,
nextIndex: make(map[string]int),
matchIndex: make(map[string]int),
db: db,
storage: rstorage,
transport: threadTransport,
SelfId: SelfId,
LeaderId: "",
Nodes: peerIds,
MaxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
CurrTerm: 1,
Log: make([]RaftLogEntry, 0),
CommitIndex: -1,
LastApplied: -1,
NextIndex: make(map[string]int),
MatchIndex: make(map[string]int),
Db: db,
Storage: rstorage,
Transport: threadTransport,
} }
node.initLeaderState() node.initLeaderState()
if isRestart { if isRestart {
node.currTerm = rstorage.GetCurrentTerm()
node.votedFor = rstorage.GetVotedFor()
node.log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log))
node.CurrTerm = rstorage.GetCurrentTerm()
node.VotedFor = rstorage.GetVotedFor()
node.Log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", SelfId, len(node.Log))
} }
threadTransport.RegisterNodeChan(selfId, rpcChan)
threadTransport.RegisterNodeChan(SelfId, rpcChan)
quitChan := make(chan struct{}, 1) quitChan := make(chan struct{}, 1)
go node.listenForChan(rpcChan, quitChan) go node.listenForChan(rpcChan, quitChan)
@ -107,7 +107,7 @@ func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *R
} }
func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) { func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) {
defer node.db.Close()
defer node.Db.Close()
for { for {
select { select {
@ -162,7 +162,7 @@ func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{})
req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod) req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod)
} }
case <-quitChan: case <-quitChan:
log.Sugar().Infof("[%s] 监听线程收到退出信号", node.selfId)
log.Sugar().Infof("[%s] 监听线程收到退出信号", node.SelfId)
return return
} }
} }
@ -170,14 +170,14 @@ func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{})
// 共同部分和启动 // 共同部分和启动
func (n *Node) initLeaderState() { func (n *Node) initLeaderState() {
for _, peerId := range n.nodes {
n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引
n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引
for _, peerId := range n.Nodes {
n.NextIndex[peerId] = len(n.Log) // 发送日志的下一个索引
n.MatchIndex[peerId] = 0 // 复制日志的最新匹配索引
} }
} }
func Start(node *Node, quitChan chan struct{}) { func Start(node *Node, quitChan chan struct{}) {
node.state = Follower // 所有节点以 Follower 状态启动
node.State = Follower // 所有节点以 Follower 状态启动
node.resetElectionTimer() // 启动选举超时定时器 node.resetElectionTimer() // 启动选举超时定时器
go func() { go func() {
@ -187,20 +187,20 @@ func Start(node *Node, quitChan chan struct{}) {
for { for {
select { select {
case <-quitChan: case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.selfId)
fmt.Printf("[%s] Raft start 退出...\n", node.SelfId)
return // 退出 goroutine return // 退出 goroutine
case <-ticker.C: case <-ticker.C:
switch node.state {
switch node.State {
case Follower: case Follower:
// 监听心跳超时 // 监听心跳超时
fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId)
// fmt.Printf("[%s] is a follower, 监听中...\n", node.SelfId)
case Leader: case Leader:
// 发送心跳 // 发送心跳
fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId)
// fmt.Printf("[%s] is the leader, 发送心跳...\n", node.SelfId)
node.resetElectionTimer() // leader 不主动触发选举 node.resetElectionTimer() // leader 不主动触发选举
node.BroadCastKV(Normal)
node.BroadCastKV()
} }
} }
} }

+ 0
- 8
internal/nodes/log.go View File

@ -2,13 +2,6 @@ package nodes
import "strconv" import "strconv"
type CallMode = uint8
const (
Normal CallMode = iota + 1
Delay
Fail
)
type LogEntry struct { type LogEntry struct {
Key string Key string
Value string Value string
@ -28,7 +21,6 @@ func (RLogE *RaftLogEntry) print() string {
type LogEntryCall struct { type LogEntryCall struct {
LogE LogEntry LogE LogEntry
CallState CallMode
} }
type KVReply struct { type KVReply struct {

+ 17
- 17
internal/nodes/node.go View File

@ -16,48 +16,48 @@ const (
) )
type Node struct { type Node struct {
mu sync.Mutex
Mu sync.Mutex
// 当前节点id // 当前节点id
selfId string
SelfId string
// 记录的leader(不能用votedfor:投票的leader可能没有收到多数票) // 记录的leader(不能用votedfor:投票的leader可能没有收到多数票)
leaderId string
LeaderId string
// 除当前节点外其他节点id // 除当前节点外其他节点id
nodes []string
Nodes []string
// 当前节点状态 // 当前节点状态
state State
State State
// 任期 // 任期
currTerm int
CurrTerm int
// 简单的kv存储 // 简单的kv存储
log []RaftLogEntry
Log []RaftLogEntry
// leader用来标记新log, = log.len // leader用来标记新log, = log.len
maxLogId int
MaxLogId int
// 已提交的index // 已提交的index
commitIndex int
CommitIndex int
// 最后应用(写到db)的index // 最后应用(写到db)的index
lastApplied int
LastApplied int
// 需要发送给每个节点的下一个索引 // 需要发送给每个节点的下一个索引
nextIndex map[string]int
NextIndex map[string]int
// 已经发送给每个节点的最大索引 // 已经发送给每个节点的最大索引
matchIndex map[string]int
MatchIndex map[string]int
// 存kv(模拟状态机) // 存kv(模拟状态机)
db *leveldb.DB
Db *leveldb.DB
// 持久化节点数据(currterm votedfor log) // 持久化节点数据(currterm votedfor log)
storage *RaftStorage
Storage *RaftStorage
votedFor string
electionTimer *time.Timer
VotedFor string
ElectionTimer *time.Timer
// 通信方式 // 通信方式
transport Transport
Transport Transport
} }

+ 1
- 1
internal/nodes/real_transport.go View File

@ -13,7 +13,7 @@ type HTTPTransport struct{
} }
// 封装有超时的dial // 封装有超时的dial
func (t *HTTPTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) {
func (t *HTTPTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
done := make(chan struct{}) done := make(chan struct{})
var client *rpc.Client var client *rpc.Client
var err error var err error

+ 100
- 91
internal/nodes/replica.go View File

@ -1,10 +1,9 @@
package nodes package nodes
import ( import (
"math/rand"
"sort" "sort"
"strconv" "strconv"
"time"
"sync"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -24,31 +23,31 @@ type AppendEntriesReply struct {
} }
// leader收到新内容要广播,以及心跳广播(同步自己的log) // leader收到新内容要广播,以及心跳广播(同步自己的log)
func (node *Node) BroadCastKV(callMode CallMode) {
func (node *Node) BroadCastKV() {
log.Sugar().Infof("leader[%s]广播消息", node.SelfId)
failCount := 0
// 这里增加一个锁,防止并发修改成功计数
var failMutex sync.Mutex
// 遍历所有节点 // 遍历所有节点
for _, id := range node.nodes {
go func(id string, kv CallMode) {
node.sendKV(id, callMode)
}(id, callMode)
for _, id := range node.Nodes {
go func(id string) {
node.sendKV(id, &failCount, &failMutex)
}(id)
} }
} }
func (node *Node) sendKV(peerId string, callMode CallMode) {
switch callMode {
case Fail:
log.Info("模拟发送失败")
// 这么写向所有的node发送都失败,也可以随机数确定是否失败
case Delay:
log.Info("模拟发送延迟")
// 随机延迟0-5ms
time.Sleep(time.Millisecond * time.Duration(rand.Intn(5)))
default:
}
client, err := node.transport.DialHTTPWithTimeout("tcp", peerId)
func (node *Node) sendKV(peerId string, failCount *int, failMutex *sync.Mutex) {
client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId)
if err != nil { if err != nil {
log.Error(node.selfId + "dialling [" + peerId + "] fail: ", zap.Error(err))
log.Error("[" + node.SelfId + "]dialling [" + peerId + "] fail: ", zap.Error(err))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
node.LeaderId = ""
node.State = Follower
node.resetElectionTimer()
}
failMutex.Unlock()
return return
} }
@ -59,69 +58,78 @@ func (node *Node) sendKV(peerId string, callMode CallMode) {
} }
}(client) }(client)
node.mu.Lock()
defer node.mu.Unlock()
node.Mu.Lock()
defer node.Mu.Unlock()
var appendReply AppendEntriesReply var appendReply AppendEntriesReply
appendReply.Success = false appendReply.Success = false
nextIndex := node.nextIndex[peerId]
// log.Info("nextindex " + strconv.Itoa(nextIndex))
NextIndex := node.NextIndex[peerId]
// log.Info("NextIndex " + strconv.Itoa(NextIndex))
for (!appendReply.Success) { for (!appendReply.Success) {
if nextIndex < 0 {
if NextIndex < 0 {
log.Fatal("assert >= 0 here") log.Fatal("assert >= 0 here")
} }
sendEntries := node.log[nextIndex:]
sendEntries := node.Log[NextIndex:]
arg := AppendEntriesArg{ arg := AppendEntriesArg{
Term: node.currTerm,
PrevLogIndex: nextIndex - 1,
Term: node.CurrTerm,
PrevLogIndex: NextIndex - 1,
Entries: sendEntries, Entries: sendEntries,
LeaderCommit: node.commitIndex,
LeaderId: node.selfId,
LeaderCommit: node.CommitIndex,
LeaderId: node.SelfId,
} }
if arg.PrevLogIndex >= 0 { if arg.PrevLogIndex >= 0 {
arg.PrevLogTerm = node.log[arg.PrevLogIndex].Term
arg.PrevLogTerm = node.Log[arg.PrevLogIndex].Term
} }
callErr := node.transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
callErr := node.Transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
if callErr != nil { if callErr != nil {
log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr))
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
failMutex.Lock()
*failCount++
if *failCount == len(node.Nodes) / 2 + 1 { // 无法联系超过半数:自己有问题,降级
node.LeaderId = ""
node.State = Follower
node.resetElectionTimer()
}
failMutex.Unlock()
return return
} }
if appendReply.Term != node.currTerm {
log.Info("term=" + strconv.Itoa(node.currTerm) + "的Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower")
node.currTerm = appendReply.Term
node.state = Follower
node.votedFor = ""
node.storage.SetTermAndVote(node.currTerm, node.votedFor)
if appendReply.Term != node.CurrTerm {
log.Info("term=" + strconv.Itoa(node.CurrTerm) + "的Leader[" + node.SelfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower")
node.LeaderId = ""
node.CurrTerm = appendReply.Term
node.State = Follower
node.VotedFor = ""
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
node.resetElectionTimer() node.resetElectionTimer()
return return
} }
nextIndex-- // 失败往前传一格
NextIndex-- // 失败往前传一格
} }
// 不变成follower情况下 // 不变成follower情况下
node.nextIndex[peerId] = node.maxLogId + 1
node.matchIndex[peerId] = node.maxLogId
node.NextIndex[peerId] = node.MaxLogId + 1
node.MatchIndex[peerId] = node.MaxLogId
node.updateCommitIndex() node.updateCommitIndex()
} }
func (node *Node) updateCommitIndex() { func (node *Node) updateCommitIndex() {
totalNodes := len(node.nodes)
totalNodes := len(node.Nodes)
// 收集所有 matchIndex 并排序
matchIndexes := make([]int, 0, totalNodes)
for _, index := range node.matchIndex {
matchIndexes = append(matchIndexes, index)
// 收集所有 MatchIndex 并排序
MatchIndexes := make([]int, 0, totalNodes)
for _, index := range node.MatchIndex {
MatchIndexes = append(MatchIndexes, index)
} }
sort.Ints(matchIndexes) // 排序
sort.Ints(MatchIndexes) // 排序
// 计算多数派 commitIndex
majorityIndex := matchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派)
// 计算多数派 CommitIndex
majorityIndex := MatchIndexes[totalNodes/2] // 取 N/2 位置上的索引(多数派)
// 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志 // 确保这个索引的日志条目属于当前 term,防止提交旧 term 的日志
if majorityIndex > node.commitIndex && majorityIndex < len(node.log) && node.log[majorityIndex].Term == node.currTerm {
node.commitIndex = majorityIndex
log.Info("Leader[" + node.selfId + "]更新 commitIndex: " + strconv.Itoa(majorityIndex))
if majorityIndex > node.CommitIndex && majorityIndex < len(node.Log) && node.Log[majorityIndex].Term == node.CurrTerm {
node.CommitIndex = majorityIndex
log.Info("Leader[" + node.SelfId + "]更新 CommitIndex: " + strconv.Itoa(majorityIndex))
// 应用日志到状态机 // 应用日志到状态机
node.applyCommittedLogs() node.applyCommittedLogs()
@ -130,13 +138,13 @@ func (node *Node) updateCommitIndex() {
// 应用日志到状态机 // 应用日志到状态机
func (node *Node) applyCommittedLogs() { func (node *Node) applyCommittedLogs() {
for node.lastApplied < node.commitIndex {
node.lastApplied++
logEntry := node.log[node.lastApplied]
log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.selfId)
err := node.db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil)
for node.LastApplied < node.CommitIndex {
node.LastApplied++
logEntry := node.Log[node.LastApplied]
log.Sugar().Infof("[%s]应用日志到状态机: " + logEntry.print(), node.SelfId)
err := node.Db.Put([]byte(logEntry.LogE.Key), []byte(logEntry.LogE.Value), nil)
if err != nil { if err != nil {
log.Error(node.selfId + "应用状态机失败: ", zap.Error(err))
log.Error(node.SelfId + "应用状态机失败: ", zap.Error(err))
} }
} }
} }
@ -147,65 +155,66 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
// defer func() { // defer func() {
// log.Sugar().Infof("AppendEntries 处理时间: %v", time.Since(start)) // log.Sugar().Infof("AppendEntries 处理时间: %v", time.Since(start))
// }() // }()
node.mu.Lock()
defer node.mu.Unlock()
log.Sugar().Infof("[%s]收到[%s]的AppendEntries", node.SelfId, arg.LeaderId)
node.Mu.Lock()
defer node.Mu.Unlock()
// 如果 term 过期,拒绝接受日志 // 如果 term 过期,拒绝接受日志
if node.currTerm > arg.Term {
*reply = AppendEntriesReply{node.currTerm, false}
if node.CurrTerm > arg.Term {
*reply = AppendEntriesReply{node.CurrTerm, false}
return nil return nil
} }
node.leaderId = arg.LeaderId // 记录Leader
node.LeaderId = arg.LeaderId // 记录Leader
// 如果term比自己高,或自己不是follower但收到相同term的心跳 // 如果term比自己高,或自己不是follower但收到相同term的心跳
if node.currTerm < arg.Term || node.state != Follower {
log.Sugar().Infof("[%s]发现更高 term(%s)", node.selfId, strconv.Itoa(arg.Term))
node.currTerm = arg.Term
node.state = Follower
node.votedFor = ""
// node.storage.SetTermAndVote(node.currTerm, node.votedFor)
if node.CurrTerm < arg.Term || node.State != Follower {
log.Sugar().Infof("[%s]发现更高 term(%s)", node.SelfId, strconv.Itoa(arg.Term))
node.CurrTerm = arg.Term
node.State = Follower
node.VotedFor = ""
// node.storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
} }
node.storage.SetTermAndVote(node.currTerm, node.votedFor)
node.Storage.SetTermAndVote(node.CurrTerm, node.VotedFor)
// 检查 prevLogIndex 是否有效 // 检查 prevLogIndex 是否有效
if arg.PrevLogIndex >= len(node.log) || (arg.PrevLogIndex >= 0 && node.log[arg.PrevLogIndex].Term != arg.PrevLogTerm) {
*reply = AppendEntriesReply{node.currTerm, false}
if arg.PrevLogIndex >= len(node.Log) || (arg.PrevLogIndex >= 0 && node.Log[arg.PrevLogIndex].Term != arg.PrevLogTerm) {
*reply = AppendEntriesReply{node.CurrTerm, false}
return nil return nil
} }
// 处理日志冲突(如果存在不同 term,则截断日志) // 处理日志冲突(如果存在不同 term,则截断日志)
idx := arg.PrevLogIndex + 1 idx := arg.PrevLogIndex + 1
for i := idx; i < len(node.log) && i-idx < len(arg.Entries); i++ {
if node.log[i].Term != arg.Entries[i-idx].Term {
node.log = node.log[:idx]
for i := idx; i < len(node.Log) && i-idx < len(arg.Entries); i++ {
if node.Log[i].Term != arg.Entries[i-idx].Term {
node.Log = node.Log[:idx]
break break
} }
} }
// log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.log)))
// log.Info(strconv.Itoa(idx) + strconv.Itoa(len(node.Log)))
// 追加新的日志条目 // 追加新的日志条目
for _, raftLogEntry := range arg.Entries { for _, raftLogEntry := range arg.Entries {
log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.selfId)
if idx < len(node.log) {
node.log[idx] = raftLogEntry
log.Sugar().Infof("[%s]写入:" + raftLogEntry.print(), node.SelfId)
if idx < len(node.Log) {
node.Log[idx] = raftLogEntry
} else { } else {
node.log = append(node.log, raftLogEntry)
node.Log = append(node.Log, raftLogEntry)
} }
idx++ idx++
} }
// 暴力持久化 // 暴力持久化
node.storage.WriteLog(node.log)
node.Storage.WriteLog(node.Log)
// 更新 maxLogId
node.maxLogId = len(node.log) - 1
// 更新 MaxLogId
node.MaxLogId = len(node.Log) - 1
// 更新 commitIndex
if arg.LeaderCommit < node.maxLogId {
node.commitIndex = arg.LeaderCommit
// 更新 CommitIndex
if arg.LeaderCommit < node.MaxLogId {
node.CommitIndex = arg.LeaderCommit
} else { } else {
node.commitIndex = node.maxLogId
node.CommitIndex = node.MaxLogId
} }
// 提交已提交的日志 // 提交已提交的日志
@ -213,6 +222,6 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
// 在成功接受日志或心跳后,重置选举超时 // 在成功接受日志或心跳后,重置选举超时
node.resetElectionTimer() node.resetElectionTimer()
*reply = AppendEntriesReply{node.currTerm, true}
*reply = AppendEntriesReply{node.CurrTerm, true}
return nil return nil
} }

+ 17
- 23
internal/nodes/server_node.go View File

@ -1,8 +1,6 @@
package nodes package nodes
import ( import (
"strconv"
"github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb"
) )
@ -15,38 +13,34 @@ type ServerReply struct{
} }
// RPC call // RPC call
func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error { func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error {
log.Sugar().Infof("[%s]收到客户端write请求", node.selfId)
log.Sugar().Infof("[%s]收到客户端write请求", node.SelfId)
// 自己不是leader,转交leader地址回复 // 自己不是leader,转交leader地址回复
if node.state != Leader {
if node.State != Leader {
reply.Isleader = false reply.Isleader = false
if (node.leaderId == "") {
log.Fatal("还没选出第一个leader")
return nil
}
reply.LeaderId = node.leaderId
log.Sugar().Infof("[%s]转交给[%s]", node.selfId, node.leaderId)
reply.LeaderId = node.LeaderId // 可能是空,那client就随机再找一个节点
log.Sugar().Infof("[%s]转交给[%s]", node.SelfId, node.LeaderId)
return nil return nil
} }
// 自己是leader,修改自己的记录并广播 // 自己是leader,修改自己的记录并广播
node.maxLogId++
logId := node.maxLogId
rLogE := RaftLogEntry{kvCall.LogE, logId, node.currTerm}
node.log = append(node.log, rLogE)
node.storage.AppendLog(rLogE)
log.Info("leader[" + node.selfId + "]处理请求 : " + kvCall.LogE.print() + ", 模拟方式 : " + strconv.Itoa(int(kvCall.CallState)))
node.MaxLogId++
logId := node.MaxLogId
rLogE := RaftLogEntry{kvCall.LogE, logId, node.CurrTerm}
node.Log = append(node.Log, rLogE)
node.Storage.AppendLog(rLogE)
log.Info("leader[" + node.SelfId + "]处理请求 : " + kvCall.LogE.print())
// 广播给其它节点 // 广播给其它节点
node.BroadCastKV(kvCall.CallState)
node.BroadCastKV()
reply.Isleader = true reply.Isleader = true
return nil return nil
} }
// RPC call // RPC call
func (node *Node) ReadKey(key *string, reply *ServerReply) error { func (node *Node) ReadKey(key *string, reply *ServerReply) error {
log.Sugar().Infof("[%s]收到客户端read请求", node.selfId)
log.Sugar().Infof("[%s]收到客户端read请求", node.SelfId)
// 先只读自己(无论自己是不是leader),也方便测试 // 先只读自己(无论自己是不是leader),也方便测试
value, err := node.db.Get([]byte(*key), nil)
value, err := node.Db.Get([]byte(*key), nil)
if err == leveldb.ErrNotFound { if err == leveldb.ErrNotFound {
reply.HaveValue = false reply.HaveValue = false
} else { } else {
@ -64,17 +58,17 @@ type FindLeaderReply struct{
} }
func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error { func (node *Node) FindLeader(_ struct{}, reply *FindLeaderReply) error {
// 自己不是leader,转交leader地址回复 // 自己不是leader,转交leader地址回复
if node.state != Leader {
if node.State != Leader {
reply.Isleader = false reply.Isleader = false
if (node.leaderId == "") {
if (node.LeaderId == "") {
log.Fatal("还没选出第一个leader") log.Fatal("还没选出第一个leader")
return nil return nil
} }
reply.LeaderId = node.leaderId
reply.LeaderId = node.LeaderId
return nil return nil
} }
reply.LeaderId = node.selfId
reply.LeaderId = node.SelfId
reply.Isleader = true reply.Isleader = true
return nil return nil
} }

+ 53
- 7
internal/nodes/thread_transport.go View File

@ -18,11 +18,13 @@ type RPCRequest struct {
type ThreadTransport struct { type ThreadTransport struct {
mu sync.Mutex mu sync.Mutex
nodeChans map[string]chan RPCRequest // 每个节点的消息通道 nodeChans map[string]chan RPCRequest // 每个节点的消息通道
connectivityMap map[string]map[string]bool // 模拟网络分区
} }
// 线程版 dial的返回clientinterface // 线程版 dial的返回clientinterface
type ThreadClient struct { type ThreadClient struct {
targetId string
SourceId string
TargetId string
} }
func (c *ThreadClient) Close() error { func (c *ThreadClient) Close() error {
@ -33,6 +35,7 @@ func (c *ThreadClient) Close() error {
func NewThreadTransport() *ThreadTransport { func NewThreadTransport() *ThreadTransport {
return &ThreadTransport{ return &ThreadTransport{
nodeChans: make(map[string]chan RPCRequest), nodeChans: make(map[string]chan RPCRequest),
connectivityMap: make(map[string]map[string]bool),
} }
} }
@ -41,6 +44,24 @@ func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
t.nodeChans[nodeId] = ch t.nodeChans[nodeId] = ch
// 初始化连通性(默认所有节点互相可达)
if _, exists := t.connectivityMap[nodeId]; !exists {
t.connectivityMap[nodeId] = make(map[string]bool)
}
for peerId := range t.nodeChans {
t.connectivityMap[nodeId][peerId] = true
t.connectivityMap[peerId][nodeId] = true
}
}
// 设置两个节点的连通性
func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.connectivityMap[from]; exists {
t.connectivityMap[from][to] = isConnected
}
} }
// 获取节点的 channel // 获取节点的 channel
@ -52,27 +73,41 @@ func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) {
} }
// 模拟 Dial 操作 // 模拟 Dial 操作
func (t *ThreadTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) {
func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if _, exists := t.nodeChans[peerId]; !exists { if _, exists := t.nodeChans[peerId]; !exists {
return nil, fmt.Errorf("节点 [%s] 不存在", peerId) return nil, fmt.Errorf("节点 [%s] 不存在", peerId)
} }
return &ThreadClient{targetId: peerId}, nil
return &ThreadClient{SourceId: myId, TargetId: peerId}, nil
} }
// 模拟 Call 操作 // 模拟 Call 操作
func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
threadClient, ok := client.(*ThreadClient) threadClient, ok := client.(*ThreadClient)
if !ok { if !ok {
return fmt.Errorf("无效的客户端")
return fmt.Errorf("无效的caller")
} }
var isConnected bool
if threadClient.SourceId == "" { // 来自客户端的连接
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId] // 检查连通性
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("network partition: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
}
// 获取目标节点的 channel // 获取目标节点的 channel
targetChan, exists := t.getNodeChan(threadClient.targetId)
targetChan, exists := t.getNodeChan(threadClient.TargetId)
if !exists { if !exists {
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.targetId)
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId)
} }
// 创建响应通道(用于返回 RPC 结果) // 创建响应通道(用于返回 RPC 结果)
@ -91,11 +126,22 @@ func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod
// 等待响应或超时 // 等待响应或超时
select { select {
case err := <-done: case err := <-done:
if threadClient.SourceId == "" { // 来自客户端的连接
isConnected = true
} else {
t.mu.Lock()
isConnected = t.connectivityMap[threadClient.TargetId][threadClient.SourceId] // 检查连通性
t.mu.Unlock()
}
if !isConnected {
return fmt.Errorf("network partition: %s cannot reach %s", threadClient.TargetId, threadClient.SourceId)
}
return err return err
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod) return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
} }
default: default:
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.targetId)
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
} }
} }

+ 1
- 1
internal/nodes/transport.go View File

@ -5,6 +5,6 @@ type ClientInterface interface{
} }
type Transport interface { type Transport interface {
DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error)
DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
} }

+ 57
- 57
internal/nodes/vote.go View File

@ -22,15 +22,15 @@ type RequestVoteReply struct {
} }
func (n *Node) startElection() { func (n *Node) startElection() {
n.mu.Lock()
defer n.mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
// 增加当前任期,转换为 Candidate // 增加当前任期,转换为 Candidate
n.currTerm++
n.state = Candidate
n.votedFor = n.selfId // 自己投自己
n.storage.SetTermAndVote(n.currTerm, n.votedFor)
n.CurrTerm++
n.State = Candidate
n.VotedFor = n.SelfId // 自己投自己
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.selfId, n.currTerm)
log.Sugar().Infof("[%s] 开始选举,当前任期: %d", n.SelfId, n.CurrTerm)
// 重新设置选举超时,防止重复选举 // 重新设置选举超时,防止重复选举
n.resetElectionTimer() n.resetElectionTimer()
@ -39,40 +39,40 @@ func (n *Node) startElection() {
var lastLogIndex int var lastLogIndex int
var lastLogTerm int var lastLogTerm int
if len(n.log) == 0 {
if len(n.Log) == 0 {
lastLogIndex = 0 lastLogIndex = 0
lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0 lastLogTerm = 0 // 论文中定义,空日志时 Term 设为 0
} else { } else {
lastLogIndex = len(n.log) - 1
lastLogTerm = n.log[lastLogIndex].Term
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
} }
args := RequestVoteArgs{ args := RequestVoteArgs{
Term: n.currTerm,
CandidateId: n.selfId,
Term: n.CurrTerm,
CandidateId: n.SelfId,
LastLogIndex: lastLogIndex, LastLogIndex: lastLogIndex,
LastLogTerm: lastLogTerm, LastLogTerm: lastLogTerm,
} }
// 并行向其他节点发送请求投票 // 并行向其他节点发送请求投票
var mu sync.Mutex
totalNodes := len(n.nodes)
var Mu sync.Mutex
totalNodes := len(n.Nodes)
grantedVotes := 1 // 自己的票 grantedVotes := 1 // 自己的票
for _, peerId := range n.nodes {
for _, peerId := range n.Nodes {
go func(peerId string) { go func(peerId string) {
reply := RequestVoteReply{} reply := RequestVoteReply{}
if n.sendRequestVote(peerId, &args, &reply) { if n.sendRequestVote(peerId, &args, &reply) {
mu.Lock()
Mu.Lock()
if reply.Term > n.currTerm {
if reply.Term > n.CurrTerm {
// 发现更高任期,回退为 Follower // 发现更高任期,回退为 Follower
log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.selfId, reply.Term)
n.currTerm = reply.Term
n.state = Follower
n.votedFor = ""
n.storage.SetTermAndVote(n.currTerm, n.votedFor)
log.Sugar().Infof("[%s] 发现更高的 Term (%d),回退为 Follower", n.SelfId, reply.Term)
n.CurrTerm = reply.Term
n.State = Follower
n.VotedFor = ""
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
n.resetElectionTimer() n.resetElectionTimer()
mu.Unlock()
Mu.Unlock()
return return
} }
@ -81,32 +81,32 @@ func (n *Node) startElection() {
} }
if grantedVotes == totalNodes / 2 + 1 { if grantedVotes == totalNodes / 2 + 1 {
n.state = Leader
log.Sugar().Infof("[%s] 当选 Leader!", n.selfId)
n.State = Leader
log.Sugar().Infof("[%s] 当选 Leader!", n.SelfId)
n.initLeaderState() n.initLeaderState()
} }
mu.Unlock()
Mu.Unlock()
} }
}(peerId) }(peerId)
} }
// 等待选举结果 // 等待选举结果
time.Sleep(300 * time.Millisecond) time.Sleep(300 * time.Millisecond)
mu.Lock()
if n.state == Candidate {
log.Sugar().Infof("[%s] 选举超时,重新发起选举", n.selfId)
// n.state = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower
Mu.Lock()
if n.State == Candidate {
log.Sugar().Infof("[%s] 选举超时,重新发起选举", n.SelfId)
// n.State = Follower 这里不修改,如果appendentries收到term合理的心跳,再变回follower
n.resetElectionTimer() n.resetElectionTimer()
} }
mu.Unlock()
Mu.Unlock()
} }
func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool { func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool {
log.Sugar().Infof("[%s] 请求 [%s] 投票", node.selfId, peerId)
client, err := node.transport.DialHTTPWithTimeout("tcp", peerId)
log.Sugar().Infof("[%s] 请求 [%s] 投票", node.SelfId, peerId)
client, err := node.Transport.DialHTTPWithTimeout("tcp", node.SelfId, peerId)
if err != nil { if err != nil {
log.Error(node.selfId + "dialing [" + peerId + "] fail: ", zap.Error(err))
log.Error("[" + node.SelfId + "]dialing [" + peerId + "] fail: ", zap.Error(err))
return false return false
} }
@ -117,50 +117,50 @@ func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *R
} }
}(client) }(client)
callErr := node.transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
callErr := node.Transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
if callErr != nil { if callErr != nil {
log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr))
log.Error("[" + node.SelfId + "]calling [" + peerId + "] fail: ", zap.Error(callErr))
} }
return callErr == nil return callErr == nil
} }
func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error { func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error {
n.mu.Lock()
defer n.mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
// 如果候选人的任期小于当前任期,则拒绝投票 // 如果候选人的任期小于当前任期,则拒绝投票
if args.Term < n.currTerm {
reply.Term = n.currTerm
if args.Term < n.CurrTerm {
reply.Term = n.CurrTerm
reply.VoteGranted = false reply.VoteGranted = false
return nil return nil
} }
// 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower // 如果请求的 Term 更高,则更新当前 Term 并回退为 Follower
if args.Term > n.currTerm {
n.currTerm = args.Term
n.state = Follower
n.votedFor = ""
if args.Term > n.CurrTerm {
n.CurrTerm = args.Term
n.State = Follower
n.VotedFor = ""
n.resetElectionTimer() // 重新设置选举超时 n.resetElectionTimer() // 重新设置选举超时
} }
// 检查是否已经投过票,且是否投给了同一个候选人 // 检查是否已经投过票,且是否投给了同一个候选人
if n.votedFor == "" || n.votedFor == args.CandidateId {
if n.VotedFor == "" || n.VotedFor == args.CandidateId {
// 检查日志是否足够新 // 检查日志是否足够新
var lastLogIndex int var lastLogIndex int
var lastLogTerm int var lastLogTerm int
if len(n.log) == 0 {
if len(n.Log) == 0 {
lastLogIndex = -1 lastLogIndex = -1
lastLogTerm = 0 lastLogTerm = 0
} else { } else {
lastLogIndex = len(n.log) - 1
lastLogTerm = n.log[lastLogIndex].Term
lastLogIndex = len(n.Log) - 1
lastLogTerm = n.Log[lastLogIndex].Term
} }
if args.LastLogTerm > lastLogTerm || if args.LastLogTerm > lastLogTerm ||
(args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) { (args.LastLogTerm == lastLogTerm && args.LastLogIndex >= lastLogIndex) {
// 够新就投票给候选人 // 够新就投票给候选人
n.votedFor = args.CandidateId
log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.currTerm), n.selfId, n.votedFor)
n.VotedFor = args.CandidateId
log.Sugar().Infof("在term(%s), [%s]投票给[%s]", strconv.Itoa(n.CurrTerm), n.SelfId, n.VotedFor)
reply.VoteGranted = true reply.VoteGranted = true
n.resetElectionTimer() n.resetElectionTimer()
} else { } else {
@ -170,23 +170,23 @@ func (n *Node) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) error
reply.VoteGranted = false reply.VoteGranted = false
} }
n.storage.SetTermAndVote(n.currTerm, n.votedFor)
reply.Term = n.currTerm
n.Storage.SetTermAndVote(n.CurrTerm, n.VotedFor)
reply.Term = n.CurrTerm
return nil return nil
} }
// follower 500-1000ms内没收到appendentries心跳,就变成candidate发起选举 // follower 500-1000ms内没收到appendentries心跳,就变成candidate发起选举
func (node *Node) resetElectionTimer() { func (node *Node) resetElectionTimer() {
if node.electionTimer == nil {
node.electionTimer = time.NewTimer(time.Duration(500+rand.Intn(500)) * time.Millisecond)
if node.ElectionTimer == nil {
node.ElectionTimer = time.NewTimer(time.Duration(500+rand.Intn(500)) * time.Millisecond)
go func() { go func() {
for { for {
<-node.electionTimer.C
<-node.ElectionTimer.C
node.startElection() node.startElection()
} }
}() }()
} else { } else {
node.electionTimer.Stop()
node.electionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
node.ElectionTimer.Stop()
node.ElectionTimer.Reset(time.Duration(500+rand.Intn(500)) * time.Millisecond)
} }
} }

+ 1
- 1
test/restart_node_test.go View File

@ -52,7 +52,7 @@ func TestNodeRestart(t *testing.T) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
key := strconv.Itoa(i) key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"} newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog})
if s != clientPkg.Ok { if s != clientPkg.Ok {
t.Errorf("write test fail") t.Errorf("write test fail")
} }

+ 1
- 1
test/server_client_test.go View File

@ -52,7 +52,7 @@ func TestServerClient(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
key := strconv.Itoa(i) key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"} newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s := c.Write(nodes.LogEntryCall{LogE: newlog})
if s != clientPkg.Ok { if s != clientPkg.Ok {
t.Errorf("write test fail") t.Errorf("write test fail")
} }

+ 150
- 0
threadTest/network_partition_test.go View File

@ -0,0 +1,150 @@
package threadTest
import (
"fmt"
clientPkg "simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"strings"
"testing"
"time"
)
func TestBasicConnectivity(t *testing.T) {
transport := nodes.NewThreadTransport()
transport.RegisterNodeChan("1", make(chan nodes.RPCRequest, 10))
transport.RegisterNodeChan("2", make(chan nodes.RPCRequest, 10))
// 断开 A 和 B
transport.SetConnectivity("1", "2", false)
err := transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if err == nil {
t.Errorf("Expected network partition error, but got nil")
}
// 恢复连接
transport.SetConnectivity("1", "2", true)
err = transport.CallWithTimeout(&nodes.ThreadClient{SourceId: "1", TargetId: "2"}, "Node.AppendEntries", &nodes.AppendEntriesArg{}, &nodes.AppendEntriesReply{})
if !strings.Contains(err.Error(), "RPC 调用超时") {
t.Errorf("Expected success, but got error: %v", err)
}
}
func TestElectionWithPartition(t *testing.T) {
// 登记结点信息
n := 3
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
var nodeCollections []*nodes.Node
threadTransport := nodes.NewThreadTransport()
for i := 0; i < n; i++ {
n, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
nodeCollections = append(nodeCollections, n)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(2 * time.Second) // 等待启动完毕
fmt.Println("开始分区模拟1")
var leaderNo int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderNo = i
for j := 0; j < n; j++ {
if i != j { // 切断其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[i].SelfId, false)
}
}
}
}
time.Sleep(2 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j { // 恢复其它节点到leader的消息
threadTransport.SetConnectivity(nodeCollections[j].SelfId, nodeCollections[leaderNo].SelfId, true)
}
}
time.Sleep(1 * time.Second)
var leaderCnt int
for i := 0; i < n; i++ {
if nodeCollections[i].State == nodes.Leader {
leaderCnt++
leaderNo = i
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
fmt.Println("开始分区模拟2")
for j := 0; j < n; j++ {
if leaderNo != j { // 切断leader到其它节点的消息
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, false)
}
}
time.Sleep(1 * time.Second)
if nodeCollections[leaderNo].State == nodes.Leader {
t.Errorf("分区退选失败")
}
leaderCnt = 0
for j := 0; j < n; j++ {
if nodeCollections[j].State == nodes.Leader {
leaderCnt++
}
}
if leaderCnt != 1 {
t.Errorf("多leader产生")
}
// client启动
c := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport}
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s = c.Write(nodes.LogEntryCall{LogE: newlog})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 恢复网络
for j := 0; j < n; j++ {
if leaderNo != j {
threadTransport.SetConnectivity(nodeCollections[leaderNo].SelfId, nodeCollections[j].SelfId, true)
}
}
time.Sleep(time.Second)
// 日志一致性检查
for i := 0; i < n; i++ {
if len(nodeCollections[i].Log) != 5 {
t.Errorf("日志数量不一致:" + strconv.Itoa(len(nodeCollections[i].Log)))
}
}
}

+ 1
- 1
threadTest/restart_node_test.go View File

@ -39,7 +39,7 @@ func TestNodeRestart(t *testing.T) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
key := strconv.Itoa(i) key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"} newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s = cWrite.Write(nodes.LogEntryCall{LogE: newlog})
if s != clientPkg.Ok { if s != clientPkg.Ok {
t.Errorf("write test fail") t.Errorf("write test fail")
} }

+ 1
- 1
threadTest/server_client_test.go View File

@ -40,7 +40,7 @@ func TestServerClient(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
key := strconv.Itoa(i) key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"} newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
s = c.Write(nodes.LogEntryCall{LogE: newlog})
if s != clientPkg.Ok { if s != clientPkg.Ok {
t.Errorf("write test fail") t.Errorf("write test fail")
} }

Loading…
Cancel
Save