Files
bl/common/utils/bytearray/bytearray.go
昔念 ea4ca98e49 fix(socket): 修复连接处理逻辑并优化数据解码流程
- 修复 `OnOpen` 中网络类型判断位置不正确的问题,提前过滤非 TCP 连接
- 移除 `OnTraffic` 中重复的网络类型判断
- 优化 `TomeeSocketCodec` 的解码逻辑,使用 `InboundBuffered` 和 `Next` 提高效率
- 调整 `ByteArray` 创建方法参数,避免可变参数带来的性能损耗
- 在 `ClientData` 中将 `IsCrossDomain` 改为 `sync.Once` 避免重复处理
- 使用 `AsyncWrite` 替代 `Write` 提升写入异步性
- 修复连接关闭流程,使用
2025-11-01 14:31:19 +08:00

522 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package bytearray
import (
"encoding/binary"
"errors"
"io"
"math"
"sync"
)
// ByteArray 提供字节数组的读写操作,支持大小端字节序
type ByteArray struct {
buf []byte
posWrite int
posRead int
endian binary.ByteOrder
}
// 默认使用大端字节序
var defaultEndian binary.ByteOrder = binary.BigEndian
// bufferpool 用于重用ByteArray实例
var bufferpool = &sync.Pool{
New: func() interface{} {
return &ByteArray{endian: defaultEndian}
},
}
// CreateByteArray 创建一个新的ByteArray实例使用指定的字节数组
func CreateByteArray(bytes []byte) *ByteArray {
var ba *ByteArray
if len(bytes) == 0 { //如果是0则为新创建
ba = bufferpool.Get().(*ByteArray)
} else { //读序列
ba = &ByteArray{endian: defaultEndian}
}
ba.buf = append(ba.buf, bytes...)
ba.ResetPos()
return ba
}
// releaseByteArray 将ByteArray实例放回池中以便重用
func releaseByteArray(ba *ByteArray) {
ba.Reset()
bufferpool.Put(ba)
}
// Length 返回字节数组的总长度
func (ba *ByteArray) Length() int {
return len(ba.buf)
}
// Available 返回可读取的字节数
func (ba *ByteArray) Available() int {
return ba.Length() - ba.posRead
}
// SetEndian 设置字节序(大端或小端)
func (ba *ByteArray) SetEndian(endian binary.ByteOrder) {
ba.endian = endian
}
// GetEndian 获取当前字节序
func (ba *ByteArray) GetEndian() binary.ByteOrder {
if ba.endian == nil {
return defaultEndian
}
return ba.endian
}
// Grow 确保缓冲区有足够的空间
func (ba *ByteArray) Grow(size int) {
if size <= 0 {
return
}
required := ba.posWrite + size
if len(ba.buf) >= required {
return
}
newBuf := make([]byte, required)
copy(newBuf, ba.buf)
ba.buf = newBuf
}
// SetWritePos 设置写指针位置
func (ba *ByteArray) SetWritePos(pos int) error {
if pos < 0 || pos > ba.Length() {
return io.EOF
}
ba.posWrite = pos
return nil
}
// SetWriteEnd 将写指针设置到末尾
func (ba *ByteArray) SetWriteEnd() {
ba.posWrite = ba.Length()
}
// GetWritePos 获取写指针位置
func (ba *ByteArray) GetWritePos() int {
return ba.posWrite
}
// SetReadPos 设置读指针位置
func (ba *ByteArray) SetReadPos(pos int) error {
if pos < 0 || pos > ba.Length() {
return io.EOF
}
ba.posRead = pos
return nil
}
// SetReadEnd 将读指针设置到末尾
func (ba *ByteArray) SetReadEnd() {
ba.posRead = ba.Length()
}
// GetReadPos 获取读指针位置
func (ba *ByteArray) GetReadPos() int {
return ba.posRead
}
// ResetPos 重置读写指针到开始位置
func (ba *ByteArray) ResetPos() {
ba.posWrite = 0
ba.posRead = 0
}
// Reset 重置ByteArray清空缓冲区并重置指针
func (ba *ByteArray) Reset() {
ba.buf = nil
ba.ResetPos()
}
// Bytes 返回完整的字节数组
func (ba *ByteArray) Bytes() []byte {
defer releaseByteArray(ba) //这里是写数组,写完后退出时释放线程池
return ba.buf
}
// BytesAvailable 返回从当前读指针位置到末尾的字节数组
func (ba *ByteArray) BytesAvailable() []byte {
return ba.buf[ba.posRead:]
}
// ========== 写入方法 ==========
// Write 写入字节数组
func (ba *ByteArray) Write(bytes []byte) (int, error) {
if len(bytes) == 0 {
return 0, nil
}
ba.Grow(len(bytes))
n := copy(ba.buf[ba.posWrite:], bytes)
ba.posWrite += n
return n, nil
}
// WriteByte 写入单个字节
func (ba *ByteArray) WriteByte(b byte) error {
ba.Grow(1)
ba.buf[ba.posWrite] = b
ba.posWrite++
return nil
}
// WriteInt8 写入int8
func (ba *ByteArray) WriteInt8(value int8) error {
return ba.WriteByte(byte(value))
}
// WriteInt16 写入int16根据当前字节序处理
func (ba *ByteArray) WriteInt16(value int16) error {
return ba.writeNumber(value)
}
// WriteUInt16 写入uint16根据当前字节序处理
func (ba *ByteArray) WriteUInt16(value uint16) error {
return ba.writeNumber(value)
}
// WriteInt32 写入int32根据当前字节序处理
func (ba *ByteArray) WriteInt32(value int32) error {
return ba.writeNumber(value)
}
// WriteUInt32 写入uint32根据当前字节序处理
func (ba *ByteArray) WriteUInt32(value uint32) error {
return ba.writeNumber(value)
}
// WriteInt64 写入int64根据当前字节序处理
func (ba *ByteArray) WriteInt64(value int64) error {
return ba.writeNumber(value)
}
// Writeuint32 写入uint32根据当前字节序处理
func (ba *ByteArray) Writeuint32(value uint32) error {
return ba.writeNumber(value)
}
// WriteFloat32 写入float32根据当前字节序处理
func (ba *ByteArray) WriteFloat32(value float32) error {
return ba.writeNumber(math.Float32bits(value))
}
// WriteFloat64 写入float64根据当前字节序处理
func (ba *ByteArray) WriteFloat64(value float64) error {
return ba.writeNumber(math.Float64bits(value))
}
// WriteBool 写入布尔值
func (ba *ByteArray) WriteBool(value bool) error {
var b byte
if value {
b = 1
} else {
b = 0
}
return ba.WriteByte(b)
}
// WriteString 写入字符串
func (ba *ByteArray) WriteString(value string) error {
_, err := ba.Write([]byte(value))
return err
}
// WriteUTF 写入UTF字符串带长度前缀
func (ba *ByteArray) WriteUTF(value string) error {
bytes := []byte(value)
if err := ba.WriteUInt16(uint16(len(bytes))); err != nil {
return err
}
_, err := ba.Write(bytes)
return err
}
// ReadUTF8Array 读取 UTF8 字符串数组(格式:先读取 Int32 长度,再读取多个 UTF 字符串)
func (ba *ByteArray) ReadUTF8Array() ([]string, error) {
count, err := ba.ReadInt32()
if err != nil {
return nil, err
}
if count < 0 {
return nil, errors.New("invalid array length")
}
array := make([]string, 0, count)
for i := 0; i < int(count); i++ {
str, err := ba.ReadUTF()
if err != nil {
return nil, err
}
array = append(array, str)
}
return array, nil
}
// ReadInt32Array 读取 Int32 数组(格式:先读取 Int32 长度,再读取多个 Int32
func (ba *ByteArray) ReadInt32Array() ([]int32, error) {
count, err := ba.ReadInt32()
if err != nil {
return nil, err
}
if count < 0 {
return nil, errors.New("invalid array length")
}
array := make([]int32, 0, count)
for i := 0; i < int(count); i++ {
val, err := ba.ReadInt32()
if err != nil {
return nil, err
}
array = append(array, val)
}
return array, nil
}
// WriteUTF8 写入UTF8字符串不带长度前缀
func (ba *ByteArray) WriteUTF8(value string) error {
_, err := ba.Write([]byte(value))
return err
}
// 通用写入数值方法
func (ba *ByteArray) writeNumber(value interface{}) error {
var size int
switch value.(type) {
case int8, uint8:
size = 1
case int16, uint16:
size = 2
case int32, uint32, float32:
size = 4
case int64, uint64, float64:
size = 8
default:
return errors.New("unsupported number type")
}
ba.Grow(size)
switch v := value.(type) {
case int8:
ba.buf[ba.posWrite] = byte(v)
case uint8:
ba.buf[ba.posWrite] = v
case int16:
ba.endian.PutUint16(ba.buf[ba.posWrite:], uint16(v))
case uint16:
ba.endian.PutUint16(ba.buf[ba.posWrite:], v)
case int32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], uint32(v))
case uint32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], v)
case int64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], uint64(v))
case uint64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], v)
case float32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], math.Float32bits(v))
case float64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], math.Float64bits(v))
}
ba.posWrite += size
return nil
}
// ========== 读取方法 ==========
// Read 读取字节数组到指定缓冲区
func (ba *ByteArray) Read(bytes []byte) (int, error) {
if len(bytes) == 0 {
return 0, nil
}
if ba.posRead+len(bytes) > ba.Length() {
return 0, io.EOF
}
n := copy(bytes, ba.buf[ba.posRead:])
ba.posRead += n
return n, nil
}
// ReadByte 读取单个字节
func (ba *ByteArray) ReadByte() (byte, error) {
if ba.posRead >= ba.Length() {
return 0, io.EOF
}
b := ba.buf[ba.posRead]
ba.posRead++
return b, nil
}
// ReadInt8 读取int8
func (ba *ByteArray) ReadInt8() (int8, error) {
b, err := ba.ReadByte()
return int8(b), err
}
// ReadUInt8 读取uint8
func (ba *ByteArray) ReadUInt8() (uint8, error) {
return ba.ReadByte()
}
// ReadInt16 读取int16根据当前字节序处理
func (ba *ByteArray) ReadInt16() (int16, error) {
var v uint16
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int16(v), nil
}
// ReadUInt16 读取uint16根据当前字节序处理
func (ba *ByteArray) ReadUInt16() (uint16, error) {
var v uint16
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadInt32 读取int32根据当前字节序处理
func (ba *ByteArray) ReadInt32() (int32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int32(v), nil
}
// ReadUInt32 读取uint32根据当前字节序处理
func (ba *ByteArray) ReadUInt32() (uint32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadInt64 读取int64根据当前字节序处理
func (ba *ByteArray) ReadInt64() (int64, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int64(v), nil
}
// Readuint32 读取uint32根据当前字节序处理
func (ba *ByteArray) Readuint32() (uint32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadFloat32 读取float32根据当前字节序处理
func (ba *ByteArray) ReadFloat32() (float32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return math.Float32frombits(v), nil
}
// ReadFloat64 读取float64根据当前字节序处理
func (ba *ByteArray) ReadFloat64() (float64, error) {
var v uint64
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return math.Float64frombits(v), nil
}
// ReadBool 读取布尔值
func (ba *ByteArray) ReadBool() (bool, error) {
b, err := ba.ReadByte()
if err != nil {
return false, err
}
return b != 0, nil
}
// ReadString 读取指定长度的字符串
func (ba *ByteArray) ReadString(length int) (string, error) {
if length < 0 {
return "", errors.New("invalid string length")
}
if ba.posRead+length > ba.Length() {
return "", io.EOF
}
str := string(ba.buf[ba.posRead : ba.posRead+length])
ba.posRead += length
return str, nil
}
// ReadUTF 读取UTF字符串带长度前缀
func (ba *ByteArray) ReadUTF() (string, error) {
length, err := ba.ReadUInt16()
if err != nil {
return "", err
}
return ba.ReadString(int(length))
}
// 通用读取数值方法
func (ba *ByteArray) readNumber(value interface{}) error {
var size int
switch value.(type) {
case *int16, *uint16:
size = 2
case *int32, *uint32, *float32:
size = 4
case *int64, *uint64, *float64:
size = 8
default:
return errors.New("unsupported number type")
}
if ba.posRead+size > ba.Length() {
return io.EOF
}
buf := ba.buf[ba.posRead : ba.posRead+size]
ba.posRead += size
switch v := value.(type) {
case *int16:
*v = int16(ba.endian.Uint16(buf))
case *uint16:
*v = ba.endian.Uint16(buf)
case *int32:
*v = int32(ba.endian.Uint32(buf))
case *uint32:
*v = ba.endian.Uint32(buf)
case *int64:
*v = int64(ba.endian.Uint64(buf))
case *uint64:
*v = ba.endian.Uint64(buf)
case *float32:
*v = math.Float32frombits(ba.endian.Uint32(buf))
case *float64:
*v = math.Float64frombits(ba.endian.Uint64(buf))
}
return nil
}