Files
bl/common/utils/sturc/fields.go
xinian 932e199622
All checks were successful
ci/woodpecker/push/my-first-workflow Pipeline was successful
fix: 修复切片长度校验和内存分配防护问题
refactor: 优化战斗循环中的宠物处理逻辑
refactor: 重构物品更新服务使用ORM模型
2026-02-22 21:46:36 +08:00

292 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package struc
import (
"encoding/binary"
"errors"
"fmt"
"io"
"reflect"
"strings"
)
type Fields []*Field
func (f Fields) SetByteOrder(order binary.ByteOrder) {
for _, field := range f {
if field != nil {
field.Order = order
}
}
}
func (f Fields) String() string {
fields := make([]string, len(f))
for i, field := range f {
if field != nil {
fields[i] = field.String()
}
}
return "{" + strings.Join(fields, ", ") + "}"
}
func (f Fields) Sizeof(val reflect.Value, options *Options) int {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
size := 0
for i, field := range f {
if field != nil {
size += field.Size(val.Field(i), options)
}
}
return size
}
func (f Fields) sizefrom(val reflect.Value, index []int) int {
field := val.FieldByIndex(index)
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return int(field.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := int(field.Uint())
// all the builtin array length types are native int
// so this guards against weird truncation
if n < 0 {
return 0
}
return n
default:
name := val.Type().FieldByIndex(index).Name
panic(fmt.Sprintf("sizeof field %T.%s not an integer type", val.Interface(), name))
}
}
func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, error) {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
pos := 0
for i, field := range f {
if field == nil {
continue
}
v := val.Field(i)
length := field.Len
if field.Sizefrom != nil {
length = f.sizefrom(val, field.Sizefrom)
}
if length <= 0 && field.Slice {
length = v.Len()
}
if field.Sizeof != nil {
length := val.FieldByIndex(field.Sizeof).Len()
switch field.kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// allocating a new int here has fewer side effects (doesn't update the original struct)
// but it's a wasteful allocation
// the old method might work if we just cast the temporary int/uint to the target type
v = reflect.New(v.Type()).Elem()
v.SetInt(int64(length))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = reflect.New(v.Type()).Elem()
v.SetUint(uint64(length))
default:
panic(fmt.Sprintf("sizeof field is not int or uint type: %s, %s", field.Name, v.Type()))
}
}
if n, err := field.Pack(buf[pos:], v, length, options); err != nil {
return n, err
} else {
pos += n
//g.Dump(v)
}
}
return pos, nil
}
// 提取魔法常量,便于配置和维护
const (
// MaxBufferSize 最大缓冲区大小,防止分配过大内存
MaxBufferSize = 1024 * 1024 // 1MB
// smallBufferSize 小缓冲区大小,复用栈上数组减少堆分配
smallBufferSize = 8
)
// -------------------------- 优化后的核心方法 --------------------------
func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error {
// 解引用指针,直到拿到非指针类型的值
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
// 定义小缓冲区(栈上分配,减少堆内存开销)
var smallBuf [smallBufferSize]byte
var readBuf []byte
// 遍历所有字段
for fieldIdx, field := range f {
// 跳过空字段
if field == nil {
continue
}
// 获取当前字段的反射值
fieldVal := val.Field(fieldIdx)
// 获取字段长度优先从Sizefrom读取否则用默认Len
fieldLen := field.Len
if field.Sizefrom != nil {
fieldLen = f.sizefrom(val, field.Sizefrom)
}
// 处理指针字段:如果指针未初始化,创建新实例
if fieldVal.Kind() == reflect.Ptr && !fieldVal.Elem().IsValid() {
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
}
// 处理结构体类型字段
if field.Type == Struct {
if err := f.unpackStructField(r, fieldVal, fieldLen, field, options); err != nil {
return fmt.Errorf("unpack struct field index %d: %w", fieldIdx, err)
}
continue
}
// 处理非结构体类型字段(基础类型/自定义类型)
if err := f.unpackBasicField(r, fieldVal, field, fieldLen, smallBuf[:], &readBuf, options); err != nil {
return fmt.Errorf("unpack basic field index %d: %w", fieldIdx, err)
}
}
return nil
}
// 定义全局的最大安全切片长度(可根据业务调整,建议通过 options 配置)
const defaultMaxSafeSliceLen = 10000 // 1万根据实际场景调整
// 新增错误类型,便于上层捕获
var (
ErrExceedMaxSliceLen = errors.New("slice length exceeds maximum safe limit")
ErrInvalidSliceLen = errors.New("slice length is negative or zero")
)
// unpackStructField 抽离重复的结构体解析逻辑解决DRY问题
// 修复点:增加长度校验和内存分配防护
func (f Fields) unpackStructField(r io.Reader, fieldVal reflect.Value, length int, field *Field, options *Options) error {
// 修复1基础长度校验拒绝无效/超大长度
if length <= 0 {
return ErrInvalidSliceLen
}
// 修复2获取最大允许的切片长度优先使用 options 配置,无则用默认值)
maxSliceLen := defaultMaxSafeSliceLen
// 修复3校验长度是否超过安全阈值防止OOM
if length > maxSliceLen {
return fmt.Errorf("%w: requested %d, max allowed %d", ErrExceedMaxSliceLen, length, maxSliceLen)
}
// 处理切片/数组类型的结构体字段
if field.Slice {
var sliceVal reflect.Value
// 如果是数组(固定长度),直接使用原字段;如果是切片,创建指定长度的切片
if field.Array {
sliceVal = fieldVal
} else {
// 原逻辑这里是OOM的核心触发点现在已经提前做了长度校验
sliceVal = reflect.MakeSlice(fieldVal.Type(), length, length)
}
// 遍历切片/数组的每个元素,解析结构体
for elemIdx := 0; elemIdx < length; elemIdx++ {
elemVal := sliceVal.Index(elemIdx)
if err := f.unpackSingleStructElem(r, elemVal, options); err != nil {
return fmt.Errorf("slice elem %d: %w", elemIdx, err)
}
}
// 非数组类型需要将创建的切片赋值回原字段
if !field.Array {
fieldVal.Set(sliceVal)
}
} else {
// 处理单个结构体字段
if err := f.unpackSingleStructElem(r, fieldVal, options); err != nil {
return fmt.Errorf("single struct: %w", err)
}
}
return nil
}
// -------------------------- 抽离的辅助方法:处理单个结构体元素 --------------------------
// unpackSingleStructElem 解析单个结构体元素的核心逻辑(原重复代码)
func (f Fields) unpackSingleStructElem(r io.Reader, elemVal reflect.Value, options *Options) error {
// 解析结构体的字段定义
structFields, err := parseFields(elemVal)
if err != nil {
return fmt.Errorf("parse struct fields: %w", err)
}
// 递归调用Unpack解析结构体数据
if err := structFields.Unpack(r, elemVal, options); err != nil {
return fmt.Errorf("unpack struct elem: %w", err)
}
return nil
}
// -------------------------- 抽离的辅助方法:处理基础类型字段 --------------------------
// unpackBasicField 处理非结构体类型的字段(基础类型/自定义类型)
func (f Fields) unpackBasicField(
r io.Reader,
fieldVal reflect.Value,
field *Field,
length int,
smallBuf []byte,
readBuf *[]byte,
options *Options,
) error {
// 解析字段实际类型
fieldType := field.Type.Resolve(options)
// 处理自定义类型实现Custom接口
if fieldType == CustomType {
custom, ok := fieldVal.Addr().Interface().(Custom)
if !ok {
return errors.New("field does not implement Custom interface")
}
if err := custom.Unpack(r, length, options); err != nil {
return fmt.Errorf("custom unpack: %w", err)
}
return nil
}
// 计算需要读取的字节数
bufferSize := length * fieldType.Size()
// 检查缓冲区大小,防止内存溢出
if bufferSize > MaxBufferSize {
return fmt.Errorf("buffer size %d exceeds max %d", bufferSize, MaxBufferSize)
}
// 长度为0时直接返回
if bufferSize <= 0 {
return nil
}
// 复用小缓冲区(栈上)或分配堆缓冲区,减少内存分配开销
if bufferSize < smallBufferSize {
*readBuf = smallBuf[:bufferSize]
} else {
*readBuf = make([]byte, bufferSize)
}
// 读取指定长度的字节数据
if _, err := io.ReadFull(r, *readBuf); err != nil {
return fmt.Errorf("read data: %w", err)
}
// 解析字节数据到目标字段
if err := field.Unpack((*readBuf)[:bufferSize], fieldVal, length, options); err != nil {
return fmt.Errorf("field unpack: %w", err)
}
return nil
}