From 6b639b882681ee3cd11a49a3b73193b46912e1c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=94=E5=BF=B5?= <1@72wo.cn> Date: Fri, 11 Jul 2025 21:04:28 +0800 Subject: [PATCH] =?UTF-8?q?feat(common):=20=E6=B7=BB=E5=8A=A0=20WebSocket?= =?UTF-8?q?=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 ClientData 结构中添加 WsCodec 字段 - 实现 WebSocket 升级和消息处理逻辑 - 添加 WebSocket 相关的依赖包 --- common/data/entity/Client.go | 8 ++ common/data/entity/wscodec.go | 161 ++++++++++++++++++++++++++++++++++ common/go.mod | 5 ++ common/go.sum | 10 +++ common/socket/ServerEvent.go | 58 ++++++++++-- 5 files changed, 236 insertions(+), 6 deletions(-) create mode 100644 common/data/entity/wscodec.go diff --git a/common/data/entity/Client.go b/common/data/entity/Client.go index 6f02f595b..c4876a895 100644 --- a/common/data/entity/Client.go +++ b/common/data/entity/Client.go @@ -7,6 +7,8 @@ type ClientData struct { player *Player //客户实体 //UserID uint32 m sync.Mutex + + wsmsg WsCodec } func (cd *ClientData) SetPlayer(player *Player) { @@ -19,6 +21,11 @@ func (cd *ClientData) GetPlayer() *Player { defer cd.m.Unlock() return cd.player } +func (cd *ClientData) Getwsmsg() *WsCodec { + cd.m.Lock() + defer cd.m.Unlock() + return &cd.wsmsg +} func (cd *ClientData) SetCrossDomain(isCrossDomain bool) { cd.m.Lock() defer cd.m.Unlock() @@ -35,6 +42,7 @@ func NewClientData() *ClientData { isCrossDomain: false, player: nil, m: sync.Mutex{}, + wsmsg: WsCodec{}, } return &cd diff --git a/common/data/entity/wscodec.go b/common/data/entity/wscodec.go new file mode 100644 index 000000000..1fcbe1836 --- /dev/null +++ b/common/data/entity/wscodec.go @@ -0,0 +1,161 @@ +package entity + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/panjf2000/gnet/v2" + "github.com/panjf2000/gnet/v2/pkg/logging" +) + +type WsCodec struct { + upgraded bool // 链接是否升级 + Buf bytes.Buffer // 从实际socket中读取到的数据缓存 + wsMsgBuf wsMessageBuf // ws 消息缓存 + Isinitws bool +} + +type wsMessageBuf struct { + curHeader *ws.Header + cachedBuf bytes.Buffer +} + +type readWrite struct { + io.Reader + io.Writer +} + +func (w *WsCodec) Upgrade(c gnet.Conn) (ok bool, action gnet.Action) { + if w.upgraded { + ok = true + return + } + buf := &w.Buf + tmpReader := bytes.NewReader(buf.Bytes()) + oldLen := tmpReader.Len() + logging.Infof("do Upgrade") + + hs, err := ws.Upgrade(readWrite{tmpReader, c}) + skipN := oldLen - tmpReader.Len() + if err != nil { + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整,不跳过 buf 中的 skipN 字节(此时 buf 中存放的仅是部分 "handshake data" bytes),下次再尝试读取 + return + } + buf.Next(skipN) + logging.Errorf("conn[%v] [err=%v]", c.RemoteAddr().String(), err.Error()) + action = gnet.Close + return + } + buf.Next(skipN) + logging.Infof("conn[%v] upgrade websocket protocol! Handshake: %v", c.RemoteAddr().String(), hs) + + ok = true + w.upgraded = true + return +} +func (w *WsCodec) ReadBufferBytes(c gnet.Conn) (gnet.Action, int) { + size := c.InboundBuffered() + //buf := make([]byte, size) + read, err := c.Peek(size) + if err != nil { + logging.Errorf("read err! %v", err) + return gnet.Close, 0 + } + // if read < size { + // logging.Errorf("read bytes len err! size: %d read: %d", size, read) + // return gnet.Close + // } + w.Buf.Write(read) + return gnet.None, size +} +func (w *WsCodec) Decode(c gnet.Conn) (outs []wsutil.Message, err error) { + fmt.Println("do Decode") + messages, err := w.readWsMessages() + if err != nil { + logging.Errorf("Error reading message! %v", err) + return nil, err + } + if messages == nil || len(messages) <= 0 { //没有读到完整数据 不处理 + return + } + for _, message := range messages { + if message.OpCode.IsControl() { + err = wsutil.HandleClientControlMessage(c, message) + if err != nil { + return + } + continue + } + if message.OpCode == ws.OpText || message.OpCode == ws.OpBinary { + outs = append(outs, message) + } + } + return +} + +func (w *WsCodec) readWsMessages() (messages []wsutil.Message, err error) { + msgBuf := &w.wsMsgBuf + in := &w.Buf + for { + // 从 in 中读出 header,并将 header bytes 写入 msgBuf.cachedBuf + if msgBuf.curHeader == nil { + if in.Len() < ws.MinHeaderSize { //头长度至少是2 + return + } + var head ws.Header + if in.Len() >= ws.MaxHeaderSize { + head, err = ws.ReadHeader(in) + if err != nil { + return messages, err + } + } else { //有可能不完整,构建新的 reader 读取 head,读取成功才实际对 in 进行读操作 + tmpReader := bytes.NewReader(in.Bytes()) + oldLen := tmpReader.Len() + head, err = ws.ReadHeader(tmpReader) + skipN := oldLen - tmpReader.Len() + if err != nil { + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整 + return messages, nil + } + in.Next(skipN) + return nil, err + } + in.Next(skipN) + } + + msgBuf.curHeader = &head + err = ws.WriteHeader(&msgBuf.cachedBuf, head) + if err != nil { + return nil, err + } + } + dataLen := (int)(msgBuf.curHeader.Length) + // 从 in 中读出 data,并将 data bytes 写入 msgBuf.cachedBuf + if dataLen > 0 { + if in.Len() < dataLen { //数据不完整 + fmt.Println(in.Len(), dataLen) + logging.Infof("incomplete data") + return + } + + _, err = io.CopyN(&msgBuf.cachedBuf, in, int64(dataLen)) + if err != nil { + return + } + } + if msgBuf.curHeader.Fin { //当前 header 已经是一个完整消息 + messages, err = wsutil.ReadClientMessage(&msgBuf.cachedBuf, messages) + if err != nil { + return nil, err + } + msgBuf.cachedBuf.Reset() + } else { + logging.Infof("The data is split into multiple frames") + } + msgBuf.curHeader = nil + } +} diff --git a/common/go.mod b/common/go.mod index 1d67e2b26..067e7af9a 100644 --- a/common/go.mod +++ b/common/go.mod @@ -25,8 +25,13 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/gobwas/ws v1.4.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kavu/go_reuseport v1.5.0 // indirect + github.com/libp2p/go-reuseport v0.4.0 // indirect github.com/pointernil/bitset32 v0.0.1 // indirect github.com/yitter/idgenerator-go v1.3.3 // indirect google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect diff --git a/common/go.sum b/common/go.sum index d4172238f..004250e70 100644 --- a/common/go.sum +++ b/common/go.sum @@ -32,6 +32,12 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4 github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/gogf/gf/contrib/nosql/redis/v2 v2.6.3/go.mod h1:2+evGu1xAlamaYuDdSqa7QCiwPTm1RrGsUFSMc8PyLc= github.com/gogf/gf/v2 v2.6.3 h1:DoqeuwU98wotpFoDSQEx8RZbmJdK8KdGiJtzJeqpyIo= github.com/gogf/gf/v2 v2.6.3/go.mod h1:x2XONYcI4hRQ/4gMNbWHmZrNzSEIg20s2NULbzom5k0= @@ -52,9 +58,13 @@ github.com/grokify/html-strip-tags-go v0.0.1 h1:0fThFwLbW7P/kOiTBs03FsJSV9RM2M/Q github.com/grokify/html-strip-tags-go v0.0.1/go.mod h1:2Su6romC5/1VXOQMaWL2yb618ARB8iVo6/DR99A6d78= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90= +github.com/kavu/go_reuseport v1.5.0 h1:UNuiY2OblcqAtVDE8Gsg1kZz8zbBWg907sP1ceBV+bk= +github.com/kavu/go_reuseport v1.5.0/go.mod h1:CG8Ee7ceMFSMnx/xr25Vm0qXaj2Z4i5PWoUx+JZ5/CU= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= +github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/lunixbochs/struc v0.0.0-20241101090106-8d528fa2c543 h1:GxMuVb9tJajC1QpbQwYNY1ZAo1EIE8I+UclBjOfjz/M= github.com/lunixbochs/struc v0.0.0-20241101090106-8d528fa2c543/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamhfG/Qzo= diff --git a/common/socket/ServerEvent.go b/common/socket/ServerEvent.go index 570780d19..a8c8540fc 100644 --- a/common/socket/ServerEvent.go +++ b/common/socket/ServerEvent.go @@ -2,21 +2,26 @@ package socket import ( "context" + "fmt" "log" "sync/atomic" "time" "blazing/common/data/entity" + "github.com/gobwas/ws/wsutil" "github.com/gogf/gf/v2/os/glog" "github.com/panjf2000/gnet/v2" "github.com/panjf2000/gnet/v2/pkg/logging" ) func (s *Server) Boot() error { + // go s.bootws() err := gnet.Run(s, s.network+"://"+s.addr, gnet.WithMulticore(true), gnet.WithTicker(true), + // gnet.WithReusePort(true), + // gnet.WithReuseAddr(true), gnet.WithSocketRecvBuffer(s.bufferSize)) if err != nil { return err @@ -64,19 +69,60 @@ func (s *Server) OnBoot(eng gnet.Engine) gnet.Action { return gnet.None } -func (s *Server) OnOpen(_ gnet.Conn) (out []byte, action gnet.Action) { +func (s *Server) OnOpen(conn gnet.Conn) (out []byte, action gnet.Action) { + if conn.Context() == nil { + conn.SetContext(entity.NewClientData()) //注入data + } + atomic.AddInt64(&s.connected, 1) return nil, gnet.None } -func (s *Server) OnTraffic(conn gnet.Conn) (action gnet.Action) { - if conn.Context() == nil { - conn.SetContext(entity.NewClientData()) //注入data +func (s *Server) OnTraffic(c gnet.Conn) (action gnet.Action) { + if s.network != "tcp" { + return gnet.Close } - if s.network == "tcp" { - return s.handleTcp(conn) + ws := c.Context().(*entity.ClientData).Getwsmsg() + tt, len1 := ws.ReadBufferBytes(c) + if tt == gnet.Close { + + return gnet.Close + } + + ok, action := ws.Upgrade(c) + if !ok { + s.handleTcp(c) + return gnet.None + } else { + fmt.Println(ws.Buf.Bytes()) + c.Read(make([]byte, len1)) + } + + if ws.Buf.Len() <= 0 { + return gnet.None + } + messages, err := ws.Decode(c) + if err != nil { + return gnet.Close + } + if messages == nil { + return + } + for _, message := range messages { + msgLen := len(message.Payload) + if msgLen > 128 { + logging.Infof("conn[%v] receive [op=%v] [msg=%v..., len=%d]", c.RemoteAddr().String(), message.OpCode, string(message.Payload[:128]), len(message.Payload)) + } else { + logging.Infof("conn[%v] receive [op=%v] [msg=%v, len=%d]", c.RemoteAddr().String(), message.OpCode, string(message.Payload), len(message.Payload)) + } + // This is the echo server + err = wsutil.WriteServerMessage(c, message.OpCode, message.Payload) + if err != nil { + logging.Infof("conn[%v] [err=%v]", c.RemoteAddr().String(), err.Error()) + return gnet.Close + } } return gnet.None