diff --git a/cmd/main.go b/cmd/main.go index 7df8f39..60842a6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -31,7 +31,7 @@ func main() { port := flag.String("port", ":9091", "rpc listen port") cluster := flag.String("cluster", "127.0.0.1:9091,127.0.0.1:9092,127.0.0.1:9093", "comma sep") id := flag.String("id", "1", "node ID") - isRestart := flag.Bool("isRestart", true, "new test or restart") + isRestart := flag.Bool("isRestart", false, "new test or restart") // 参数解析 flag.Parse() @@ -51,7 +51,7 @@ func main() { idCnt++ } - if *isRestart { + if !*isRestart { os.RemoveAll("storage/node" + *id + ".json") } @@ -59,27 +59,24 @@ func main() { // 而用leveldb模拟状态机,造成了状态机本身的持久化,因此暂时通过删去旧db避免这一矛盾 os.RemoveAll("leveldb/simple-kv-store" + *id) - db, err := leveldb.OpenFile("leveldb/simple-kv-store"+*id, nil) + db, err := leveldb.OpenFile("leveldb/simple-kv-store" + *id, nil) if err != nil { log.Fatal("Failed to open database: ", zap.Error(err)) } defer db.Close() // 确保数据库在使用完毕后关闭 - iter := db.NewIterator(nil, nil) - defer iter.Release() // 打开或创建节点数据持久化文件 storage := nodes.NewRaftStorage("storage/node" + *id + ".json") // 初始化 - node := nodes.Init(*id, idClusterPairs, db, storage, !*isRestart) + node := nodes.InitRPCNode(*id, *port, idClusterPairs, db, storage, !*isRestart) - log.Sugar().Infof("[%s]开始监听" + *port + "端口", *id) - // 监听rpc - node.Rpc(*port) // 开启 raft - nodes.Start(node) + quitChan := make(chan struct{}, 1) + nodes.Start(node, quitChan) sig := <-sigs fmt.Println("node_"+ *id +"接收到信号:", sig) + close(quitChan) } diff --git a/internal/client/client_node.go b/internal/client/client_node.go index 6669c6f..a59a9aa 100644 --- a/internal/client/client_node.go +++ b/internal/client/client_node.go @@ -2,7 +2,6 @@ package clientPkg import ( "math/rand" - "net/rpc" "simple-kv-store/internal/logprovider" "simple-kv-store/internal/nodes" @@ -13,7 +12,8 @@ var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) type Client struct { // 连接的server端节点群 - Address map[string]string + PeerIds []string + Transport nodes.Transport } type Status = uint8 @@ -24,35 +24,28 @@ const ( Fail ) -func getRandomAddress(addressMap map[string]string) string { - keys := make([]string, 0, len(addressMap)) - - // 获取所有 key - for key := range addressMap { - keys = append(keys, key) - } - - // 随机选一个 key - randomKey := keys[rand.Intn(len(keys))] - return addressMap[randomKey] +func getRandomAddress(peerIds []string) string { + // 随机选一个 id + randomKey := peerIds[rand.Intn(len(peerIds))] + return randomKey } -func (client *Client) FindActiveNode() *rpc.Client { +func (client *Client) FindActiveNode() nodes.ClientInterface { var err error - var c *rpc.Client + var c nodes.ClientInterface for { // 直到找到一个可连接的节点(保证至少一个节点活着) - addr := getRandomAddress(client.Address) - c, err = nodes.DialHTTPWithTimeout("tcp", addr) + peerId := getRandomAddress(client.PeerIds) + c, err = client.Transport.DialHTTPWithTimeout("tcp", peerId) if err != nil { log.Error("dialing: ", zap.Error(err)) } else { - log.Sugar().Info("client发现活跃节点地址[%s]", addr) + log.Sugar().Infof("client发现活跃节点[%s]", peerId) return c } } } -func (client *Client) CloseRpcClient(c *rpc.Client) { +func (client *Client) CloseRpcClient(c nodes.ClientInterface) { err := c.Close() if err != nil { log.Error("client close err: ", zap.Error(err)) @@ -68,7 +61,7 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status { var err error for !reply.Isleader { // 根据存活节点的反馈,直到找到leader - callErr := nodes.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC + callErr := client.Transport.CallWithTimeout(c, "Node.WriteKV", &kvCall, &reply) // RPC if callErr != nil { // dial和call之间可能崩溃,重新找存活节点 log.Error("dialing: ", zap.Error(callErr)) client.CloseRpcClient(c) @@ -77,9 +70,9 @@ func (client *Client) Write(kvCall nodes.LogEntryCall) Status { } if !reply.Isleader { // 对方不是leader,根据反馈找leader - addr := reply.LeaderAddress + leaderId := reply.LeaderId client.CloseRpcClient(c) - c, err = nodes.DialHTTPWithTimeout("tcp", addr) + c, err = client.Transport.DialHTTPWithTimeout("tcp", leaderId) for err != nil { // 重新找下一个存活节点 c = client.FindActiveNode() } @@ -97,12 +90,12 @@ func (client *Client) Read(key string, value *string) Status { // 查不到value if value == nil { return Fail } - var c *rpc.Client + var c nodes.ClientInterface for { c = client.FindActiveNode() var reply nodes.ServerReply - callErr := nodes.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC + callErr := client.Transport.CallWithTimeout(c, "Node.ReadKey", &key, &reply) // RPC if callErr != nil { log.Error("dialing: ", zap.Error(callErr)) client.CloseRpcClient(c) @@ -129,7 +122,7 @@ func (client *Client) FindLeader() string { var err error for !reply.Isleader { // 根据存活节点的反馈,直到找到leader - callErr := nodes.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC + callErr := client.Transport.CallWithTimeout(c, "Node.FindLeader", &arg, &reply) // RPC if callErr != nil { // dial和call之间可能崩溃,重新找存活节点 log.Error("dialing: ", zap.Error(callErr)) client.CloseRpcClient(c) @@ -138,9 +131,8 @@ func (client *Client) FindLeader() string { } if !reply.Isleader { // 对方不是leader,根据反馈找leader - addr := client.Address[reply.LeaderId] client.CloseRpcClient(c) - c, err = nodes.DialHTTPWithTimeout("tcp", addr) + c, err = client.Transport.DialHTTPWithTimeout("tcp", reply.LeaderId) for err != nil { // 重新找下一个存活节点 c = client.FindActiveNode() } diff --git a/internal/nodes/init.go b/internal/nodes/init.go index b0f4aa8..4193d14 100644 --- a/internal/nodes/init.go +++ b/internal/nodes/init.go @@ -1,7 +1,7 @@ package nodes import ( - // "context" + "errors" "fmt" "net" "net/http" @@ -15,24 +15,18 @@ import ( var log, _ = logprovider.CreateDefaultZapLogger(zap.InfoLevel) -func newNode(address string) *Public_node_info { - return &Public_node_info{ - connect: false, - address: address, - } -} - -func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node { - ns := make(map[string]*Public_node_info) - for id, addr := range nodeAddr { - ns[id] = newNode(addr) +// 运行在进程上的初始化 + rpc注册 +func InitRPCNode(selfId string, port string, nodeAddr map[string]string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool) *Node { + var nodeIds []string + for id := range nodeAddr { + nodeIds = append(nodeIds, id) } // 创建节点 node := &Node{ selfId: selfId, leaderId: "", - nodes: ns, + nodes: nodeIds, maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 currTerm: 1, log: make([]RaftLogEntry, 0), @@ -42,6 +36,7 @@ func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *R matchIndex: make(map[string]int), db: db, storage: rstorage, + transport: &HTTPTransport{NodeMap: nodeAddr}, } node.initLeaderState() if isRestart { @@ -51,40 +46,13 @@ func Init(selfId string, nodeAddr map[string]string, db *leveldb.DB, rstorage *R log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log)) } - return node -} - -func (n *Node) initLeaderState() { - for peerId := range n.nodes { - n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引 - n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引 - } -} - -func Start(node *Node) { - node.state = Follower // 所有节点以 Follower 状态启动 - node.resetElectionTimer() // 启动选举超时定时器 + log.Sugar().Infof("[%s]开始监听" + port + "端口", selfId) + node.ListenPort(port) - go func() { - for { - switch node.state { - case Follower: - // 监听心跳超时 - fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId) - - case Leader: - // 发送心跳 - fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId) - node.resetElectionTimer() // leader不主动触发选举 - node.BroadCastKV(Normal) - } - time.Sleep(50 * time.Millisecond) - } - }() + return node } -// 初始时注册rpc方法 -func (node *Node) Rpc(port string) { +func (node *Node) ListenPort(port string) { err := rpc.Register(node) if err != nil { @@ -104,37 +72,140 @@ func (node *Node) Rpc(port string) { }() } -// 封装有超时的dial -func DialHTTPWithTimeout(network, address string) (*rpc.Client, error) { - done := make(chan struct{}) - var client *rpc.Client - var err error +// 线程模拟的初始化 +func InitThreadNode(selfId string, peerIds []string, db *leveldb.DB, rstorage *RaftStorage, isRestart bool, threadTransport *ThreadTransport) (*Node, chan struct{}) { + rpcChan := make(chan RPCRequest, 100) // 要监听的chan + // 创建节点 + node := &Node{ + selfId: selfId, + leaderId: "", + nodes: peerIds, + maxLogId: -1, // 后来发现论文中是从1开始的(初始0),但不想改了 + currTerm: 1, + log: make([]RaftLogEntry, 0), + commitIndex: -1, + lastApplied: -1, + nextIndex: make(map[string]int), + matchIndex: make(map[string]int), + db: db, + storage: rstorage, + transport: threadTransport, + } + node.initLeaderState() + if isRestart { + node.currTerm = rstorage.GetCurrentTerm() + node.votedFor = rstorage.GetVotedFor() + node.log = rstorage.GetLogEntries() + log.Sugar().Infof("[%s]从重启中恢复log数量: %d", selfId, len(node.log)) + } - go func() { - client, err = rpc.DialHTTP(network, address) - close(done) - }() + threadTransport.RegisterNodeChan(selfId, rpcChan) + quitChan := make(chan struct{}, 1) + go node.listenForChan(rpcChan, quitChan) + + return node, quitChan +} - select { - case <-done: - return client, err - case <-time.After(50 * time.Millisecond): - return nil, fmt.Errorf("dial timeout: %s", address) +func (node *Node) listenForChan(rpcChan chan RPCRequest, quitChan chan struct{}) { + defer node.db.Close() + + for { + select { + case req := <-rpcChan: + switch req.ServiceMethod { + case "Node.AppendEntries": + arg, ok := req.Args.(*AppendEntriesArg) + resp, ok2 := req.Reply.(*AppendEntriesReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for AppendEntries") + } else { + req.Done <- node.AppendEntries(arg, resp) + } + + case "Node.RequestVote": + arg, ok := req.Args.(*RequestVoteArgs) + resp, ok2 := req.Reply.(*RequestVoteReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for RequestVote") + } else { + req.Done <- node.RequestVote(arg, resp) + } + + case "Node.WriteKV": + arg, ok := req.Args.(*LogEntryCall) + resp, ok2 := req.Reply.(*ServerReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for WriteKV") + } else { + req.Done <- node.WriteKV(arg, resp) + } + + case "Node.ReadKey": + arg, ok := req.Args.(*string) + resp, ok2 := req.Reply.(*ServerReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for ReadKey") + } else { + req.Done <- node.ReadKey(arg, resp) + } + + case "Node.FindLeader": + arg, ok := req.Args.(struct{}) + resp, ok2 := req.Reply.(*FindLeaderReply) + if !ok || !ok2 { + req.Done <- errors.New("type assertion failed for FindLeader") + } else { + req.Done <- node.FindLeader(arg, resp) + } + + default: + req.Done <- fmt.Errorf("未知方法: %s", req.ServiceMethod) + } + case <-quitChan: + log.Sugar().Infof("[%s] 监听线程收到退出信号", node.selfId) + return + } + } +} + +// 共同部分和启动 +func (n *Node) initLeaderState() { + for _, peerId := range n.nodes { + n.nextIndex[peerId] = len(n.log) // 发送日志的下一个索引 + n.matchIndex[peerId] = 0 // 复制日志的最新匹配索引 } } -// 封装有超时的call -func CallWithTimeout[T1 any, T2 any](client *rpc.Client, serviceMethod string, args *T1, reply *T2) error { - done := make(chan error, 1) +func Start(node *Node, quitChan chan struct{}) { + node.state = Follower // 所有节点以 Follower 状态启动 + node.resetElectionTimer() // 启动选举超时定时器 go func() { - done <- client.Call(serviceMethod, args, reply) - }() + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() - select { - case err := <-done: - return err - case <-time.After(50 * time.Millisecond): - return fmt.Errorf("call timeout: %s", serviceMethod) - } + for { + select { + case <-quitChan: + fmt.Printf("[%s] Raft start 退出...\n", node.selfId) + return // 退出 goroutine + + case <-ticker.C: + switch node.state { + case Follower: + // 监听心跳超时 + fmt.Printf("[%s] is a follower, 监听中...\n", node.selfId) + + case Leader: + // 发送心跳 + fmt.Printf("[%s] is the leader, 发送心跳...\n", node.selfId) + node.resetElectionTimer() // leader 不主动触发选举 + node.BroadCastKV(Normal) + } + } + } + }() } + + + diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 568dc4d..0e3f053 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -15,11 +15,6 @@ const ( Leader ) -type Public_node_info struct { - connect bool - address string -} - type Node struct { mu sync.Mutex // 当前节点id @@ -27,8 +22,8 @@ type Node struct { // 记录的leader(不能用votedfor:投票的leader可能没有收到多数票) leaderId string - // 除当前节点外其他节点信息 - nodes map[string]*Public_node_info + // 除当前节点外其他节点id + nodes []string // 当前节点状态 state State @@ -61,5 +56,8 @@ type Node struct { votedFor string electionTimer *time.Timer + + // 通信方式 + transport Transport } diff --git a/internal/nodes/real_transport.go b/internal/nodes/real_transport.go new file mode 100644 index 0000000..fea359a --- /dev/null +++ b/internal/nodes/real_transport.go @@ -0,0 +1,111 @@ +package nodes + +import ( + "errors" + "fmt" + "net/rpc" + "time" +) + +// 真实rpc通讯的transport类型实现 +type HTTPTransport struct{ + NodeMap map[string]string // id到addr的映射 +} + +// 封装有超时的dial +func (t *HTTPTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) { + done := make(chan struct{}) + var client *rpc.Client + var err error + + go func() { + client, err = rpc.DialHTTP(network, t.NodeMap[peerId]) + close(done) + }() + + select { + case <-done: + return &HTTPClient{rpcClient: client}, err + case <-time.After(50 * time.Millisecond): + return nil, fmt.Errorf("dial timeout: %s", t.NodeMap[peerId]) + } +} + +func (t *HTTPTransport) CallWithTimeout(clientInterface ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { + c, ok := clientInterface.(*HTTPClient) + client := c.rpcClient + if !ok { + return fmt.Errorf("invalid client type") + } + + done := make(chan error, 1) + + go func() { + switch serviceMethod { + case "Node.AppendEntries": + arg, ok := args.(*AppendEntriesArg) + resp, ok2 := reply.(*AppendEntriesReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for AppendEntries") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.RequestVote": + arg, ok := args.(*RequestVoteArgs) + resp, ok2 := reply.(*RequestVoteReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for RequestVote") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.WriteKV": + arg, ok := args.(*LogEntryCall) + resp, ok2 := reply.(*ServerReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for WriteKV") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.ReadKey": + arg, ok := args.(*string) + resp, ok2 := reply.(*ServerReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for ReadKey") + return + } + done <- client.Call(serviceMethod, arg, resp) + + case "Node.FindLeader": + arg, ok := args.(struct{}) + resp, ok2 := reply.(*FindLeaderReply) + if !ok || !ok2 { + done <- errors.New("type assertion failed for FindLeader") + return + } + done <- client.Call(serviceMethod, arg, resp) + + default: + done <- fmt.Errorf("unknown service method: %s", serviceMethod) + } + }() + + select { + case err := <-done: + return err + case <-time.After(50 * time.Millisecond): // 设置超时时间 + return fmt.Errorf("call timeout: %s", serviceMethod) + } +} + +type HTTPClient struct { + rpcClient *rpc.Client +} + +func (h *HTTPClient) Close() error { + return h.rpcClient.Close() +} + + diff --git a/internal/nodes/replica.go b/internal/nodes/replica.go index 41be21f..49401d0 100644 --- a/internal/nodes/replica.go +++ b/internal/nodes/replica.go @@ -2,7 +2,6 @@ package nodes import ( "math/rand" - "net/rpc" "sort" "strconv" "time" @@ -27,14 +26,14 @@ type AppendEntriesReply struct { // leader收到新内容要广播,以及心跳广播(同步自己的log) func (node *Node) BroadCastKV(callMode CallMode) { // 遍历所有节点 - for id := range node.nodes { + for _, id := range node.nodes { go func(id string, kv CallMode) { node.sendKV(id, callMode) }(id, callMode) } } -func (node *Node) sendKV(id string, callMode CallMode) { +func (node *Node) sendKV(peerId string, callMode CallMode) { switch callMode { case Fail: @@ -47,13 +46,13 @@ func (node *Node) sendKV(id string, callMode CallMode) { default: } - client, err := DialHTTPWithTimeout("tcp", node.nodes[id].address) + client, err := node.transport.DialHTTPWithTimeout("tcp", peerId) if err != nil { - log.Error("dialing: ", zap.Error(err)) + log.Error(node.selfId + "dialling [" + peerId + "] fail: ", zap.Error(err)) return } - defer func(client *rpc.Client) { + defer func(client ClientInterface) { err := client.Close() if err != nil { log.Error("client close err: ", zap.Error(err)) @@ -65,7 +64,7 @@ func (node *Node) sendKV(id string, callMode CallMode) { var appendReply AppendEntriesReply appendReply.Success = false - nextIndex := node.nextIndex[id] + nextIndex := node.nextIndex[peerId] // log.Info("nextindex " + strconv.Itoa(nextIndex)) for (!appendReply.Success) { if nextIndex < 0 { @@ -82,13 +81,14 @@ func (node *Node) sendKV(id string, callMode CallMode) { if arg.PrevLogIndex >= 0 { arg.PrevLogTerm = node.log[arg.PrevLogIndex].Term } - callErr := CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC + callErr := node.transport.CallWithTimeout(client, "Node.AppendEntries", &arg, &appendReply) // RPC if callErr != nil { - log.Error("dialing node_"+ id +"fail: ", zap.Error(callErr)) + log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr)) + return } if appendReply.Term != node.currTerm { - log.Info("Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower") + log.Info("term=" + strconv.Itoa(node.currTerm) + "的Leader[" + node.selfId + "]收到更高的 term=" + strconv.Itoa(appendReply.Term) + ",转换为 Follower") node.currTerm = appendReply.Term node.state = Follower node.votedFor = "" @@ -100,8 +100,8 @@ func (node *Node) sendKV(id string, callMode CallMode) { } // 不变成follower情况下 - node.nextIndex[id] = node.maxLogId + 1 - node.matchIndex[id] = node.maxLogId + node.nextIndex[peerId] = node.maxLogId + 1 + node.matchIndex[peerId] = node.maxLogId node.updateCommitIndex() } @@ -164,7 +164,7 @@ func (node *Node) AppendEntries(arg *AppendEntriesArg, reply *AppendEntriesReply node.currTerm = arg.Term node.state = Follower node.votedFor = "" - node.storage.SetTermAndVote(node.currTerm, node.votedFor) + // node.storage.SetTermAndVote(node.currTerm, node.votedFor) } node.storage.SetTermAndVote(node.currTerm, node.votedFor) diff --git a/internal/nodes/server_node.go b/internal/nodes/server_node.go index 1e7d90f..daa1089 100644 --- a/internal/nodes/server_node.go +++ b/internal/nodes/server_node.go @@ -9,7 +9,7 @@ import ( // leader node作为server为client注册的方法 type ServerReply struct{ Isleader bool - LeaderAddress string // 自己不是leader则返回leader地址 + LeaderId string // 自己不是leader则返回leader HaveValue bool Value string } @@ -24,7 +24,7 @@ func (node *Node) WriteKV(kvCall *LogEntryCall, reply *ServerReply) error { log.Fatal("还没选出第一个leader") return nil } - reply.LeaderAddress = node.nodes[node.leaderId].address + reply.LeaderId = node.leaderId log.Sugar().Infof("[%s]转交给[%s]", node.selfId, node.leaderId) return nil } diff --git a/internal/nodes/thread_transport.go b/internal/nodes/thread_transport.go new file mode 100644 index 0000000..879577d --- /dev/null +++ b/internal/nodes/thread_transport.go @@ -0,0 +1,101 @@ +package nodes + +import ( + "fmt" + "sync" + "time" +) + +// RPC 请求结构 +type RPCRequest struct { + ServiceMethod string + Args interface{} + Reply interface{} + Done chan error // 用于返回响应 +} + +// 线程版 Transport +type ThreadTransport struct { + mu sync.Mutex + nodeChans map[string]chan RPCRequest // 每个节点的消息通道 +} + +// 线程版 dial的返回clientinterface +type ThreadClient struct { + targetId string +} + +func (c *ThreadClient) Close() error { + return nil +} + +// 初始化线程通信系统 +func NewThreadTransport() *ThreadTransport { + return &ThreadTransport{ + nodeChans: make(map[string]chan RPCRequest), + } +} + +// 注册一个新节点chan +func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) { + t.mu.Lock() + defer t.mu.Unlock() + t.nodeChans[nodeId] = ch +} + +// 获取节点的 channel +func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) { + t.mu.Lock() + defer t.mu.Unlock() + ch, ok := t.nodeChans[nodeId] + return ch, ok +} + +// 模拟 Dial 操作 +func (t *ThreadTransport) DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) { + t.mu.Lock() + defer t.mu.Unlock() + if _, exists := t.nodeChans[peerId]; !exists { + return nil, fmt.Errorf("节点 [%s] 不存在", peerId) + } + + return &ThreadClient{targetId: peerId}, nil +} + +// 模拟 Call 操作 +func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error { + threadClient, ok := client.(*ThreadClient) + if !ok { + return fmt.Errorf("无效的客户端") + } + + // 获取目标节点的 channel + targetChan, exists := t.getNodeChan(threadClient.targetId) + if !exists { + return fmt.Errorf("目标节点 [%s] 不存在", threadClient.targetId) + } + + // 创建响应通道(用于返回 RPC 结果) + done := make(chan error, 1) + + // 发送请求 + request := RPCRequest{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + + select { + case targetChan <- request: + // 等待响应或超时 + select { + case err := <-done: + return err + case <-time.After(100 * time.Millisecond): + return fmt.Errorf("RPC 调用超时: %s", serviceMethod) + } + default: + return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.targetId) + } +} diff --git a/internal/nodes/transport.go b/internal/nodes/transport.go new file mode 100644 index 0000000..40f4cdd --- /dev/null +++ b/internal/nodes/transport.go @@ -0,0 +1,10 @@ +package nodes + +type ClientInterface interface{ + Close() error +} + +type Transport interface { + DialHTTPWithTimeout(network string, peerId string) (ClientInterface, error) + CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error +} diff --git a/internal/nodes/vote.go b/internal/nodes/vote.go index cbe5654..a87a8d6 100644 --- a/internal/nodes/vote.go +++ b/internal/nodes/vote.go @@ -2,7 +2,6 @@ package nodes import ( "math/rand" - "net/rpc" "strconv" "sync" "time" @@ -59,7 +58,7 @@ func (n *Node) startElection() { totalNodes := len(n.nodes) grantedVotes := 1 // 自己的票 - for peerId := range n.nodes { + for _, peerId := range n.nodes { go func(peerId string) { reply := RequestVoteReply{} if n.sendRequestVote(peerId, &args, &reply) { @@ -104,23 +103,23 @@ func (n *Node) startElection() { } func (node *Node) sendRequestVote(peerId string, args *RequestVoteArgs, reply *RequestVoteReply) bool { - log.Sugar().Infof("[%s] 请求 [%s] 投票给自己", node.selfId, peerId) - client, err := DialHTTPWithTimeout("tcp", node.nodes[peerId].address) + log.Sugar().Infof("[%s] 请求 [%s] 投票", node.selfId, peerId) + client, err := node.transport.DialHTTPWithTimeout("tcp", peerId) if err != nil { - log.Error("dialing: ", zap.Error(err)) + log.Error(node.selfId + "dialing [" + peerId + "] fail: ", zap.Error(err)) return false } - defer func(client *rpc.Client) { + defer func(client ClientInterface) { err := client.Close() if err != nil { log.Error("client close err: ", zap.Error(err)) } }(client) - callErr := CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC + callErr := node.transport.CallWithTimeout(client, "Node.RequestVote", args, reply) // RPC if callErr != nil { - log.Error("dialing node_"+peerId+"fail: ", zap.Error(callErr)) + log.Error(node.selfId + "calling [" + peerId + "] fail: ", zap.Error(callErr)) } return callErr == nil } diff --git a/test/restart_node_test.go b/test/restart_node_test.go index a82f465..a9136f6 100644 --- a/test/restart_node_test.go +++ b/test/restart_node_test.go @@ -15,18 +15,20 @@ func TestNodeRestart(t *testing.T) { // 登记结点信息 n := 5 var clusters []string + var peerIds []string addressMap := make(map[string]string) for i := 0; i < n; i++ { port := fmt.Sprintf("%d", uint16(9090)+uint16(i)) addr := "127.0.0.1:" + port clusters = append(clusters, addr) addressMap[strconv.Itoa(i + 1)] = addr + peerIds = append(peerIds, strconv.Itoa(i + 1)) } // 结点启动 var cmds []*exec.Cmd for i := 0; i < n; i++ { - cmd := ExecuteNodeI(i, true, clusters) + cmd := ExecuteNodeI(i, false, clusters) cmds = append(cmds, cmd) } @@ -43,7 +45,7 @@ func TestNodeRestart(t *testing.T) { time.Sleep(time.Second) // 等待启动完毕 // client启动, 连接任意节点 - cWrite := clientPkg.Client{Address: addressMap} + cWrite := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}} // 写入 var s clientPkg.Status @@ -66,7 +68,7 @@ func TestNodeRestart(t *testing.T) { } time.Sleep(time.Second) - cmd := ExecuteNodeI(i, false, clusters) + cmd := ExecuteNodeI(i, true, clusters) if cmd == nil { t.Errorf("recover test1 fail") return @@ -78,7 +80,7 @@ func TestNodeRestart(t *testing.T) { // client启动 - cRead := clientPkg.Client{Address: addressMap} + cRead := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}} // 读写入数据 for i := 0; i < 5; i++ { key := strconv.Itoa(i) diff --git a/test/server_client_test.go b/test/server_client_test.go index 88050e8..822c548 100644 --- a/test/server_client_test.go +++ b/test/server_client_test.go @@ -15,18 +15,20 @@ func TestServerClient(t *testing.T) { // 登记结点信息 n := 5 var clusters []string + var peerIds []string addressMap := make(map[string]string) for i := 0; i < n; i++ { port := fmt.Sprintf("%d", uint16(9090)+uint16(i)) addr := "127.0.0.1:" + port clusters = append(clusters, addr) addressMap[strconv.Itoa(i + 1)] = addr + peerIds = append(peerIds, strconv.Itoa(i + 1)) } // 结点启动 var cmds []*exec.Cmd for i := 0; i < n; i++ { - cmd := ExecuteNodeI(i, true, clusters) + cmd := ExecuteNodeI(i, false, clusters) cmds = append(cmds, cmd) } @@ -43,7 +45,7 @@ func TestServerClient(t *testing.T) { time.Sleep(time.Second) // 等待启动完毕 // client启动 - c := clientPkg.Client{Address: addressMap} + c := clientPkg.Client{PeerIds: peerIds, Transport: &nodes.HTTPTransport{NodeMap: addressMap}} // 写入 var s clientPkg.Status @@ -56,6 +58,7 @@ func TestServerClient(t *testing.T) { } } + time.Sleep(time.Second) // 等待写入完毕 // 读写入数据 for i := 0; i < 10; i++ { key := strconv.Itoa(i) diff --git a/threadTest/common.go b/threadTest/common.go new file mode 100644 index 0000000..b276305 --- /dev/null +++ b/threadTest/common.go @@ -0,0 +1,38 @@ +package threadTest + +import ( + "fmt" + "os" + "simple-kv-store/internal/nodes" + + "github.com/syndtr/goleveldb/leveldb" +) + +func ExecuteNodeI(id string, isRestart bool, peerIds []string, threadTransport *nodes.ThreadTransport) (*nodes.Node, chan struct{}) { + if !isRestart { + os.RemoveAll("storage/node" + id + ".json") + } + + os.RemoveAll("leveldb/simple-kv-store" + id) + + db, err := leveldb.OpenFile("leveldb/simple-kv-store"+id, nil) + if err != nil { + fmt.Println("Failed to open database: ", err) + } + + // 打开或创建节点数据持久化文件 + storage := nodes.NewRaftStorage("storage/node" + id + ".json") + + var otherIds []string + for _, ids := range peerIds { + if ids != id { + otherIds = append(otherIds, ids) // 删除目标元素 + } + } + // 初始化 + node, quitChan := nodes.InitThreadNode(id, otherIds, db, storage, isRestart, threadTransport) + + // 开启 raft + go nodes.Start(node, quitChan) + return node, quitChan +} diff --git a/threadTest/restart_node_test.go b/threadTest/restart_node_test.go new file mode 100644 index 0000000..2f57593 --- /dev/null +++ b/threadTest/restart_node_test.go @@ -0,0 +1,81 @@ +package threadTest + +import ( + "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestNodeRestart(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + threadTransport := nodes.NewThreadTransport() + for i := 0; i < n; i++ { + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动, 连接任意节点 + cWrite := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport} + // 写入 + var s clientPkg.Status + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s := cWrite.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + time.Sleep(time.Second) // 等待写入完毕 + + // 模拟结点轮流崩溃 + for i := 0; i < n; i++ { + close(quitCollections[i]) + + time.Sleep(time.Second) + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), true, peerIds, threadTransport) + quitCollections[i] = quitChan + time.Sleep(time.Second) // 等待启动完毕 + } + + + // client启动 + cRead := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport} + // 读写入数据 + for i := 0; i < 5; i++ { + key := strconv.Itoa(i) + var value string + s = cRead.Read(key, &value) + if s != clientPkg.Ok { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 5; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s = cRead.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } +} diff --git a/threadTest/server_client_test.go b/threadTest/server_client_test.go new file mode 100644 index 0000000..7768e66 --- /dev/null +++ b/threadTest/server_client_test.go @@ -0,0 +1,69 @@ +package threadTest + +import ( + "simple-kv-store/internal/client" + "simple-kv-store/internal/nodes" + "strconv" + "testing" + "time" +) + +func TestServerClient(t *testing.T) { + // 登记结点信息 + n := 5 + var peerIds []string + for i := 0; i < n; i++ { + peerIds = append(peerIds, strconv.Itoa(i + 1)) + } + + // 结点启动 + var quitCollections []chan struct{} + threadTransport := nodes.NewThreadTransport() + for i := 0; i < n; i++ { + _, quitChan := ExecuteNodeI(strconv.Itoa(i + 1), false, peerIds, threadTransport) + quitCollections = append(quitCollections, quitChan) + } + + // 通知所有node结束 + defer func(){ + for _, quitChan := range quitCollections { + close(quitChan) + } + }() + + time.Sleep(time.Second) // 等待启动完毕 + // client启动 + c := clientPkg.Client{PeerIds: peerIds, Transport: threadTransport} + + // 写入 + var s clientPkg.Status + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + newlog := nodes.LogEntry{Key: key, Value: "hello"} + s := c.Write(nodes.LogEntryCall{LogE: newlog, CallState: nodes.Normal}) + if s != clientPkg.Ok { + t.Errorf("write test fail") + } + } + + time.Sleep(time.Second) // 等待写入完毕 + // 读写入数据 + for i := 0; i < 10; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.Ok || value != "hello" { + t.Errorf("Read test1 fail") + } + } + + // 读未写入数据 + for i := 10; i < 15; i++ { + key := strconv.Itoa(i) + var value string + s = c.Read(key, &value) + if s != clientPkg.NotFound { + t.Errorf("Read test2 fail") + } + } +}