This commit is contained in:
@@ -2,6 +2,7 @@ package struc
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
@@ -103,83 +104,170 @@ func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, erro
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
// 提取魔法常量,便于配置和维护
|
||||
const (
|
||||
// MaxBufferSize 最大缓冲区大小,防止分配过大内存
|
||||
MaxBufferSize = 1024 * 1024 // 1MB
|
||||
// smallBufferSize 小缓冲区大小,复用栈上数组减少堆分配
|
||||
smallBufferSize = 8
|
||||
)
|
||||
|
||||
// -------------------------- 优化后的核心方法 --------------------------
|
||||
func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error {
|
||||
// 解引用指针,直到拿到非指针类型的值
|
||||
for val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
var tmp [8]byte
|
||||
var buf []byte
|
||||
for i, field := range f {
|
||||
|
||||
// 定义小缓冲区(栈上分配,减少堆内存开销)
|
||||
var smallBuf [smallBufferSize]byte
|
||||
var readBuf []byte
|
||||
|
||||
// 遍历所有字段
|
||||
for fieldIdx, field := range f {
|
||||
// 跳过空字段
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
v := val.Field(i)
|
||||
length := field.Len
|
||||
|
||||
// 获取当前字段的反射值
|
||||
fieldVal := val.Field(fieldIdx)
|
||||
// 获取字段长度(优先从Sizefrom读取,否则用默认Len)
|
||||
fieldLen := field.Len
|
||||
if field.Sizefrom != nil {
|
||||
length = f.sizefrom(val, field.Sizefrom)
|
||||
fieldLen = f.sizefrom(val, field.Sizefrom)
|
||||
}
|
||||
if v.Kind() == reflect.Ptr && !v.Elem().IsValid() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
|
||||
// 处理指针字段:如果指针未初始化,创建新实例
|
||||
if fieldVal.Kind() == reflect.Ptr && !fieldVal.Elem().IsValid() {
|
||||
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
|
||||
}
|
||||
|
||||
// 处理结构体类型字段
|
||||
if field.Type == Struct {
|
||||
if field.Slice {
|
||||
vals := v
|
||||
if !field.Array {
|
||||
vals = reflect.MakeSlice(v.Type(), length, length)
|
||||
}
|
||||
for i := 0; i < length; i++ {
|
||||
v := vals.Index(i)
|
||||
fields, err := parseFields(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fields.Unpack(r, v, options); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !field.Array {
|
||||
v.Set(vals)
|
||||
}
|
||||
} else {
|
||||
// TODO: DRY (we repeat the inner loop above)
|
||||
fields, err := parseFields(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fields.Unpack(r, v, options); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.unpackStructField(r, fieldVal, fieldLen, field, options); err != nil {
|
||||
return fmt.Errorf("unpack struct field index %d: %w", fieldIdx, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理非结构体类型字段(基础类型/自定义类型)
|
||||
if err := f.unpackBasicField(r, fieldVal, field, fieldLen, smallBuf[:], &readBuf, options); err != nil {
|
||||
return fmt.Errorf("unpack basic field index %d: %w", fieldIdx, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------- 抽离的辅助方法:处理结构体字段 --------------------------
|
||||
// unpackStructField 抽离重复的结构体解析逻辑,解决DRY问题
|
||||
func (f Fields) unpackStructField(r io.Reader, fieldVal reflect.Value, length int, field *Field, options *Options) error {
|
||||
// 长度为0时直接返回,避免无效循环
|
||||
if length <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理切片/数组类型的结构体字段
|
||||
if field.Slice {
|
||||
var sliceVal reflect.Value
|
||||
// 如果是数组(固定长度),直接使用原字段;如果是切片,创建指定长度的切片
|
||||
if field.Array {
|
||||
sliceVal = fieldVal
|
||||
} else {
|
||||
typ := field.Type.Resolve(options)
|
||||
if typ == CustomType {
|
||||
if err := v.Addr().Interface().(Custom).Unpack(r, length, options); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
size := length * field.Type.Resolve(options).Size()
|
||||
|
||||
// 添加大小限制,防止分配过大的内存
|
||||
const maxSize = 1024 * 1024 // 1MB 限制,可根据需求调整
|
||||
if size > maxSize {
|
||||
return fmt.Errorf("buffer size too large: %d bytes, max allowed: %d", size, maxSize)
|
||||
}
|
||||
|
||||
if size < 8 {
|
||||
buf = tmp[:size]
|
||||
} else {
|
||||
buf = make([]byte, size)
|
||||
}
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
err := field.Unpack(buf[:size], v, length, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sliceVal = reflect.MakeSlice(fieldVal.Type(), length, length)
|
||||
}
|
||||
|
||||
// 遍历切片/数组的每个元素,解析结构体
|
||||
for elemIdx := 0; elemIdx < length; elemIdx++ {
|
||||
elemVal := sliceVal.Index(elemIdx)
|
||||
if err := f.unpackSingleStructElem(r, elemVal, options); err != nil {
|
||||
return fmt.Errorf("slice elem %d: %w", elemIdx, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 非数组类型需要将创建的切片赋值回原字段
|
||||
if !field.Array {
|
||||
fieldVal.Set(sliceVal)
|
||||
}
|
||||
} else {
|
||||
// 处理单个结构体字段
|
||||
if err := f.unpackSingleStructElem(r, fieldVal, options); err != nil {
|
||||
return fmt.Errorf("single struct: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------- 抽离的辅助方法:处理单个结构体元素 --------------------------
|
||||
// unpackSingleStructElem 解析单个结构体元素的核心逻辑(原重复代码)
|
||||
func (f Fields) unpackSingleStructElem(r io.Reader, elemVal reflect.Value, options *Options) error {
|
||||
// 解析结构体的字段定义
|
||||
structFields, err := parseFields(elemVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse struct fields: %w", err)
|
||||
}
|
||||
// 递归调用Unpack解析结构体数据
|
||||
if err := structFields.Unpack(r, elemVal, options); err != nil {
|
||||
return fmt.Errorf("unpack struct elem: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------- 抽离的辅助方法:处理基础类型字段 --------------------------
|
||||
// unpackBasicField 处理非结构体类型的字段(基础类型/自定义类型)
|
||||
func (f Fields) unpackBasicField(
|
||||
r io.Reader,
|
||||
fieldVal reflect.Value,
|
||||
field *Field,
|
||||
length int,
|
||||
smallBuf []byte,
|
||||
readBuf *[]byte,
|
||||
options *Options,
|
||||
) error {
|
||||
// 解析字段实际类型
|
||||
fieldType := field.Type.Resolve(options)
|
||||
|
||||
// 处理自定义类型(实现Custom接口)
|
||||
if fieldType == CustomType {
|
||||
custom, ok := fieldVal.Addr().Interface().(Custom)
|
||||
if !ok {
|
||||
return errors.New("field does not implement Custom interface")
|
||||
}
|
||||
if err := custom.Unpack(r, length, options); err != nil {
|
||||
return fmt.Errorf("custom unpack: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 计算需要读取的字节数
|
||||
bufferSize := length * fieldType.Size()
|
||||
// 检查缓冲区大小,防止内存溢出
|
||||
if bufferSize > MaxBufferSize {
|
||||
return fmt.Errorf("buffer size %d exceeds max %d", bufferSize, MaxBufferSize)
|
||||
}
|
||||
// 长度为0时直接返回
|
||||
if bufferSize <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 复用小缓冲区(栈上)或分配堆缓冲区,减少内存分配开销
|
||||
if bufferSize < smallBufferSize {
|
||||
*readBuf = smallBuf[:bufferSize]
|
||||
} else {
|
||||
*readBuf = make([]byte, bufferSize)
|
||||
}
|
||||
|
||||
// 读取指定长度的字节数据
|
||||
if _, err := io.ReadFull(r, *readBuf); err != nil {
|
||||
return fmt.Errorf("read data: %w", err)
|
||||
}
|
||||
|
||||
// 解析字节数据到目标字段
|
||||
if err := field.Unpack((*readBuf)[:bufferSize], fieldVal, length, options); err != nil {
|
||||
return fmt.Errorf("field unpack: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user