- 修复 `OnOpen` 中网络类型判断位置不正确的问题,提前过滤非 TCP 连接 - 移除 `OnTraffic` 中重复的网络类型判断 - 优化 `TomeeSocketCodec` 的解码逻辑,使用 `InboundBuffered` 和 `Next` 提高效率 - 调整 `ByteArray` 创建方法参数,避免可变参数带来的性能损耗 - 在 `ClientData` 中将 `IsCrossDomain` 改为 `sync.Once` 避免重复处理 - 使用 `AsyncWrite` 替代 `Write` 提升写入异步性 - 修复连接关闭流程,使用
522 lines
11 KiB
Go
522 lines
11 KiB
Go
package bytearray
|
||
|
||
import (
|
||
"encoding/binary"
|
||
"errors"
|
||
"io"
|
||
"math"
|
||
"sync"
|
||
)
|
||
|
||
// ByteArray 提供字节数组的读写操作,支持大小端字节序
|
||
type ByteArray struct {
|
||
buf []byte
|
||
posWrite int
|
||
posRead int
|
||
endian binary.ByteOrder
|
||
}
|
||
|
||
// 默认使用大端字节序
|
||
var defaultEndian binary.ByteOrder = binary.BigEndian
|
||
|
||
// bufferpool 用于重用ByteArray实例
|
||
var bufferpool = &sync.Pool{
|
||
New: func() interface{} {
|
||
return &ByteArray{endian: defaultEndian}
|
||
},
|
||
}
|
||
|
||
// CreateByteArray 创建一个新的ByteArray实例,使用指定的字节数组
|
||
func CreateByteArray(bytes []byte) *ByteArray {
|
||
var ba *ByteArray
|
||
if len(bytes) == 0 { //如果是0,则为新创建
|
||
ba = bufferpool.Get().(*ByteArray)
|
||
} else { //读序列
|
||
ba = &ByteArray{endian: defaultEndian}
|
||
}
|
||
|
||
ba.buf = append(ba.buf, bytes...)
|
||
ba.ResetPos()
|
||
return ba
|
||
}
|
||
|
||
// releaseByteArray 将ByteArray实例放回池中以便重用
|
||
func releaseByteArray(ba *ByteArray) {
|
||
ba.Reset()
|
||
bufferpool.Put(ba)
|
||
}
|
||
|
||
// Length 返回字节数组的总长度
|
||
func (ba *ByteArray) Length() int {
|
||
return len(ba.buf)
|
||
}
|
||
|
||
// Available 返回可读取的字节数
|
||
func (ba *ByteArray) Available() int {
|
||
return ba.Length() - ba.posRead
|
||
}
|
||
|
||
// SetEndian 设置字节序(大端或小端)
|
||
func (ba *ByteArray) SetEndian(endian binary.ByteOrder) {
|
||
ba.endian = endian
|
||
}
|
||
|
||
// GetEndian 获取当前字节序
|
||
func (ba *ByteArray) GetEndian() binary.ByteOrder {
|
||
if ba.endian == nil {
|
||
return defaultEndian
|
||
}
|
||
return ba.endian
|
||
}
|
||
|
||
// Grow 确保缓冲区有足够的空间
|
||
func (ba *ByteArray) Grow(size int) {
|
||
if size <= 0 {
|
||
return
|
||
}
|
||
|
||
required := ba.posWrite + size
|
||
if len(ba.buf) >= required {
|
||
return
|
||
}
|
||
|
||
newBuf := make([]byte, required)
|
||
copy(newBuf, ba.buf)
|
||
ba.buf = newBuf
|
||
}
|
||
|
||
// SetWritePos 设置写指针位置
|
||
func (ba *ByteArray) SetWritePos(pos int) error {
|
||
if pos < 0 || pos > ba.Length() {
|
||
return io.EOF
|
||
}
|
||
ba.posWrite = pos
|
||
return nil
|
||
}
|
||
|
||
// SetWriteEnd 将写指针设置到末尾
|
||
func (ba *ByteArray) SetWriteEnd() {
|
||
ba.posWrite = ba.Length()
|
||
}
|
||
|
||
// GetWritePos 获取写指针位置
|
||
func (ba *ByteArray) GetWritePos() int {
|
||
return ba.posWrite
|
||
}
|
||
|
||
// SetReadPos 设置读指针位置
|
||
func (ba *ByteArray) SetReadPos(pos int) error {
|
||
if pos < 0 || pos > ba.Length() {
|
||
return io.EOF
|
||
}
|
||
ba.posRead = pos
|
||
return nil
|
||
}
|
||
|
||
// SetReadEnd 将读指针设置到末尾
|
||
func (ba *ByteArray) SetReadEnd() {
|
||
ba.posRead = ba.Length()
|
||
}
|
||
|
||
// GetReadPos 获取读指针位置
|
||
func (ba *ByteArray) GetReadPos() int {
|
||
return ba.posRead
|
||
}
|
||
|
||
// ResetPos 重置读写指针到开始位置
|
||
func (ba *ByteArray) ResetPos() {
|
||
ba.posWrite = 0
|
||
ba.posRead = 0
|
||
}
|
||
|
||
// Reset 重置ByteArray,清空缓冲区并重置指针
|
||
func (ba *ByteArray) Reset() {
|
||
ba.buf = nil
|
||
ba.ResetPos()
|
||
}
|
||
|
||
// Bytes 返回完整的字节数组
|
||
func (ba *ByteArray) Bytes() []byte {
|
||
defer releaseByteArray(ba) //这里是写数组,写完后退出时释放线程池
|
||
return ba.buf
|
||
}
|
||
|
||
// BytesAvailable 返回从当前读指针位置到末尾的字节数组
|
||
func (ba *ByteArray) BytesAvailable() []byte {
|
||
return ba.buf[ba.posRead:]
|
||
}
|
||
|
||
// ========== 写入方法 ==========
|
||
|
||
// Write 写入字节数组
|
||
func (ba *ByteArray) Write(bytes []byte) (int, error) {
|
||
if len(bytes) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
ba.Grow(len(bytes))
|
||
n := copy(ba.buf[ba.posWrite:], bytes)
|
||
ba.posWrite += n
|
||
return n, nil
|
||
}
|
||
|
||
// WriteByte 写入单个字节
|
||
func (ba *ByteArray) WriteByte(b byte) error {
|
||
ba.Grow(1)
|
||
ba.buf[ba.posWrite] = b
|
||
ba.posWrite++
|
||
return nil
|
||
}
|
||
|
||
// WriteInt8 写入int8
|
||
func (ba *ByteArray) WriteInt8(value int8) error {
|
||
return ba.WriteByte(byte(value))
|
||
}
|
||
|
||
// WriteInt16 写入int16,根据当前字节序处理
|
||
func (ba *ByteArray) WriteInt16(value int16) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// WriteUInt16 写入uint16,根据当前字节序处理
|
||
func (ba *ByteArray) WriteUInt16(value uint16) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// WriteInt32 写入int32,根据当前字节序处理
|
||
func (ba *ByteArray) WriteInt32(value int32) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// WriteUInt32 写入uint32,根据当前字节序处理
|
||
func (ba *ByteArray) WriteUInt32(value uint32) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// WriteInt64 写入int64,根据当前字节序处理
|
||
func (ba *ByteArray) WriteInt64(value int64) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// Writeuint32 写入uint32,根据当前字节序处理
|
||
func (ba *ByteArray) Writeuint32(value uint32) error {
|
||
return ba.writeNumber(value)
|
||
}
|
||
|
||
// WriteFloat32 写入float32,根据当前字节序处理
|
||
func (ba *ByteArray) WriteFloat32(value float32) error {
|
||
return ba.writeNumber(math.Float32bits(value))
|
||
}
|
||
|
||
// WriteFloat64 写入float64,根据当前字节序处理
|
||
func (ba *ByteArray) WriteFloat64(value float64) error {
|
||
return ba.writeNumber(math.Float64bits(value))
|
||
}
|
||
|
||
// WriteBool 写入布尔值
|
||
func (ba *ByteArray) WriteBool(value bool) error {
|
||
var b byte
|
||
if value {
|
||
b = 1
|
||
} else {
|
||
b = 0
|
||
}
|
||
return ba.WriteByte(b)
|
||
}
|
||
|
||
// WriteString 写入字符串
|
||
func (ba *ByteArray) WriteString(value string) error {
|
||
_, err := ba.Write([]byte(value))
|
||
return err
|
||
}
|
||
|
||
// WriteUTF 写入UTF字符串(带长度前缀)
|
||
func (ba *ByteArray) WriteUTF(value string) error {
|
||
bytes := []byte(value)
|
||
if err := ba.WriteUInt16(uint16(len(bytes))); err != nil {
|
||
return err
|
||
}
|
||
_, err := ba.Write(bytes)
|
||
return err
|
||
}
|
||
|
||
// ReadUTF8Array 读取 UTF8 字符串数组(格式:先读取 Int32 长度,再读取多个 UTF 字符串)
|
||
func (ba *ByteArray) ReadUTF8Array() ([]string, error) {
|
||
count, err := ba.ReadInt32()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if count < 0 {
|
||
return nil, errors.New("invalid array length")
|
||
}
|
||
|
||
array := make([]string, 0, count)
|
||
for i := 0; i < int(count); i++ {
|
||
str, err := ba.ReadUTF()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
array = append(array, str)
|
||
}
|
||
return array, nil
|
||
}
|
||
|
||
// ReadInt32Array 读取 Int32 数组(格式:先读取 Int32 长度,再读取多个 Int32)
|
||
func (ba *ByteArray) ReadInt32Array() ([]int32, error) {
|
||
count, err := ba.ReadInt32()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if count < 0 {
|
||
return nil, errors.New("invalid array length")
|
||
}
|
||
|
||
array := make([]int32, 0, count)
|
||
for i := 0; i < int(count); i++ {
|
||
val, err := ba.ReadInt32()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
array = append(array, val)
|
||
}
|
||
return array, nil
|
||
}
|
||
|
||
// WriteUTF8 写入UTF8字符串(不带长度前缀)
|
||
func (ba *ByteArray) WriteUTF8(value string) error {
|
||
_, err := ba.Write([]byte(value))
|
||
return err
|
||
}
|
||
|
||
// 通用写入数值方法
|
||
func (ba *ByteArray) writeNumber(value interface{}) error {
|
||
var size int
|
||
switch value.(type) {
|
||
case int8, uint8:
|
||
size = 1
|
||
case int16, uint16:
|
||
size = 2
|
||
case int32, uint32, float32:
|
||
size = 4
|
||
case int64, uint64, float64:
|
||
size = 8
|
||
default:
|
||
return errors.New("unsupported number type")
|
||
}
|
||
|
||
ba.Grow(size)
|
||
|
||
switch v := value.(type) {
|
||
case int8:
|
||
ba.buf[ba.posWrite] = byte(v)
|
||
case uint8:
|
||
ba.buf[ba.posWrite] = v
|
||
case int16:
|
||
ba.endian.PutUint16(ba.buf[ba.posWrite:], uint16(v))
|
||
case uint16:
|
||
ba.endian.PutUint16(ba.buf[ba.posWrite:], v)
|
||
case int32:
|
||
ba.endian.PutUint32(ba.buf[ba.posWrite:], uint32(v))
|
||
case uint32:
|
||
ba.endian.PutUint32(ba.buf[ba.posWrite:], v)
|
||
case int64:
|
||
ba.endian.PutUint64(ba.buf[ba.posWrite:], uint64(v))
|
||
case uint64:
|
||
ba.endian.PutUint64(ba.buf[ba.posWrite:], v)
|
||
case float32:
|
||
ba.endian.PutUint32(ba.buf[ba.posWrite:], math.Float32bits(v))
|
||
case float64:
|
||
ba.endian.PutUint64(ba.buf[ba.posWrite:], math.Float64bits(v))
|
||
}
|
||
|
||
ba.posWrite += size
|
||
return nil
|
||
}
|
||
|
||
// ========== 读取方法 ==========
|
||
|
||
// Read 读取字节数组到指定缓冲区
|
||
func (ba *ByteArray) Read(bytes []byte) (int, error) {
|
||
if len(bytes) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
if ba.posRead+len(bytes) > ba.Length() {
|
||
return 0, io.EOF
|
||
}
|
||
|
||
n := copy(bytes, ba.buf[ba.posRead:])
|
||
ba.posRead += n
|
||
return n, nil
|
||
}
|
||
|
||
// ReadByte 读取单个字节
|
||
func (ba *ByteArray) ReadByte() (byte, error) {
|
||
if ba.posRead >= ba.Length() {
|
||
return 0, io.EOF
|
||
}
|
||
|
||
b := ba.buf[ba.posRead]
|
||
ba.posRead++
|
||
return b, nil
|
||
}
|
||
|
||
// ReadInt8 读取int8
|
||
func (ba *ByteArray) ReadInt8() (int8, error) {
|
||
b, err := ba.ReadByte()
|
||
return int8(b), err
|
||
}
|
||
|
||
// ReadUInt8 读取uint8
|
||
func (ba *ByteArray) ReadUInt8() (uint8, error) {
|
||
return ba.ReadByte()
|
||
}
|
||
|
||
// ReadInt16 读取int16,根据当前字节序处理
|
||
func (ba *ByteArray) ReadInt16() (int16, error) {
|
||
var v uint16
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return int16(v), nil
|
||
}
|
||
|
||
// ReadUInt16 读取uint16,根据当前字节序处理
|
||
func (ba *ByteArray) ReadUInt16() (uint16, error) {
|
||
var v uint16
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return v, nil
|
||
}
|
||
|
||
// ReadInt32 读取int32,根据当前字节序处理
|
||
func (ba *ByteArray) ReadInt32() (int32, error) {
|
||
var v uint32
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return int32(v), nil
|
||
}
|
||
|
||
// ReadUInt32 读取uint32,根据当前字节序处理
|
||
func (ba *ByteArray) ReadUInt32() (uint32, error) {
|
||
var v uint32
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return v, nil
|
||
}
|
||
|
||
// ReadInt64 读取int64,根据当前字节序处理
|
||
func (ba *ByteArray) ReadInt64() (int64, error) {
|
||
var v uint32
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return int64(v), nil
|
||
}
|
||
|
||
// Readuint32 读取uint32,根据当前字节序处理
|
||
func (ba *ByteArray) Readuint32() (uint32, error) {
|
||
var v uint32
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return v, nil
|
||
}
|
||
|
||
// ReadFloat32 读取float32,根据当前字节序处理
|
||
func (ba *ByteArray) ReadFloat32() (float32, error) {
|
||
var v uint32
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return math.Float32frombits(v), nil
|
||
}
|
||
|
||
// ReadFloat64 读取float64,根据当前字节序处理
|
||
func (ba *ByteArray) ReadFloat64() (float64, error) {
|
||
var v uint64
|
||
if err := ba.readNumber(&v); err != nil {
|
||
return 0, err
|
||
}
|
||
return math.Float64frombits(v), nil
|
||
}
|
||
|
||
// ReadBool 读取布尔值
|
||
func (ba *ByteArray) ReadBool() (bool, error) {
|
||
b, err := ba.ReadByte()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return b != 0, nil
|
||
}
|
||
|
||
// ReadString 读取指定长度的字符串
|
||
func (ba *ByteArray) ReadString(length int) (string, error) {
|
||
if length < 0 {
|
||
return "", errors.New("invalid string length")
|
||
}
|
||
|
||
if ba.posRead+length > ba.Length() {
|
||
return "", io.EOF
|
||
}
|
||
|
||
str := string(ba.buf[ba.posRead : ba.posRead+length])
|
||
ba.posRead += length
|
||
return str, nil
|
||
}
|
||
|
||
// ReadUTF 读取UTF字符串(带长度前缀)
|
||
func (ba *ByteArray) ReadUTF() (string, error) {
|
||
length, err := ba.ReadUInt16()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return ba.ReadString(int(length))
|
||
}
|
||
|
||
// 通用读取数值方法
|
||
func (ba *ByteArray) readNumber(value interface{}) error {
|
||
var size int
|
||
switch value.(type) {
|
||
case *int16, *uint16:
|
||
size = 2
|
||
case *int32, *uint32, *float32:
|
||
size = 4
|
||
case *int64, *uint64, *float64:
|
||
size = 8
|
||
default:
|
||
return errors.New("unsupported number type")
|
||
}
|
||
|
||
if ba.posRead+size > ba.Length() {
|
||
return io.EOF
|
||
}
|
||
|
||
buf := ba.buf[ba.posRead : ba.posRead+size]
|
||
ba.posRead += size
|
||
|
||
switch v := value.(type) {
|
||
case *int16:
|
||
*v = int16(ba.endian.Uint16(buf))
|
||
case *uint16:
|
||
*v = ba.endian.Uint16(buf)
|
||
case *int32:
|
||
*v = int32(ba.endian.Uint32(buf))
|
||
case *uint32:
|
||
*v = ba.endian.Uint32(buf)
|
||
case *int64:
|
||
*v = int64(ba.endian.Uint64(buf))
|
||
case *uint64:
|
||
*v = ba.endian.Uint64(buf)
|
||
case *float32:
|
||
*v = math.Float32frombits(ba.endian.Uint32(buf))
|
||
case *float64:
|
||
*v = math.Float64frombits(ba.endian.Uint64(buf))
|
||
}
|
||
|
||
return nil
|
||
}
|