diff --git a/common/socket/ServerEvent.go b/common/socket/ServerEvent.go index 71a71fe31..5cc357cc1 100644 --- a/common/socket/ServerEvent.go +++ b/common/socket/ServerEvent.go @@ -1,24 +1,29 @@ package socket import ( + "blazing/common/socket/codec" + "blazing/cool" + "blazing/logic/service/player" + "blazing/modules/config/service" + "bytes" "context" "encoding/binary" "errors" - "io" "log" "os" "sync/atomic" "time" - "blazing/cool" - "blazing/logic/service/player" - "blazing/modules/config/service" - "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/panjf2000/gnet/v2" ) +const ( + minPacketLen = 17 + maxPacketLen = 10 * 1024 +) + func (s *Server) Boot(serverid, port uint32) error { // go s.bootws() s.serverid = serverid @@ -53,36 +58,19 @@ func (s *Server) Stop() error { func (s *Server) OnClose(c gnet.Conn, err error) (action gnet.Action) { defer func() { if err := recover(); err != nil { // 恢复 panic,err 为 panic 错误值 - // 1. 打印错误信息 if t, ok := c.Context().(*player.ClientData); ok { - if t.Player != nil { - if t.Player.Info != nil { - cool.Logger.Error(context.TODO(), "OnClose 错误:", cool.Config.ServerInfo.OnlineID, t.Player.Info.UserID, err) - t.Player.Service.Info.Save(*t.Player.Info) - } - + if t.Player != nil && t.Player.Info != nil { + cool.Logger.Error(context.TODO(), "OnClose 错误:", cool.Config.ServerInfo.OnlineID, t.Player.Info.UserID, err) + t.Player.Service.Info.Save(*t.Player.Info) } - } else { cool.Logger.Error(context.TODO(), "OnClose 错误:", cool.Config.ServerInfo.OnlineID, err) - } - } }() - // 识别 RST 导致的连接中断(错误信息含 "connection reset") - // if err != nil && (strings.Contains(err.Error(), "connection reset") || strings.Contains(err.Error(), "reset by peer")) { - // remoteIP := c.RemoteAddr().(*net.TCPAddr).IP.String() - // log.Printf("RST 攻击检测: 来源 %s, 累计攻击次数 %d", remoteIP) - - // // 防护逻辑:临时封禁异常 IP(可扩展为 IP 黑名单) - // // go s.tempBlockIP(remoteIP, 5*time.Minute) - // } - //fmt.Println(err, c.RemoteAddr().String(), "断开连接") atomic.AddInt64(&cool.Connected, -1) - //logging.Infof("conn[%v] disconnected", c.RemoteAddr().String()) v, _ := c.Context().(*player.ClientData) if v != nil { v.Close() @@ -90,23 +78,20 @@ func (s *Server) OnClose(c gnet.Conn, err error) (action gnet.Action) { v.Player.Save() //保存玩家数据 } } - - //} - //关闭连接 return } + func (s *Server) OnTick() (delay time.Duration, action gnet.Action) { g.Log().Async().Info(context.Background(), gtime.Now().ISO8601(), "服务器ID", cool.Config.ServerInfo.OnlineID, "链接数", atomic.LoadInt64(&cool.Connected)) if s.quit && atomic.LoadInt64(&cool.Connected) == 0 { - //执行正常退出逻辑 os.Exit(0) } return 30 * time.Second, gnet.None } + func (s *Server) OnBoot(eng gnet.Engine) gnet.Action { s.eng = eng - - service.NewServerService().SetServerID(s.serverid, s.port) //设置当前服务器端口 + service.NewServerService().SetServerID(s.serverid, s.port) return gnet.None } @@ -114,59 +99,52 @@ func (s *Server) OnOpen(conn gnet.Conn) (out []byte, action gnet.Action) { if s.network != "tcp" { return nil, gnet.Close } - if conn.Context() == nil { - conn.SetContext(player.NewClientData(conn)) //注入data + conn.SetContext(player.NewClientData(conn)) } - atomic.AddInt64(&cool.Connected, 1) - return nil, gnet.None } func (s *Server) OnTraffic(c gnet.Conn) (action gnet.Action) { defer func() { - if err := recover(); err != nil { // 恢复 panic,err 为 panic 错误值 - // 1. 打印错误信息 + if err := recover(); err != nil { if t, ok := c.Context().(*player.ClientData); ok { - if t.Player != nil { - if t.Player.Info != nil { - cool.Logger.Error(context.TODO(), "OnTraffic 错误:", cool.Config.ServerInfo.OnlineID, t.Player.Info.UserID, err) - t.Player.Service.Info.Save(*t.Player.Info) - - } - + if t.Player != nil && t.Player.Info != nil { + cool.Logger.Error(context.TODO(), "OnTraffic 错误:", cool.Config.ServerInfo.OnlineID, t.Player.Info.UserID, err) + t.Player.Service.Info.Save(*t.Player.Info) } - } - } }() ws := c.Context().(*player.ClientData).Wsmsg - if ws.Tcp { //升级失败时候防止缓冲区溢出 + if ws.Tcp { return s.handleTCP(c) - } - tt, len1 := ws.ReadBufferBytes(c) - if tt == gnet.Close { - + readAction, inboundLen := ws.ReadBufferBytes(c) + if readAction == gnet.Close { return gnet.Close } - ok, action := ws.Upgrade(c) - if action != gnet.None { //连接断开 + state, action := ws.Upgrade(c) + if action != gnet.None { return action } - if !ok { //升级失败,说明是tcp连接 - ws.Tcp = true - - return s.handleTCP(c) - + if state == player.UpgradeNeedMoreData { + return gnet.None + } + if state == player.UpgradeUseTCP { + return s.handleTCP(c) + } + + if inboundLen > 0 { + if _, err := c.Discard(inboundLen); err != nil { + return gnet.Close + } + ws.ResetInboundMirror() } - // fmt.Println(ws.Buf.Bytes()) - c.Discard(len1) messages, err := ws.Decode(c) if err != nil { @@ -177,91 +155,93 @@ func (s *Server) OnTraffic(c gnet.Conn) (action gnet.Action) { } for _, msg := range messages { - - s.onevent(c, msg.Payload) - //t.OnEvent(msg.Payload) + if !s.onevent(c, msg.Payload) { + return gnet.Close + } } - return gnet.None } -const maxBodyLen = 10 * 1024 // 业务最大包体长度,按需调整 func (s *Server) handleTCP(conn gnet.Conn) (action gnet.Action) { + client := conn.Context().(*player.ClientData) + if s.discorse && !client.IsCrossDomainChecked() { + handled, ready, action := handle(conn) + if action != gnet.None { + return action + } + if !ready { + return gnet.None + } + if handled { + client.MarkCrossDomainChecked() + return gnet.None + } + client.MarkCrossDomainChecked() + } - conn.Context().(*player.ClientData).IsCrossDomain.Do(func() { //跨域检测 - handle(conn) - }) - - // handle(c) - // 先读取4字节的包长度 - lenBuf, err := conn.Peek(4) - + body, err := s.codec.Decode(conn) if err != nil { - if errors.Is(err, io.ErrShortBuffer) { - return + if errors.Is(err, codec.ErrIncompletePacket) { + return gnet.None } return gnet.Close } - - bodyLen := binary.BigEndian.Uint32(lenBuf) - - if bodyLen > maxBodyLen { + if !s.onevent(conn, body) { return gnet.Close } - - if conn.InboundBuffered() < int(bodyLen) { - return - } - // 提取包体 - body, err := conn.Next(int(bodyLen)) - if err != nil { - if errors.Is(err, io.ErrShortBuffer) { - return - } - return gnet.Close - } - - s.onevent(conn, body) - if conn.InboundBuffered() > 0 { - if err := conn.Wake(nil); err != nil { // wake up the connection manually to avoid missing the leftover data - + if err := conn.Wake(nil); err != nil { return gnet.Close } } return action - } -// CROSS_DOMAIN 定义跨域策略文件内容 const CROSS_DOMAIN = "\x00" - -// TEXT 定义跨域请求的文本格式 const TEXT = "\x00" -func handle(c gnet.Conn) { +func handle(c gnet.Conn) (handled bool, ready bool, action gnet.Action) { + probeLen := c.InboundBuffered() + if probeLen == 0 { + return false, false, gnet.None + } + if probeLen > len(TEXT) { + probeLen = len(TEXT) + } - // 读取数据并检查是否为跨域请求 - data, err := c.Peek(len(TEXT)) + data, err := c.Peek(probeLen) if err != nil { log.Printf("Error reading cross-domain request: %v", err) - return + return false, false, gnet.Close } - - if string(data) == TEXT { //判断是否是跨域请求 - //log.Printf("Received cross-domain request from %s", c.RemoteAddr()) - // 处理跨域请求 - c.Write([]byte(CROSS_DOMAIN)) - c.Discard(len(TEXT)) - - return + if !bytes.Equal(data, []byte(TEXT[:probeLen])) { + return false, true, gnet.None } - - //return + if probeLen < len(TEXT) { + return false, false, gnet.None + } + if _, err := c.Write([]byte(CROSS_DOMAIN)); err != nil { + return false, true, gnet.Close + } + if _, err := c.Discard(len(TEXT)); err != nil { + return false, true, gnet.Close + } + return true, true, gnet.None } -func (s *Server) onevent(c gnet.Conn, v []byte) { +func (s *Server) onevent(c gnet.Conn, v []byte) bool { + if !isValidPacket(v) { + return false + } if t, ok := c.Context().(*player.ClientData); ok { t.PushEvent(v, s.workerPool.Submit) } + return true +} + +func isValidPacket(v []byte) bool { + if len(v) < minPacketLen || len(v) > maxPacketLen { + return false + } + return binary.BigEndian.Uint32(v[0:4]) == uint32(len(v)) } diff --git a/logic/service/player/pack.go b/logic/service/player/pack.go index 5e08f26ef..6ccc96070 100644 --- a/logic/service/player/pack.go +++ b/logic/service/player/pack.go @@ -23,6 +23,11 @@ import ( "github.com/panjf2000/gnet/v2" ) +const ( + minPacketLen = 17 + maxPacketLen = 10 * 1024 +) + // getUnderlyingValue 递归解析reflect.Value,解包指针、interface{}到底层具体类型 func getUnderlyingValue(val reflect.Value) (reflect.Value, error) { for { @@ -106,6 +111,12 @@ func (h *ClientData) PushEvent(v []byte, submit func(task func()) error) { if h == nil || h.IsClosed() { return } + if len(v) < minPacketLen || len(v) > maxPacketLen { + return + } + if binary.BigEndian.Uint32(v[0:4]) != uint32(len(v)) { + return + } var header common.TomeeHeader header.Len = binary.BigEndian.Uint32(v[0:4]) @@ -249,13 +260,13 @@ func (h *ClientData) OnEvent(data common.TomeeHeader) { } type ClientData struct { - IsCrossDomain sync.Once //是否跨域过 - Player *Player //客户实体 - ERROR_CONNUT int - Wsmsg *WsCodec - Conn gnet.Conn - LF *lockfree.Lockfree[common.TomeeHeader] - closed int32 + Player *Player //客户实体 + ERROR_CONNUT int + Wsmsg *WsCodec + Conn gnet.Conn + LF *lockfree.Lockfree[common.TomeeHeader] + closed int32 + crossDomainChecked uint32 } func (p *ClientData) IsClosed() bool { @@ -271,6 +282,14 @@ func (p *ClientData) Close() { } } +func (p *ClientData) IsCrossDomainChecked() bool { + return atomic.LoadUint32(&p.crossDomainChecked) == 1 +} + +func (p *ClientData) MarkCrossDomainChecked() { + atomic.StoreUint32(&p.crossDomainChecked, 1) +} + func (p *ClientData) GetPlayer(userid uint32) *Player { //TODO 这里待优化,可能存在内存泄漏问题 if p.Player == nil { p.Player = NewPlayer(p.Conn) diff --git a/logic/service/player/wscodec.go b/logic/service/player/wscodec.go index 5f83b2165..43b042cf1 100644 --- a/logic/service/player/wscodec.go +++ b/logic/service/player/wscodec.go @@ -2,6 +2,7 @@ package player import ( "bytes" + "encoding/binary" "errors" "io" @@ -11,12 +12,18 @@ import ( "github.com/panjf2000/gnet/v2/pkg/logging" ) +const ( + minTCPPacketLen = 17 + maxTCPPacketLen = 10 * 1024 + tomeeVersion = 49 +) + type WsCodec struct { - Tcp bool - Upgraded bool // 链接是否升级 - Buf bytes.Buffer // 从实际socket中读取到的数据缓存 - wsMsgBuf wsMessageBuf // ws 消息缓存 - //Isinitws bool + Tcp bool + Upgraded bool // 链接是否升级 + Buf bytes.Buffer // 从实际socket中读取到的数据缓存 + wsMsgBuf wsMessageBuf // ws 消息缓存 + bufferedInbound int // 已镜像到 Buf 中的 inbound 字节数 } type wsMessageBuf struct { @@ -24,92 +31,115 @@ type wsMessageBuf struct { cachedBuf bytes.Buffer } +type UpgradeState uint8 + +const ( + UpgradeNeedMoreData UpgradeState = iota + UpgradeUseTCP + UpgradeUseWS +) + type readWrite struct { io.Reader io.Writer } -func CompareLeftBytes(array1, array2 []byte, leftBytesCount int) bool { - // 检查切片长度是否足够比较左边的字节 - if len(array1) < leftBytesCount || len(array2) < leftBytesCount { - return false - } - - // 提取左边的字节切片 - left1 := array1[:leftBytesCount] - left2 := array2[:leftBytesCount] - - // 比较左边的字节切片 - for i := 0; i < leftBytesCount; i++ { - if left1[i] != left2[i] { - return false - } - } - - return true -} -func (w *WsCodec) Upgrade(c gnet.Conn) (ok bool, action gnet.Action) { +func (w *WsCodec) Upgrade(c gnet.Conn) (state UpgradeState, action gnet.Action) { if w.Upgraded { - ok = true + state = UpgradeUseWS return } - if w.Tcp { - ok = false + state = UpgradeUseTCP return } - buf := &w.Buf - if CompareLeftBytes(buf.Bytes(), []byte{0, 0}, 2) { - w.Tcp = true - return - } - tmpReader := bytes.NewReader(buf.Bytes()) - oldLen := tmpReader.Len() - //logging.Infof("do Upgrade") + buf := w.Buf.Bytes() + if looksLikeTCPPacket(buf) { + w.SwitchToTCP() + state = UpgradeUseTCP + return + } + if len(buf) == 0 { + state = UpgradeNeedMoreData + return + } + + tmpReader := bytes.NewReader(buf) + oldLen := tmpReader.Len() hs, err := ws.Upgrade(readWrite{tmpReader, c}) skipN := oldLen - tmpReader.Len() if err != nil { - if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整,不跳过 buf 中的 skipN 字节(此时 buf 中存放的仅是部分 "handshake data" bytes),下次再尝试读取 + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { + state = UpgradeNeedMoreData return } - buf.Next(skipN) + w.Buf.Next(skipN) logging.Errorf("conn[%v] [err=%v]", c.RemoteAddr().String(), err.Error()) action = gnet.Close - //ok = true - //w.Tcp = true return } - buf.Next(skipN) - logging.Infof("conn[%v] upgrade websocket protocol! Handshake: %v", c.RemoteAddr().String(), hs) - ok = true + w.Buf.Next(skipN) + logging.Infof("conn[%v] upgrade websocket protocol! Handshake: %v", c.RemoteAddr().String(), hs) w.Upgraded = true + state = UpgradeUseWS return } + +func looksLikeTCPPacket(buf []byte) bool { + if len(buf) < 4 { + return false + } + packetLen := binary.BigEndian.Uint32(buf[:4]) + if packetLen < minTCPPacketLen || packetLen > maxTCPPacketLen { + return false + } + if len(buf) >= 5 && buf[4] != tomeeVersion { + return false + } + return true +} + func (w *WsCodec) ReadBufferBytes(c gnet.Conn) (gnet.Action, int) { size := c.InboundBuffered() - //buf := make([]byte, size) + if size < w.bufferedInbound { + w.bufferedInbound = 0 + } + if size == w.bufferedInbound { + return gnet.None, size + } + read, err := c.Peek(size) if err != nil { logging.Errorf("read err! %v", err) return gnet.Close, 0 } - // if read < size { - // logging.Errorf("read bytes len err! size: %d read: %d", size, read) - // return gnet.Close - // } - w.Buf.Write(read) + w.Buf.Write(read[w.bufferedInbound:]) + w.bufferedInbound = size return gnet.None, size } + +func (w *WsCodec) ResetInboundMirror() { + w.bufferedInbound = 0 +} + +func (w *WsCodec) SwitchToTCP() { + w.Tcp = true + w.Upgraded = false + w.bufferedInbound = 0 + w.Buf.Reset() + w.wsMsgBuf.curHeader = nil + w.wsMsgBuf.cachedBuf.Reset() +} + func (w *WsCodec) Decode(c gnet.Conn) (outs []wsutil.Message, err error) { - // fmt.Println("do Decode") messages, err := w.readWsMessages() if err != nil { logging.Errorf("Error reading message! %v", err) return nil, err } - if len(messages) <= 0 { //没有读到完整数据 不处理 + if len(messages) <= 0 { return } for _, message := range messages { @@ -131,9 +161,8 @@ func (w *WsCodec) readWsMessages() (messages []wsutil.Message, err error) { msgBuf := &w.wsMsgBuf in := &w.Buf for { - // 从 in 中读出 header,并将 header bytes 写入 msgBuf.cachedBuf if msgBuf.curHeader == nil { - if in.Len() < ws.MinHeaderSize { //头长度至少是2 + if in.Len() < ws.MinHeaderSize { return } var head ws.Header @@ -142,13 +171,13 @@ func (w *WsCodec) readWsMessages() (messages []wsutil.Message, err error) { if err != nil { return messages, err } - } else { //有可能不完整,构建新的 reader 读取 head,读取成功才实际对 in 进行读操作 + } else { tmpReader := bytes.NewReader(in.Bytes()) oldLen := tmpReader.Len() head, err = ws.ReadHeader(tmpReader) skipN := oldLen - tmpReader.Len() if err != nil { - if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整 + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { return messages, nil } in.Next(skipN) @@ -163,21 +192,19 @@ func (w *WsCodec) readWsMessages() (messages []wsutil.Message, err error) { return nil, err } } - dataLen := (int)(msgBuf.curHeader.Length) - // 从 in 中读出 data,并将 data bytes 写入 msgBuf.cachedBuf - if dataLen > 0 { - if in.Len() < dataLen { //数据不完整 + dataLen := int(msgBuf.curHeader.Length) + if dataLen > 0 { + if in.Len() < dataLen { logging.Infof("incomplete data") return } - _, err = io.CopyN(&msgBuf.cachedBuf, in, int64(dataLen)) if err != nil { return } } - if msgBuf.curHeader.Fin { //当前 header 已经是一个完整消息 + if msgBuf.curHeader.Fin { messages, err = wsutil.ReadClientMessage(&msgBuf.cachedBuf, messages) if err != nil { return nil, err