Browse Source

增加transport接口,支持线程模拟

ld
augurier 5 months ago
parent
commit
0b7590ee2a
15 changed files with 614 additions and 142 deletions
  1. +7
    -10
      cmd/main.go
  2. +19
    -27
      internal/client/client_node.go
  3. +140
    -69
      internal/nodes/init.go
  4. +5
    -7
      internal/nodes/node.go
  5. +111
    -0
      internal/nodes/real_transport.go
  6. +13
    -13
      internal/nodes/replica.go
  7. +2
    -2
      internal/nodes/server_node.go
  8. +101
    -0
      internal/nodes/thread_transport.go
  9. +10
    -0
      internal/nodes/transport.go
  10. +7
    -8
      internal/nodes/vote.go
  11. +6
    -4
      test/restart_node_test.go
  12. +5
    -2
      test/server_client_test.go
  13. +38
    -0
      threadTest/common.go
  14. +81
    -0
      threadTest/restart_node_test.go
  15. +69
    -0
      threadTest/server_client_test.go

+ 7
- 10
cmd/main.go View File

@ -31,7 +31,7 @@ func main() {
port := flag.String("port", ":9091", "rpc listen port")
cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep")
id := flag.String("id", "1", "node ID")
isRestart := flag.Bool("isRestart", true, "new test or restart")
isRestart := flag.Bool("isRestart", false, "new test or restart")
// 参数解析
flag.Parse()
@ -51,7 +51,7 @@ func main() {
idCnt++
}
if *isRestart {
if !*isRestart {
os.RemoveAll("storage/node" + *id + ".json")
}
@ -59,27 +59,24 @@ func main() {
// 而用leveldb模拟状态机,造成了状态机本身的持久化,因此暂时通过删去旧db避免这一矛盾
os.RemoveAll("leveldb/simple-kv-store" + *id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store"+*id, nil)
db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil)
if err != nil {
log.Fatal("Failed to open database: ", zap.Error(err))
}
defer db.Close() // 确保数据库在使用完毕后关闭
iter := db.NewIterator(nil, nil)
defer iter.Release()
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + *id + ".json")
// 初始化
node := nodes.Init(*id, idClusterPairs, db, storage, !*isRestart)
node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, !*isRestart)
log.Sugar().Infof("[%s]开始监听" + *port + "端口", *id)
// 监听rpc
node.Rpc(*port)
// 开启 raft
nodes.Start(node)
quitChan := make(chan struct{}, 1)
nodes.Start(node, quitChan)
sig := <-sigs
fmt.Println("node_"+ *id +"接收到信号:", sig)
close(quitChan)
}

+ 19
- 27
internal/client/client_node.go View File

@ -2,7 +2,6 @@ package clientPkg
import (
"math/rand"
"net/rpc"
"simple-kv-store/internal/logprovider"
"simple-kv-store/internal/nodes"
@ -13,7 +12,8 @@ var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
type Client struct {
// 连接的server端节点群
Address map[string]string
PeerIds []string
Transport nodes.Transport
}
type Status = uint8
@ -24,35 +24,28 @@ const (
Fail
)
func getRandomAddress(addressMap map[string]string) string {
keys := make([]string, 0, len(addressMap))
// 获取所有 key
for key := range addressMap {
keys = append(keys, key)
}
// 随机选一个 key
randomKey := keys[rand.Intn(len(keys))]
return addressMap[randomKey]
func getRandomAddress(peerIds []string) string {
// 随机选一个 id
randomKey := peerIds[rand.Intn(len(peerIds))]
return randomKey
}
func (client *Client) FindActiveNode() *rpc.Client {
func (client *Client) FindActiveNode() nodes.ClientInterface {
var err error
var c *rpc.Client
var c nodes.ClientInterface
for { // 直到找到一个可连接的节点(保证至少一个节点活着)
addr := getRandomAddress(client.Address)
c, err = nodes.DialHTTPWithTimeout("tcp", addr)
peerId := getRandomAddress(client.PeerIds)
c, err = client.Transport.DialHTTPWithTimeout("tcp", peerId)
if err != nil {
log.Error("dialing: ", zap.Error(err))
} else {
log.Sugar().Info("client发现活跃节点地址[%s]", addr)
log.Sugar().Infof("client发现活跃节点[%s]", peerId)
return c
}
}
}
func (client *Client) CloseRpcClient(c *rpc.Client) {
func (client *Client) CloseRpcClient(c nodes.ClientInterface) {
err := c.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
@ -68,7 +61,7 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status {
var err error
for !reply.Isleader { // 根据存活节点的反馈,直到找到leader
callErr := nodes.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
@ -77,9 +70,9 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status {
}
if !reply.Isleader { // 对方不是leader,根据反馈找leader
addr := reply.LeaderAddress
leaderId := reply.LeaderId
client.CloseRpcClient(c)
c, err = nodes.DialHTTPWithTimeout("tcp", addr)
c, err = client.Transport.DialHTTPWithTimeout("tcp", leaderId)
for err != nil { // 重新找下一个存活节点
c = client.FindActiveNode()
}
@ -97,12 +90,12 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value
if value == nil {
return Fail
}
var c *rpc.Client
var c nodes.ClientInterface
for {
c = client.FindActiveNode()
var reply nodes.ServerReply
callErr := nodes.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC
callErr := client.Transport.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC
if callErr != nil {
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
@ -129,7 +122,7 @@ func (client *Client) FindLeader() string {
var err error
for !reply.Isleader { // 根据存活节点的反馈,直到找到leader
callErr := nodes.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC
callErr := client.Transport.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC
if callErr != nil { // dial和call之间可能崩溃,重新找存活节点
log.Error("dialing: ", zap.Error(callErr))
client.CloseRpcClient(c)
@ -138,9 +131,8 @@ func (client *Client) FindLeader() string {
}
if !reply.Isleader { // 对方不是leader,根据反馈找leader
addr := client.Address[reply.LeaderId]
client.CloseRpcClient(c)
c, err = nodes.DialHTTPWithTimeout("tcp", addr)
c, err = client.Transport.DialHTTPWithTimeout("tcp", reply.LeaderId)
for err != nil { // 重新找下一个存活节点
c = client.FindActiveNode()
}

+ 140
- 69
internal/nodes/init.go View File

@ -1,7 +1,7 @@
package nodes
import (
// "context"
"errors"
"fmt"
"net"
"net/http"
@ -15,24 +15,18 @@ import (
var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel)
func newNode(address string) *Public_node_info {
return &Public_node_info{
connect: false,
address: address,
}
}
func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
ns := make(map[string]*Public_node_info)
for id, addr := range nodeAddr {
ns[id] = newNode(addr)
// 运行在进程上的初始化 + rpc注册
func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node {
var nodeIds []string
for id := range nodeAddr {
nodeIds = append(nodeIds, id)
}
// 创建节点
node := &Node{
selfId: selfId,
leaderId: "",
nodes: ns,
nodes: nodeIds,
maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
currTerm: 1,
log: make([]RaftLogEntry, 0),
@ -42,6 +36,7 @@ func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *R
matchIndex: make(map[string]int),
db: db,
storage: rstorage,
transport: &HTTPTransport{NodeMap: nodeAddr},
}
node.initLeaderState()
if isRestart {
@ -51,40 +46,13 @@ func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *R
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log))
}
return node
}
func (n *Node) initLeaderState() {
for peerId := range n.nodes {
n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引
n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引
}
}
func Start(node *Node) {
node.state = Follower // 所有节点以 Follower 状态启动
node.resetElectionTimer() // 启动选举超时定时器
log.Sugar().Infof("[%s]开始监听" + port + "端口", selfId)
node.ListenPort(port)
go func() {
for {
switch node.state {
case Follower:
// 监听心跳超时
fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId)
case Leader:
// 发送心跳
fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId)
node.resetElectionTimer() // leader不主动触发选举
node.BroadCastKV(Normal)
}
time.Sleep(50 * time.Millisecond)
}
}()
return node
}
// 初始时注册rpc方法
func (node *Node) Rpc(port string) {
func (node *Node) ListenPort(port string) {
err := rpc.Register(node)
if err != nil {
@ -104,37 +72,140 @@ func (node *Node) Rpc(port string) {
}()
}
// 封装有超时的dial
func DialHTTPWithTimeout(network, address string) (*rpc.Client, error) {
done := make(chan struct{})
var client *rpc.Client
var err error
// 线程模拟的初始化
func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) {
rpcChan := make(chan RPCRequest, 100) // 要监听的chan
// 创建节点
node := &Node{
selfId: selfId,
leaderId: "",
nodes: peerIds,
maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了
currTerm: 1,
log: make([]RaftLogEntry, 0),
commitIndex: -1,
lastApplied: -1,
nextIndex: make(map[string]int),
matchIndex: make(map[string]int),
db: db,
storage: rstorage,
transport: threadTransport,
}
node.initLeaderState()
if isRestart {
node.currTerm = rstorage.GetCurrentTerm()
node.votedFor = rstorage.GetVotedFor()
node.log = rstorage.GetLogEntries()
log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log))
}
go func() {
client, err = rpc.DialHTTP(network, address)
close(done)
}()
threadTransport.RegisterNodeChan(selfId, rpcChan)
quitChan := make(chan struct{}, 1)
go node.listenForChan(rpcChan, quitChan)
return node, quitChan
}
select {
case <-done:
return client, err
case <-time.After(50 * time.Millisecond):
return nil, fmt.Errorf("dial timeout: %s", address)
func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) {
defer node.db.Close()
for {
select {
case req := <-rpcChan:
switch req.ServiceMethod {
case "Node.AppendEntries":
arg, ok := req.Args.(*AppendEntriesArg)
resp, ok2 := req.Reply.(*AppendEntriesReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for AppendEntries")
} else {
req.Done <- node.AppendEntries(arg, resp)
}
case "Node.RequestVote":
arg, ok := req.Args.(*RequestVoteArgs)
resp, ok2 := req.Reply.(*RequestVoteReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for RequestVote")
} else {
req.Done <- node.RequestVote(arg, resp)
}
case "Node.WriteKV":
arg, ok := req.Args.(*LogEntryCall)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for WriteKV")
} else {
req.Done <- node.WriteKV(arg, resp)
}
case "Node.ReadKey":
arg, ok := req.Args.(*string)
resp, ok2 := req.Reply.(*ServerReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for ReadKey")
} else {
req.Done <- node.ReadKey(arg, resp)
}
case "Node.FindLeader":
arg, ok := req.Args.(struct{})
resp, ok2 := req.Reply.(*FindLeaderReply)
if !ok || !ok2 {
req.Done <- errors.New("type assertion failed for FindLeader")
} else {
req.Done <- node.FindLeader(arg, resp)
}
default:
req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod)
}
case <-quitChan:
log.Sugar().Infof("[%s] 监听线程收到退出信号", node.selfId)
return
}
}
}
// 共同部分和启动
func (n *Node) initLeaderState() {
for _, peerId := range n.nodes {
n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引
n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引
}
}
// 封装有超时的call
func CallWithTimeout[T1 any, T2 any](client *rpc.Client, serviceMethod string, args *T1, reply *T2) error {
done := make(chan error, 1)
func Start(node *Node, quitChan chan struct{}) {
node.state = Follower // 所有节点以 Follower 状态启动
node.resetElectionTimer() // 启动选举超时定时器
go func() {
done <- client.Call(serviceMethod, args, reply)
}()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
select {
case err := <-done:
return err
case <-time.After(50 * time.Millisecond):
return fmt.Errorf("call timeout: %s", serviceMethod)
}
for {
select {
case <-quitChan:
fmt.Printf("[%s] Raft start 退出...\n", node.selfId)
return // 退出 goroutine
case <-ticker.C:
switch node.state {
case Follower:
// 监听心跳超时
fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId)
case Leader:
// 发送心跳
fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId)
node.resetElectionTimer() // leader 不主动触发选举
node.BroadCastKV(Normal)
}
}
}
}()
}

+ 5
- 7
internal/nodes/node.go View File

@ -15,11 +15,6 @@ const (
Leader
)
type Public_node_info struct {
connect bool
address string
}
type Node struct {
mu sync.Mutex
// 当前节点id
@ -27,8 +22,8 @@ type Node struct {
// 记录的leader(不能用votedfor:投票的leader可能没有收到多数票)
leaderId string
// 除当前节点外其他节点信息
nodes map[string]*Public_node_info
// 除当前节点外其他节点id
nodes []string
// 当前节点状态
state State
@ -61,5 +56,8 @@ type Node struct {
votedFor string
electionTimer *time.Timer
// 通信方式
transport Transport
}

+ 111
- 0
internal/nodes/real_transport.go View File

@ -0,0 +1,111 @@
package nodes
import (
"errors"
"fmt"
"net/rpc"
"time"
)
// 真实rpc通讯的transport类型实现
type HTTPTransport struct{
NodeMap map[string]string // id到addr的映射
}
// 封装有超时的dial
func (t *HTTPTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) {
done := make(chan struct{})
var client *rpc.Client
var err error
go func() {
client, err = rpc.DialHTTP(network, t.NodeMap[peerId])
close(done)
}()
select {
case <-done:
return &HTTPClient{rpcClient: client}, err
case <-time.After(50 * time.Millisecond):
return nil, fmt.Errorf("dial timeout: %s", t.NodeMap[peerId])
}
}
func (t *HTTPTransport) CallWithTimeout(clientInterface ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
c, ok := clientInterface.(*HTTPClient)
client := c.rpcClient
if !ok {
return fmt.Errorf("invalid client type")
}
done := make(chan error, 1)
go func() {
switch serviceMethod {
case "Node.AppendEntries":
arg, ok := args.(*AppendEntriesArg)
resp, ok2 := reply.(*AppendEntriesReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for AppendEntries")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.RequestVote":
arg, ok := args.(*RequestVoteArgs)
resp, ok2 := reply.(*RequestVoteReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for RequestVote")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.WriteKV":
arg, ok := args.(*LogEntryCall)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for WriteKV")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.ReadKey":
arg, ok := args.(*string)
resp, ok2 := reply.(*ServerReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for ReadKey")
return
}
done <- client.Call(serviceMethod, arg, resp)
case "Node.FindLeader":
arg, ok := args.(struct{})
resp, ok2 := reply.(*FindLeaderReply)
if !ok || !ok2 {
done <- errors.New("type assertion failed for FindLeader")
return
}
done <- client.Call(serviceMethod, arg, resp)
default:
done <- fmt.Errorf("unknown service method: %s", serviceMethod)
}
}()
select {
case err := <-done:
return err
case <-time.After(50 * time.Millisecond): // 设置超时时间
return fmt.Errorf("call timeout: %s", serviceMethod)
}
}
type HTTPClient struct {
rpcClient *rpc.Client
}
func (h *HTTPClient) Close() error {
return h.rpcClient.Close()
}

+ 13
- 13
internal/nodes/replica.go View File

@ -2,7 +2,6 @@ package nodes
import (
"math/rand"
"net/rpc"
"sort"
"strconv"
"time"
@ -27,14 +26,14 @@ type AppendEntriesReply struct {
// leader收到新内容要广播,以及心跳广播(同步自己的log)
func (node *Node) BroadCastKV(callMode CallMode) {
// 遍历所有节点
for id := range node.nodes {
for _, id := range node.nodes {
go func(id string, kv CallMode) {
node.sendKV(id, callMode)
}(id, callMode)
}
}
func (node *Node) sendKV(id string, callMode CallMode) {
func (node *Node) sendKV(peerId string, callMode CallMode) {
switch callMode {
case Fail:
@ -47,13 +46,13 @@ func (node *Node) sendKV(id string, callMode CallMode) {
default:
}
client, err := DialHTTPWithTimeout("tcp", node.nodes[id].address)
client, err := node.transport.DialHTTPWithTimeout("tcp", peerId)
if err != nil {
log.Error("dialing: ", zap.Error(err))
log.Error(node.selfId + "dialling [" + peerId + "] fail: ", zap.Error(err))
return
}
defer func(client *rpc.Client) {
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
@ -65,7 +64,7 @@ func (node *Node) sendKV(id string, callMode CallMode) {
var appendReply AppendEntriesReply
appendReply.Success = false
nextIndex := node.nextIndex[id]
nextIndex := node.nextIndex[peerId]
// log.Info("nextindex " + strconv.Itoa(nextIndex))
for (!appendReply.Success) {
if nextIndex < 0 {
@ -82,13 +81,14 @@ func (node *Node) sendKV(id string, callMode CallMode) {
if arg.PrevLogIndex >= 0 {
arg.PrevLogTerm = node.log[arg.PrevLogIndex].Term
}
callErr := CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
callErr := node.transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC
if callErr != nil {
log.Error("dialing node_"+ id +"fail: ", zap.Error(callErr))
log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr))
return
}
if appendReply.Term != node.currTerm {
log.Info("Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower")
log.Info("term=" + strconv.Itoa(node.currTerm) + "的Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower")
node.currTerm = appendReply.Term
node.state = Follower
node.votedFor = ""
@ -100,8 +100,8 @@ func (node *Node) sendKV(id string, callMode CallMode) {
}
// 不变成follower情况下
node.nextIndex[id] = node.maxLogId + 1
node.matchIndex[id] = node.maxLogId
node.nextIndex[peerId] = node.maxLogId + 1
node.matchIndex[peerId] = node.maxLogId
node.updateCommitIndex()
}
@ -164,7 +164,7 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply
node.currTerm = arg.Term
node.state = Follower
node.votedFor = ""
node.storage.SetTermAndVote(node.currTerm, node.votedFor)
// node.storage.SetTermAndVote(node.currTerm, node.votedFor)
}
node.storage.SetTermAndVote(node.currTerm, node.votedFor)

+ 2
- 2
internal/nodes/server_node.go View File

@ -9,7 +9,7 @@ import (
// leader node作为server为client注册的方法
type ServerReply struct{
Isleader bool
LeaderAddress string // 自己不是leader则返回leader地址
LeaderId string // 自己不是leader则返回leader
HaveValue bool
Value string
}
@ -24,7 +24,7 @@ func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error {
log.Fatal("还没选出第一个leader")
return nil
}
reply.LeaderAddress = node.nodes[node.leaderId].address
reply.LeaderId = node.leaderId
log.Sugar().Infof("[%s]转交给[%s]", node.selfId, node.leaderId)
return nil
}

+ 101
- 0
internal/nodes/thread_transport.go View File

@ -0,0 +1,101 @@
package nodes
import (
"fmt"
"sync"
"time"
)
// RPC 请求结构
type RPCRequest struct {
ServiceMethod string
Args interface{}
Reply interface{}
Done chan error // 用于返回响应
}
// 线程版 Transport
type ThreadTransport struct {
mu sync.Mutex
nodeChans map[string]chan RPCRequest // 每个节点的消息通道
}
// 线程版 dial的返回clientinterface
type ThreadClient struct {
targetId string
}
func (c *ThreadClient) Close() error {
return nil
}
// 初始化线程通信系统
func NewThreadTransport() *ThreadTransport {
return &ThreadTransport{
nodeChans: make(map[string]chan RPCRequest),
}
}
// 注册一个新节点chan
func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) {
t.mu.Lock()
defer t.mu.Unlock()
t.nodeChans[nodeId] = ch
}
// 获取节点的 channel
func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) {
t.mu.Lock()
defer t.mu.Unlock()
ch, ok := t.nodeChans[nodeId]
return ch, ok
}
// 模拟 Dial 操作
func (t *ThreadTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) {
t.mu.Lock()
defer t.mu.Unlock()
if _, exists := t.nodeChans[peerId]; !exists {
return nil, fmt.Errorf("节点 [%s] 不存在", peerId)
}
return &ThreadClient{targetId: peerId}, nil
}
// 模拟 Call 操作
func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
threadClient, ok := client.(*ThreadClient)
if !ok {
return fmt.Errorf("无效的客户端")
}
// 获取目标节点的 channel
targetChan, exists := t.getNodeChan(threadClient.targetId)
if !exists {
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.targetId)
}
// 创建响应通道(用于返回 RPC 结果)
done := make(chan error, 1)
// 发送请求
request := RPCRequest{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
select {
case targetChan <- request:
// 等待响应或超时
select {
case err := <-done:
return err
case <-time.After(100 * time.Millisecond):
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
}
default:
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.targetId)
}
}

+ 10
- 0
internal/nodes/transport.go View File

@ -0,0 +1,10 @@
package nodes
type ClientInterface interface{
Close() error
}
type Transport interface {
DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error)
CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error
}

+ 7
- 8
internal/nodes/vote.go View File

@ -2,7 +2,6 @@ package nodes
import (
"math/rand"
"net/rpc"
"strconv"
"sync"
"time"
@ -59,7 +58,7 @@ func (n *Node) startElection() {
totalNodes := len(n.nodes)
grantedVotes := 1 // 自己的票
for peerId := range n.nodes {
for _, peerId := range n.nodes {
go func(peerId string) {
reply := RequestVoteReply{}
if n.sendRequestVote(peerId, &args, &reply) {
@ -104,23 +103,23 @@ func (n *Node) startElection() {
}
func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool {
log.Sugar().Infof("[%s] 请求 [%s] 投票给自己", node.selfId, peerId)
client, err := DialHTTPWithTimeout("tcp", node.nodes[peerId].address)
log.Sugar().Infof("[%s] 请求 [%s] 投票", node.selfId, peerId)
client, err := node.transport.DialHTTPWithTimeout("tcp", peerId)
if err != nil {
log.Error("dialing: ", zap.Error(err))
log.Error(node.selfId + "dialing [" + peerId + "] fail: ", zap.Error(err))
return false
}
defer func(client *rpc.Client) {
defer func(client ClientInterface) {
err := client.Close()
if err != nil {
log.Error("client close err: ", zap.Error(err))
}
}(client)
callErr := CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
callErr := node.transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC
if callErr != nil {
log.Error("dialing node_"+peerId+"fail: ", zap.Error(callErr))
log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr))
}
return callErr == nil
}

+ 6
- 4
test/restart_node_test.go View File

@ -15,18 +15,20 @@ func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
cmd := ExecuteNodeI(i, true, clusters)
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
@ -43,7 +45,7 @@ func TestNodeRestart(t *testing.T) {
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
cWrite := clientPkg.Client{Address: addressMap}
cWrite := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}}
// 写入
var s clientPkg.Status
@ -66,7 +68,7 @@ func TestNodeRestart(t *testing.T) {
}
time.Sleep(time.Second)
cmd := ExecuteNodeI(i, false, clusters)
cmd := ExecuteNodeI(i, true, clusters)
if cmd == nil {
t.Errorf("recover test1 fail")
return
@ -78,7 +80,7 @@ func TestNodeRestart(t *testing.T) {
// client启动
cRead := clientPkg.Client{Address: addressMap}
cRead := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}}
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)

+ 5
- 2
test/server_client_test.go View File

@ -15,18 +15,20 @@ func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var clusters []string
var peerIds []string
addressMap := make(map[string]string)
for i := 0; i < n; i++ {
port := fmt.Sprintf("%d", uint16(9090)+uint16(i))
addr := "127.0.0.1:" + port
clusters = append(clusters, addr)
addressMap[strconv.Itoa(i + 1)] = addr
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var cmds []*exec.Cmd
for i := 0; i < n; i++ {
cmd := ExecuteNodeI(i, true, clusters)
cmd := ExecuteNodeI(i, false, clusters)
cmds = append(cmds, cmd)
}
@ -43,7 +45,7 @@ func TestServerClient(t *testing.T) {
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.Client{Address: addressMap}
c := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}}
// 写入
var s clientPkg.Status
@ -56,6 +58,7 @@ func TestServerClient(t *testing.T) {
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)

+ 38
- 0
threadTest/common.go View File

@ -0,0 +1,38 @@
package threadTest
import (
"fmt"
"os"
"simple-kv-store/internal/nodes"
"github.com/syndtr/goleveldb/leveldb"
)
func ExecuteNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) {
if !isRestart {
os.RemoveAll("storage/node" + id + ".json")
}
os.RemoveAll("leveldb/simple-kv-store" + id)
db, err := leveldb.OpenFile("leveldb/simple-kv-store"+id, nil)
if err != nil {
fmt.Println("Failed to open database: ", err)
}
// 打开或创建节点数据持久化文件
storage := nodes.NewRaftStorage("storage/node" + id + ".json")
var otherIds []string
for _, ids := range peerIds {
if ids != id {
otherIds = append(otherIds, ids) // 删除目标元素
}
}
// 初始化
node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport)
// 开启 raft
go nodes.Start(node, quitChan)
return node, quitChan
}

+ 81
- 0
threadTest/restart_node_test.go View File

@ -0,0 +1,81 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestNodeRestart(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport()
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动, 连接任意节点
cWrite := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport}
// 写入
var s clientPkg.Status
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 模拟结点轮流崩溃
for i := 0; i < n; i++ {
close(quitCollections[i])
time.Sleep(time.Second)
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), true, peerIds, threadTransport)
quitCollections[i] = quitChan
time.Sleep(time.Second) // 等待启动完毕
}
// client启动
cRead := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport}
// 读写入数据
for i := 0; i < 5; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.Ok {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 5; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = cRead.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}

+ 69
- 0
threadTest/server_client_test.go View File

@ -0,0 +1,69 @@
package threadTest
import (
"simple-kv-store/internal/client"
"simple-kv-store/internal/nodes"
"strconv"
"testing"
"time"
)
func TestServerClient(t *testing.T) {
// 登记结点信息
n := 5
var peerIds []string
for i := 0; i < n; i++ {
peerIds = append(peerIds, strconv.Itoa(i + 1))
}
// 结点启动
var quitCollections []chan struct{}
threadTransport := nodes.NewThreadTransport()
for i := 0; i < n; i++ {
_, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport)
quitCollections = append(quitCollections, quitChan)
}
// 通知所有node结束
defer func(){
for _, quitChan := range quitCollections {
close(quitChan)
}
}()
time.Sleep(time.Second) // 等待启动完毕
// client启动
c := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport}
// 写入
var s clientPkg.Status
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
newlog := nodes.LogEntry{Key: key, Value: "hello"}
s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal})
if s != clientPkg.Ok {
t.Errorf("write test fail")
}
}
time.Sleep(time.Second) // 等待写入完毕
// 读写入数据
for i := 0; i < 10; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.Ok || value != "hello" {
t.Errorf("Read test1 fail")
}
}
// 读未写入数据
for i := 10; i < 15; i++ {
key := strconv.Itoa(i)
var value string
s = c.Read(key, &value)
if s != clientPkg.NotFound {
t.Errorf("Read test2 fail")
}
}
}

Loading…
Cancel
Save