Files
bl/common/utils/sturc/fields.go

292 lines
8.4 KiB
Go
Raw Normal View History

package struc
import (
"encoding/binary"
2026-02-08 15:18:50 +08:00
"errors"
"fmt"
"io"
"reflect"
"strings"
)
type Fields []*Field
func (f Fields) SetByteOrder(order binary.ByteOrder) {
for _, field := range f {
if field != nil {
field.Order = order
}
}
}
func (f Fields) String() string {
fields := make([]string, len(f))
for i, field := range f {
if field != nil {
fields[i] = field.String()
}
}
return "{" + strings.Join(fields, ", ") + "}"
}
func (f Fields) Sizeof(val reflect.Value, options *Options) int {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
size := 0
for i, field := range f {
if field != nil {
size += field.Size(val.Field(i), options)
}
}
return size
}
func (f Fields) sizefrom(val reflect.Value, index []int) int {
field := val.FieldByIndex(index)
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return int(field.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := int(field.Uint())
// all the builtin array length types are native int
// so this guards against weird truncation
if n < 0 {
return 0
}
return n
default:
name := val.Type().FieldByIndex(index).Name
panic(fmt.Sprintf("sizeof field %T.%s not an integer type", val.Interface(), name))
}
}
func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, error) {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
pos := 0
for i, field := range f {
if field == nil {
continue
}
v := val.Field(i)
length := field.Len
if field.Sizefrom != nil {
length = f.sizefrom(val, field.Sizefrom)
}
if length <= 0 && field.Slice {
length = v.Len()
}
if field.Sizeof != nil {
length := val.FieldByIndex(field.Sizeof).Len()
switch field.kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// allocating a new int here has fewer side effects (doesn't update the original struct)
// but it's a wasteful allocation
// the old method might work if we just cast the temporary int/uint to the target type
v = reflect.New(v.Type()).Elem()
v.SetInt(int64(length))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = reflect.New(v.Type()).Elem()
v.SetUint(uint64(length))
default:
panic(fmt.Sprintf("sizeof field is not int or uint type: %s, %s", field.Name, v.Type()))
}
}
if n, err := field.Pack(buf[pos:], v, length, options); err != nil {
return n, err
} else {
pos += n
2025-06-24 22:09:05 +08:00
//g.Dump(v)
}
}
return pos, nil
}
2026-02-08 15:18:50 +08:00
// 提取魔法常量,便于配置和维护
const (
// MaxBufferSize 最大缓冲区大小,防止分配过大内存
MaxBufferSize = 1024 * 1024 // 1MB
// smallBufferSize 小缓冲区大小,复用栈上数组减少堆分配
smallBufferSize = 8
)
// -------------------------- 优化后的核心方法 --------------------------
func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error {
2026-02-08 15:18:50 +08:00
// 解引用指针,直到拿到非指针类型的值
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
2026-02-08 15:18:50 +08:00
// 定义小缓冲区(栈上分配,减少堆内存开销)
var smallBuf [smallBufferSize]byte
var readBuf []byte
// 遍历所有字段
for fieldIdx, field := range f {
// 跳过空字段
if field == nil {
continue
}
2026-02-08 15:18:50 +08:00
// 获取当前字段的反射值
fieldVal := val.Field(fieldIdx)
// 获取字段长度优先从Sizefrom读取否则用默认Len
fieldLen := field.Len
if field.Sizefrom != nil {
2026-02-08 15:18:50 +08:00
fieldLen = f.sizefrom(val, field.Sizefrom)
}
2026-02-08 15:18:50 +08:00
// 处理指针字段:如果指针未初始化,创建新实例
if fieldVal.Kind() == reflect.Ptr && !fieldVal.Elem().IsValid() {
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
}
2026-02-08 15:18:50 +08:00
// 处理结构体类型字段
if field.Type == Struct {
2026-02-08 15:18:50 +08:00
if err := f.unpackStructField(r, fieldVal, fieldLen, field, options); err != nil {
return fmt.Errorf("unpack struct field index %d: %w", fieldIdx, err)
}
continue
2026-02-08 15:18:50 +08:00
}
// 处理非结构体类型字段(基础类型/自定义类型)
if err := f.unpackBasicField(r, fieldVal, field, fieldLen, smallBuf[:], &readBuf, options); err != nil {
return fmt.Errorf("unpack basic field index %d: %w", fieldIdx, err)
}
}
return nil
}
// 定义全局的最大安全切片长度(可根据业务调整,建议通过 options 配置)
const defaultMaxSafeSliceLen = 10000 // 1万根据实际场景调整
// 新增错误类型,便于上层捕获
var (
ErrExceedMaxSliceLen = errors.New("slice length exceeds maximum safe limit")
ErrInvalidSliceLen = errors.New("slice length is negative or zero")
)
2026-02-08 15:18:50 +08:00
// unpackStructField 抽离重复的结构体解析逻辑解决DRY问题
// 修复点:增加长度校验和内存分配防护
2026-02-08 15:18:50 +08:00
func (f Fields) unpackStructField(r io.Reader, fieldVal reflect.Value, length int, field *Field, options *Options) error {
// 修复1基础长度校验拒绝无效/超大长度
2026-02-08 15:18:50 +08:00
if length <= 0 {
return ErrInvalidSliceLen
}
// 修复2获取最大允许的切片长度优先使用 options 配置,无则用默认值)
maxSliceLen := defaultMaxSafeSliceLen
// 修复3校验长度是否超过安全阈值防止OOM
if length > maxSliceLen {
return fmt.Errorf("%w: requested %d, max allowed %d", ErrExceedMaxSliceLen, length, maxSliceLen)
2026-02-08 15:18:50 +08:00
}
// 处理切片/数组类型的结构体字段
if field.Slice {
var sliceVal reflect.Value
// 如果是数组(固定长度),直接使用原字段;如果是切片,创建指定长度的切片
if field.Array {
sliceVal = fieldVal
} else {
// 原逻辑这里是OOM的核心触发点现在已经提前做了长度校验
2026-02-08 15:18:50 +08:00
sliceVal = reflect.MakeSlice(fieldVal.Type(), length, length)
}
// 遍历切片/数组的每个元素,解析结构体
for elemIdx := 0; elemIdx < length; elemIdx++ {
elemVal := sliceVal.Index(elemIdx)
if err := f.unpackSingleStructElem(r, elemVal, options); err != nil {
return fmt.Errorf("slice elem %d: %w", elemIdx, err)
}
}
2026-02-08 15:18:50 +08:00
// 非数组类型需要将创建的切片赋值回原字段
if !field.Array {
fieldVal.Set(sliceVal)
}
} else {
// 处理单个结构体字段
if err := f.unpackSingleStructElem(r, fieldVal, options); err != nil {
return fmt.Errorf("single struct: %w", err)
}
}
return nil
}
// -------------------------- 抽离的辅助方法:处理单个结构体元素 --------------------------
// unpackSingleStructElem 解析单个结构体元素的核心逻辑(原重复代码)
func (f Fields) unpackSingleStructElem(r io.Reader, elemVal reflect.Value, options *Options) error {
// 解析结构体的字段定义
structFields, err := parseFields(elemVal)
if err != nil {
return fmt.Errorf("parse struct fields: %w", err)
}
// 递归调用Unpack解析结构体数据
if err := structFields.Unpack(r, elemVal, options); err != nil {
return fmt.Errorf("unpack struct elem: %w", err)
}
return nil
}
// -------------------------- 抽离的辅助方法:处理基础类型字段 --------------------------
// unpackBasicField 处理非结构体类型的字段(基础类型/自定义类型)
func (f Fields) unpackBasicField(
r io.Reader,
fieldVal reflect.Value,
field *Field,
length int,
smallBuf []byte,
readBuf *[]byte,
options *Options,
) error {
// 解析字段实际类型
fieldType := field.Type.Resolve(options)
// 处理自定义类型实现Custom接口
if fieldType == CustomType {
custom, ok := fieldVal.Addr().Interface().(Custom)
if !ok {
return errors.New("field does not implement Custom interface")
}
if err := custom.Unpack(r, length, options); err != nil {
return fmt.Errorf("custom unpack: %w", err)
}
return nil
}
// 计算需要读取的字节数
bufferSize := length * fieldType.Size()
// 检查缓冲区大小,防止内存溢出
if bufferSize > MaxBufferSize {
return fmt.Errorf("buffer size %d exceeds max %d", bufferSize, MaxBufferSize)
}
// 长度为0时直接返回
if bufferSize <= 0 {
return nil
}
// 复用小缓冲区(栈上)或分配堆缓冲区,减少内存分配开销
if bufferSize < smallBufferSize {
*readBuf = smallBuf[:bufferSize]
} else {
*readBuf = make([]byte, bufferSize)
}
// 读取指定长度的字节数据
if _, err := io.ReadFull(r, *readBuf); err != nil {
return fmt.Errorf("read data: %w", err)
}
2026-02-08 15:18:50 +08:00
// 解析字节数据到目标字段
if err := field.Unpack((*readBuf)[:bufferSize], fieldVal, length, options); err != nil {
return fmt.Errorf("field unpack: %w", err)
}
return nil
}