package nodes
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type CallBehavior = uint8
|
|
|
|
const (
|
|
NormalRpc CallBehavior = iota + 1
|
|
DelayRpc
|
|
RetryRpc
|
|
FailRpc
|
|
)
|
|
|
|
// RPC 请求结构
|
|
type RPCRequest struct {
|
|
ServiceMethod string
|
|
Args interface{}
|
|
Reply interface{}
|
|
|
|
Done chan error // 用于返回响应
|
|
|
|
SourceId string
|
|
// 模拟rpc请求状态
|
|
Behavior CallBehavior
|
|
}
|
|
|
|
// 线程版 Transport
|
|
type ThreadTransport struct {
|
|
mu sync.Mutex
|
|
nodeChans map[string]chan RPCRequest // 每个节点的消息通道
|
|
connectivityMap map[string]map[string]bool // 模拟网络分区
|
|
Ctx *Ctx
|
|
}
|
|
|
|
// 线程版 dial的返回clientinterface
|
|
type ThreadClient struct {
|
|
SourceId string
|
|
TargetId string
|
|
}
|
|
|
|
func (c *ThreadClient) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// 初始化线程通信系统
|
|
func NewThreadTransport(ctx *Ctx) *ThreadTransport {
|
|
return &ThreadTransport{
|
|
nodeChans: make(map[string]chan RPCRequest),
|
|
connectivityMap: make(map[string]map[string]bool),
|
|
Ctx: ctx,
|
|
}
|
|
}
|
|
|
|
// 注册一个新节点chan
|
|
func (t *ThreadTransport) RegisterNodeChan(nodeId string, ch chan RPCRequest) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.nodeChans[nodeId] = ch
|
|
|
|
// 初始化连通性(默认所有节点互相可达)
|
|
if _, exists := t.connectivityMap[nodeId]; !exists {
|
|
t.connectivityMap[nodeId] = make(map[string]bool)
|
|
}
|
|
for peerId := range t.nodeChans {
|
|
t.connectivityMap[nodeId][peerId] = true
|
|
t.connectivityMap[peerId][nodeId] = true
|
|
}
|
|
}
|
|
|
|
// 设置两个节点的连通性
|
|
func (t *ThreadTransport) SetConnectivity(from, to string, isConnected bool) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
if _, exists := t.connectivityMap[from]; exists {
|
|
t.connectivityMap[from][to] = isConnected
|
|
}
|
|
}
|
|
|
|
func (t *ThreadTransport) ResetConnectivity() {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
for firstId:= range t.nodeChans {
|
|
for peerId:= range t.nodeChans {
|
|
if firstId != peerId {
|
|
t.connectivityMap[firstId][peerId] = true
|
|
t.connectivityMap[peerId][firstId] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 获取节点的 channel
|
|
func (t *ThreadTransport) getNodeChan(nodeId string) (chan RPCRequest, bool) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
ch, ok := t.nodeChans[nodeId]
|
|
return ch, ok
|
|
}
|
|
|
|
// 模拟 Dial 操作
|
|
func (t *ThreadTransport) DialHTTPWithTimeout(network string, myId string, peerId string) (ClientInterface, error) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
if _, exists := t.nodeChans[peerId]; !exists {
|
|
return nil, fmt.Errorf("节点 [%s] 不存在", peerId)
|
|
}
|
|
|
|
return &ThreadClient{SourceId: myId, TargetId: peerId}, nil
|
|
}
|
|
|
|
// 模拟 Call 操作
|
|
func (t *ThreadTransport) CallWithTimeout(client ClientInterface, serviceMethod string, args interface{}, reply interface{}) error {
|
|
threadClient, ok := client.(*ThreadClient)
|
|
if !ok {
|
|
return fmt.Errorf("无效的caller")
|
|
}
|
|
|
|
var isConnected bool
|
|
if threadClient.SourceId == "" {
|
|
isConnected = true
|
|
} else {
|
|
t.mu.Lock()
|
|
isConnected = t.connectivityMap[threadClient.SourceId][threadClient.TargetId]
|
|
t.mu.Unlock()
|
|
}
|
|
if !isConnected {
|
|
return fmt.Errorf("网络分区: %s cannot reach %s", threadClient.SourceId, threadClient.TargetId)
|
|
}
|
|
|
|
targetChan, exists := t.getNodeChan(threadClient.TargetId)
|
|
if !exists {
|
|
return fmt.Errorf("目标节点 [%s] 不存在", threadClient.TargetId)
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
behavior := t.Ctx.GetBehavior(threadClient.SourceId, threadClient.TargetId)
|
|
|
|
// 辅助函数:复制 replyCopy 到原始 reply
|
|
copyReply := func(dst, src interface{}) {
|
|
switch d := dst.(type) {
|
|
case *AppendEntriesReply:
|
|
*d = *(src.(*AppendEntriesReply))
|
|
case *RequestVoteReply:
|
|
*d = *(src.(*RequestVoteReply))
|
|
}
|
|
}
|
|
|
|
sendRequest := func(req RPCRequest, ch chan RPCRequest) bool {
|
|
select {
|
|
case ch <- req:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
switch behavior {
|
|
case RetryRpc:
|
|
retryTimes, ok := t.Ctx.GetRetries(threadClient.SourceId, threadClient.TargetId)
|
|
if !ok {
|
|
log.Fatal("没有设置对应的retry次数")
|
|
}
|
|
|
|
var lastErr error
|
|
for i := 0; i < retryTimes; i++ {
|
|
var replyCopy interface{}
|
|
useCopy := true
|
|
|
|
switch r := reply.(type) {
|
|
case *AppendEntriesReply:
|
|
tmp := *r
|
|
replyCopy = &tmp
|
|
case *RequestVoteReply:
|
|
tmp := *r
|
|
replyCopy = &tmp
|
|
default:
|
|
replyCopy = reply // 其他类型不复制
|
|
useCopy = false
|
|
}
|
|
|
|
request := RPCRequest{
|
|
ServiceMethod: serviceMethod,
|
|
Args: args,
|
|
Reply: replyCopy,
|
|
Done: done,
|
|
SourceId: threadClient.SourceId,
|
|
Behavior: NormalRpc,
|
|
}
|
|
|
|
if !sendRequest(request, targetChan) {
|
|
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
|
|
}
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err == nil && useCopy {
|
|
copyReply(reply, replyCopy)
|
|
}
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
lastErr = err
|
|
case <-time.After(250 * time.Millisecond):
|
|
lastErr = fmt.Errorf("RPC 调用超时: %s", serviceMethod)
|
|
}
|
|
}
|
|
return lastErr
|
|
|
|
default:
|
|
request := RPCRequest{
|
|
ServiceMethod: serviceMethod,
|
|
Args: args,
|
|
Reply: reply,
|
|
Done: done,
|
|
SourceId: threadClient.SourceId,
|
|
Behavior: behavior,
|
|
}
|
|
|
|
if !sendRequest(request, targetChan) {
|
|
return fmt.Errorf("目标节点 [%s] 无法接收请求", threadClient.TargetId)
|
|
}
|
|
|
|
select {
|
|
case err := <-done:
|
|
return err
|
|
case <-time.After(250 * time.Millisecond):
|
|
return fmt.Errorf("RPC 调用超时: %s", serviceMethod)
|
|
}
|
|
}
|
|
}
|