package struc import ( "encoding/binary" "errors" "fmt" "io" "reflect" "strings" ) type Fields []*Field func (f Fields) SetByteOrder(order binary.ByteOrder) { for _, field := range f { if field != nil { field.Order = order } } } func (f Fields) String() string { fields := make([]string, len(f)) for i, field := range f { if field != nil { fields[i] = field.String() } } return "{" + strings.Join(fields, ", ") + "}" } func (f Fields) Sizeof(val reflect.Value, options *Options) int { for val.Kind() == reflect.Ptr { val = val.Elem() } size := 0 for i, field := range f { if field != nil { size += field.Size(val.Field(i), options) } } return size } func (f Fields) sizefrom(val reflect.Value, index []int) int { field := val.FieldByIndex(index) switch field.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return int(field.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: n := int(field.Uint()) // all the builtin array length types are native int // so this guards against weird truncation if n < 0 { return 0 } return n default: name := val.Type().FieldByIndex(index).Name panic(fmt.Sprintf("sizeof field %T.%s not an integer type", val.Interface(), name)) } } func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, error) { for val.Kind() == reflect.Ptr { val = val.Elem() } pos := 0 for i, field := range f { if field == nil { continue } v := val.Field(i) length := field.Len if field.Sizefrom != nil { length = f.sizefrom(val, field.Sizefrom) } if length <= 0 && field.Slice { length = v.Len() } if field.Sizeof != nil { length := val.FieldByIndex(field.Sizeof).Len() switch field.kind { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: // allocating a new int here has fewer side effects (doesn't update the original struct) // but it's a wasteful allocation // the old method might work if we just cast the temporary int/uint to the target type v = reflect.New(v.Type()).Elem() v.SetInt(int64(length)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v = reflect.New(v.Type()).Elem() v.SetUint(uint64(length)) default: panic(fmt.Sprintf("sizeof field is not int or uint type: %s, %s", field.Name, v.Type())) } } if n, err := field.Pack(buf[pos:], v, length, options); err != nil { return n, err } else { pos += n //g.Dump(v) } } return pos, nil } // 提取魔法常量,便于配置和维护 const ( // MaxBufferSize 最大缓冲区大小,防止分配过大内存 MaxBufferSize = 1024 * 1024 // 1MB // smallBufferSize 小缓冲区大小,复用栈上数组减少堆分配 smallBufferSize = 8 ) // -------------------------- 优化后的核心方法 -------------------------- func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error { // 解引用指针,直到拿到非指针类型的值 for val.Kind() == reflect.Ptr { val = val.Elem() } // 定义小缓冲区(栈上分配,减少堆内存开销) var smallBuf [smallBufferSize]byte var readBuf []byte // 遍历所有字段 for fieldIdx, field := range f { // 跳过空字段 if field == nil { continue } // 获取当前字段的反射值 fieldVal := val.Field(fieldIdx) // 获取字段长度(优先从Sizefrom读取,否则用默认Len) fieldLen := field.Len if field.Sizefrom != nil { fieldLen = f.sizefrom(val, field.Sizefrom) } // 处理指针字段:如果指针未初始化,创建新实例 if fieldVal.Kind() == reflect.Ptr && !fieldVal.Elem().IsValid() { fieldVal.Set(reflect.New(fieldVal.Type().Elem())) } // 处理结构体类型字段 if field.Type == Struct { if err := f.unpackStructField(r, fieldVal, fieldLen, field, options); err != nil { return fmt.Errorf("unpack struct field index %d: %w", fieldIdx, err) } continue } // 处理非结构体类型字段(基础类型/自定义类型) if err := f.unpackBasicField(r, fieldVal, field, fieldLen, smallBuf[:], &readBuf, options); err != nil { return fmt.Errorf("unpack basic field index %d: %w", fieldIdx, err) } } return nil } // 定义全局的最大安全切片长度(可根据业务调整,建议通过 options 配置) const defaultMaxSafeSliceLen = 10000 // 1万,根据实际场景调整 // 新增错误类型,便于上层捕获 var ( ErrExceedMaxSliceLen = errors.New("slice length exceeds maximum safe limit") ErrInvalidSliceLen = errors.New("slice length is negative or zero") ) // unpackStructField 抽离重复的结构体解析逻辑,解决DRY问题 // 修复点:增加长度校验和内存分配防护 func (f Fields) unpackStructField(r io.Reader, fieldVal reflect.Value, length int, field *Field, options *Options) error { // 修复1:基础长度校验,拒绝无效/超大长度 if length <= 0 { return ErrInvalidSliceLen } // 修复2:获取最大允许的切片长度(优先使用 options 配置,无则用默认值) maxSliceLen := defaultMaxSafeSliceLen // 修复3:校验长度是否超过安全阈值,防止OOM if length > maxSliceLen { return fmt.Errorf("%w: requested %d, max allowed %d", ErrExceedMaxSliceLen, length, maxSliceLen) } // 处理切片/数组类型的结构体字段 if field.Slice { var sliceVal reflect.Value // 如果是数组(固定长度),直接使用原字段;如果是切片,创建指定长度的切片 if field.Array { sliceVal = fieldVal } else { // 原逻辑:这里是OOM的核心触发点,现在已经提前做了长度校验 sliceVal = reflect.MakeSlice(fieldVal.Type(), length, length) } // 遍历切片/数组的每个元素,解析结构体 for elemIdx := 0; elemIdx < length; elemIdx++ { elemVal := sliceVal.Index(elemIdx) if err := f.unpackSingleStructElem(r, elemVal, options); err != nil { return fmt.Errorf("slice elem %d: %w", elemIdx, err) } } // 非数组类型需要将创建的切片赋值回原字段 if !field.Array { fieldVal.Set(sliceVal) } } else { // 处理单个结构体字段 if err := f.unpackSingleStructElem(r, fieldVal, options); err != nil { return fmt.Errorf("single struct: %w", err) } } return nil } // -------------------------- 抽离的辅助方法:处理单个结构体元素 -------------------------- // unpackSingleStructElem 解析单个结构体元素的核心逻辑(原重复代码) func (f Fields) unpackSingleStructElem(r io.Reader, elemVal reflect.Value, options *Options) error { // 解析结构体的字段定义 structFields, err := parseFields(elemVal) if err != nil { return fmt.Errorf("parse struct fields: %w", err) } // 递归调用Unpack解析结构体数据 if err := structFields.Unpack(r, elemVal, options); err != nil { return fmt.Errorf("unpack struct elem: %w", err) } return nil } // -------------------------- 抽离的辅助方法:处理基础类型字段 -------------------------- // unpackBasicField 处理非结构体类型的字段(基础类型/自定义类型) func (f Fields) unpackBasicField( r io.Reader, fieldVal reflect.Value, field *Field, length int, smallBuf []byte, readBuf *[]byte, options *Options, ) error { // 解析字段实际类型 fieldType := field.Type.Resolve(options) // 处理自定义类型(实现Custom接口) if fieldType == CustomType { custom, ok := fieldVal.Addr().Interface().(Custom) if !ok { return errors.New("field does not implement Custom interface") } if err := custom.Unpack(r, length, options); err != nil { return fmt.Errorf("custom unpack: %w", err) } return nil } // 计算需要读取的字节数 bufferSize := length * fieldType.Size() // 检查缓冲区大小,防止内存溢出 if bufferSize > MaxBufferSize { return fmt.Errorf("buffer size %d exceeds max %d", bufferSize, MaxBufferSize) } // 长度为0时直接返回 if bufferSize <= 0 { return nil } // 复用小缓冲区(栈上)或分配堆缓冲区,减少内存分配开销 if bufferSize < smallBufferSize { *readBuf = smallBuf[:bufferSize] } else { *readBuf = make([]byte, bufferSize) } // 读取指定长度的字节数据 if _, err := io.ReadFull(r, *readBuf); err != nil { return fmt.Errorf("read data: %w", err) } // 解析字节数据到目标字段 if err := field.Unpack((*readBuf)[:bufferSize], fieldVal, length, options); err != nil { return fmt.Errorf("field unpack: %w", err) } return nil }