All checks were successful
ci/woodpecker/push/my-first-workflow Pipeline was successful
refactor: 优化战斗循环中的宠物处理逻辑 refactor: 重构物品更新服务使用ORM模型
292 lines
8.4 KiB
Go
292 lines
8.4 KiB
Go
package struc
|
||
|
||
import (
|
||
"encoding/binary"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"reflect"
|
||
"strings"
|
||
)
|
||
|
||
type Fields []*Field
|
||
|
||
func (f Fields) SetByteOrder(order binary.ByteOrder) {
|
||
for _, field := range f {
|
||
if field != nil {
|
||
field.Order = order
|
||
}
|
||
}
|
||
}
|
||
|
||
func (f Fields) String() string {
|
||
fields := make([]string, len(f))
|
||
for i, field := range f {
|
||
if field != nil {
|
||
fields[i] = field.String()
|
||
}
|
||
}
|
||
return "{" + strings.Join(fields, ", ") + "}"
|
||
}
|
||
|
||
func (f Fields) Sizeof(val reflect.Value, options *Options) int {
|
||
for val.Kind() == reflect.Ptr {
|
||
val = val.Elem()
|
||
}
|
||
size := 0
|
||
for i, field := range f {
|
||
if field != nil {
|
||
size += field.Size(val.Field(i), options)
|
||
}
|
||
}
|
||
return size
|
||
}
|
||
|
||
func (f Fields) sizefrom(val reflect.Value, index []int) int {
|
||
field := val.FieldByIndex(index)
|
||
switch field.Kind() {
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||
return int(field.Int())
|
||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||
n := int(field.Uint())
|
||
// all the builtin array length types are native int
|
||
// so this guards against weird truncation
|
||
if n < 0 {
|
||
return 0
|
||
}
|
||
return n
|
||
default:
|
||
name := val.Type().FieldByIndex(index).Name
|
||
panic(fmt.Sprintf("sizeof field %T.%s not an integer type", val.Interface(), name))
|
||
}
|
||
}
|
||
|
||
func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, error) {
|
||
for val.Kind() == reflect.Ptr {
|
||
val = val.Elem()
|
||
}
|
||
pos := 0
|
||
for i, field := range f {
|
||
if field == nil {
|
||
continue
|
||
}
|
||
v := val.Field(i)
|
||
length := field.Len
|
||
if field.Sizefrom != nil {
|
||
length = f.sizefrom(val, field.Sizefrom)
|
||
}
|
||
if length <= 0 && field.Slice {
|
||
length = v.Len()
|
||
}
|
||
if field.Sizeof != nil {
|
||
length := val.FieldByIndex(field.Sizeof).Len()
|
||
switch field.kind {
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||
// allocating a new int here has fewer side effects (doesn't update the original struct)
|
||
// but it's a wasteful allocation
|
||
// the old method might work if we just cast the temporary int/uint to the target type
|
||
v = reflect.New(v.Type()).Elem()
|
||
v.SetInt(int64(length))
|
||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||
v = reflect.New(v.Type()).Elem()
|
||
v.SetUint(uint64(length))
|
||
default:
|
||
panic(fmt.Sprintf("sizeof field is not int or uint type: %s, %s", field.Name, v.Type()))
|
||
}
|
||
}
|
||
if n, err := field.Pack(buf[pos:], v, length, options); err != nil {
|
||
return n, err
|
||
} else {
|
||
pos += n
|
||
//g.Dump(v)
|
||
}
|
||
}
|
||
return pos, nil
|
||
}
|
||
|
||
// 提取魔法常量,便于配置和维护
|
||
const (
|
||
// MaxBufferSize 最大缓冲区大小,防止分配过大内存
|
||
MaxBufferSize = 1024 * 1024 // 1MB
|
||
// smallBufferSize 小缓冲区大小,复用栈上数组减少堆分配
|
||
smallBufferSize = 8
|
||
)
|
||
|
||
// -------------------------- 优化后的核心方法 --------------------------
|
||
func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error {
|
||
// 解引用指针,直到拿到非指针类型的值
|
||
for val.Kind() == reflect.Ptr {
|
||
val = val.Elem()
|
||
}
|
||
|
||
// 定义小缓冲区(栈上分配,减少堆内存开销)
|
||
var smallBuf [smallBufferSize]byte
|
||
var readBuf []byte
|
||
|
||
// 遍历所有字段
|
||
for fieldIdx, field := range f {
|
||
// 跳过空字段
|
||
if field == nil {
|
||
continue
|
||
}
|
||
|
||
// 获取当前字段的反射值
|
||
fieldVal := val.Field(fieldIdx)
|
||
// 获取字段长度(优先从Sizefrom读取,否则用默认Len)
|
||
fieldLen := field.Len
|
||
if field.Sizefrom != nil {
|
||
fieldLen = f.sizefrom(val, field.Sizefrom)
|
||
}
|
||
|
||
// 处理指针字段:如果指针未初始化,创建新实例
|
||
if fieldVal.Kind() == reflect.Ptr && !fieldVal.Elem().IsValid() {
|
||
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
|
||
}
|
||
|
||
// 处理结构体类型字段
|
||
if field.Type == Struct {
|
||
if err := f.unpackStructField(r, fieldVal, fieldLen, field, options); err != nil {
|
||
return fmt.Errorf("unpack struct field index %d: %w", fieldIdx, err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 处理非结构体类型字段(基础类型/自定义类型)
|
||
if err := f.unpackBasicField(r, fieldVal, field, fieldLen, smallBuf[:], &readBuf, options); err != nil {
|
||
return fmt.Errorf("unpack basic field index %d: %w", fieldIdx, err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 定义全局的最大安全切片长度(可根据业务调整,建议通过 options 配置)
|
||
const defaultMaxSafeSliceLen = 10000 // 1万,根据实际场景调整
|
||
|
||
// 新增错误类型,便于上层捕获
|
||
var (
|
||
ErrExceedMaxSliceLen = errors.New("slice length exceeds maximum safe limit")
|
||
ErrInvalidSliceLen = errors.New("slice length is negative or zero")
|
||
)
|
||
|
||
// unpackStructField 抽离重复的结构体解析逻辑,解决DRY问题
|
||
// 修复点:增加长度校验和内存分配防护
|
||
func (f Fields) unpackStructField(r io.Reader, fieldVal reflect.Value, length int, field *Field, options *Options) error {
|
||
// 修复1:基础长度校验,拒绝无效/超大长度
|
||
if length <= 0 {
|
||
return ErrInvalidSliceLen
|
||
}
|
||
|
||
// 修复2:获取最大允许的切片长度(优先使用 options 配置,无则用默认值)
|
||
maxSliceLen := defaultMaxSafeSliceLen
|
||
|
||
// 修复3:校验长度是否超过安全阈值,防止OOM
|
||
if length > maxSliceLen {
|
||
return fmt.Errorf("%w: requested %d, max allowed %d", ErrExceedMaxSliceLen, length, maxSliceLen)
|
||
}
|
||
|
||
// 处理切片/数组类型的结构体字段
|
||
if field.Slice {
|
||
var sliceVal reflect.Value
|
||
// 如果是数组(固定长度),直接使用原字段;如果是切片,创建指定长度的切片
|
||
if field.Array {
|
||
sliceVal = fieldVal
|
||
} else {
|
||
// 原逻辑:这里是OOM的核心触发点,现在已经提前做了长度校验
|
||
sliceVal = reflect.MakeSlice(fieldVal.Type(), length, length)
|
||
}
|
||
|
||
// 遍历切片/数组的每个元素,解析结构体
|
||
for elemIdx := 0; elemIdx < length; elemIdx++ {
|
||
elemVal := sliceVal.Index(elemIdx)
|
||
if err := f.unpackSingleStructElem(r, elemVal, options); err != nil {
|
||
return fmt.Errorf("slice elem %d: %w", elemIdx, err)
|
||
}
|
||
}
|
||
|
||
// 非数组类型需要将创建的切片赋值回原字段
|
||
if !field.Array {
|
||
fieldVal.Set(sliceVal)
|
||
}
|
||
} else {
|
||
// 处理单个结构体字段
|
||
if err := f.unpackSingleStructElem(r, fieldVal, options); err != nil {
|
||
return fmt.Errorf("single struct: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// -------------------------- 抽离的辅助方法:处理单个结构体元素 --------------------------
|
||
// unpackSingleStructElem 解析单个结构体元素的核心逻辑(原重复代码)
|
||
func (f Fields) unpackSingleStructElem(r io.Reader, elemVal reflect.Value, options *Options) error {
|
||
// 解析结构体的字段定义
|
||
structFields, err := parseFields(elemVal)
|
||
if err != nil {
|
||
return fmt.Errorf("parse struct fields: %w", err)
|
||
}
|
||
// 递归调用Unpack解析结构体数据
|
||
if err := structFields.Unpack(r, elemVal, options); err != nil {
|
||
return fmt.Errorf("unpack struct elem: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// -------------------------- 抽离的辅助方法:处理基础类型字段 --------------------------
|
||
// unpackBasicField 处理非结构体类型的字段(基础类型/自定义类型)
|
||
func (f Fields) unpackBasicField(
|
||
r io.Reader,
|
||
fieldVal reflect.Value,
|
||
field *Field,
|
||
length int,
|
||
smallBuf []byte,
|
||
readBuf *[]byte,
|
||
options *Options,
|
||
) error {
|
||
// 解析字段实际类型
|
||
fieldType := field.Type.Resolve(options)
|
||
|
||
// 处理自定义类型(实现Custom接口)
|
||
if fieldType == CustomType {
|
||
custom, ok := fieldVal.Addr().Interface().(Custom)
|
||
if !ok {
|
||
return errors.New("field does not implement Custom interface")
|
||
}
|
||
if err := custom.Unpack(r, length, options); err != nil {
|
||
return fmt.Errorf("custom unpack: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 计算需要读取的字节数
|
||
bufferSize := length * fieldType.Size()
|
||
// 检查缓冲区大小,防止内存溢出
|
||
if bufferSize > MaxBufferSize {
|
||
return fmt.Errorf("buffer size %d exceeds max %d", bufferSize, MaxBufferSize)
|
||
}
|
||
// 长度为0时直接返回
|
||
if bufferSize <= 0 {
|
||
return nil
|
||
}
|
||
|
||
// 复用小缓冲区(栈上)或分配堆缓冲区,减少内存分配开销
|
||
if bufferSize < smallBufferSize {
|
||
*readBuf = smallBuf[:bufferSize]
|
||
} else {
|
||
*readBuf = make([]byte, bufferSize)
|
||
}
|
||
|
||
// 读取指定长度的字节数据
|
||
if _, err := io.ReadFull(r, *readBuf); err != nil {
|
||
return fmt.Errorf("read data: %w", err)
|
||
}
|
||
|
||
// 解析字节数据到目标字段
|
||
if err := field.Unpack((*readBuf)[:bufferSize], fieldVal, length, options); err != nil {
|
||
return fmt.Errorf("field unpack: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|