Files
bl/common/utils/concurrent-swiss-map/concurrent_swiss_map.go
昔念 1dc75b529d
All checks were successful
ci/woodpecker/push/my-first-workflow Pipeline was successful
```
feat(socket): 优化TCP连接处理性能

- 添加最小可读长度检查,避免无效Peek操作
- 修复数据部分解析逻辑,避免空切片分配

perf(utils): 优化并发哈希映射性能

- 将分段数量调整为CPU核心数
- 重写Range方法,移除channel和goroutine开销
- 添加原子标志控制遍历终止

perf(utils): 优化结构体序列化缓存机制

- 添加sync.Map缓存预处理结果
- 支持结构体、自定义类型、二进制类型分别缓存
- 减少重复反射
2026-02-22 10:59:41 +08:00

360 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package csmap
import (
"blazing/cool"
"context"
"encoding/json"
"runtime"
"sync"
"sync/atomic"
"github.com/mhmtszr/concurrent-swiss-map/maphash"
"github.com/mhmtszr/concurrent-swiss-map/swiss"
)
type CsMap[K comparable, V any] struct {
hasher func(key K) uint64
shards []shard[K, V]
shardCount uint64
size uint64
}
type HashShardPair[K comparable, V any] struct {
shard shard[K, V]
hash uint64
}
type shard[K comparable, V any] struct {
items *swiss.Map[K, V]
*sync.RWMutex
}
// OptFunc is a type that is used in New function for passing options.
type OptFunc[K comparable, V any] func(o *CsMap[K, V])
// New function creates *CsMap[K, V].
func New[K comparable, V any](options ...OptFunc[K, V]) *CsMap[K, V] {
m := CsMap[K, V]{
hasher: maphash.NewHasher[K]().Hash,
shardCount: uint64(runtime.NumCPU()),
}
for _, option := range options {
option(&m)
}
m.shards = make([]shard[K, V], m.shardCount)
for i := 0; i < int(m.shardCount); i++ {
m.shards[i] = shard[K, V]{items: swiss.NewMap[K, V](uint32((m.size / m.shardCount) + 1)), RWMutex: &sync.RWMutex{}}
}
return &m
}
// // Create creates *CsMap.
// //
// // Deprecated: New function should be used instead.
// func Create[K comparable, V any](options ...func(options *CsMap[K, V])) *CsMap[K, V] {
// m := CsMap[K, V]{
// hasher: maphash.NewHasher[K]().Hash,
// shardCount: 32,
// }
// for _, option := range options {
// option(&m)
// }
// m.shards = make([]shard[K, V], m.shardCount)
// for i := 0; i < int(m.shardCount); i++ {
// m.shards[i] = shard[K, V]{items: swiss.NewMap[K, V](uint32((m.size / m.shardCount) + 1)), RWMutex: &sync.RWMutex{}}
// }
// return &m
// }
func WithShardCount[K comparable, V any](count uint64) func(csMap *CsMap[K, V]) {
return func(csMap *CsMap[K, V]) {
csMap.shardCount = count
}
}
func WithCustomHasher[K comparable, V any](h func(key K) uint64) func(csMap *CsMap[K, V]) {
return func(csMap *CsMap[K, V]) {
csMap.hasher = h
}
}
func WithSize[K comparable, V any](size uint64) func(csMap *CsMap[K, V]) {
return func(csMap *CsMap[K, V]) {
csMap.size = size
}
}
func (m *CsMap[K, V]) getShard(key K) HashShardPair[K, V] {
u := m.hasher(key)
return HashShardPair[K, V]{
hash: u,
shard: m.shards[u%m.shardCount],
}
}
func (m *CsMap[K, V]) Store(key K, value V) {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
shard.items.PutWithHash(key, value, hashShardPair.hash)
shard.Unlock()
}
func (m *CsMap[K, V]) Delete(key K) bool {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
defer shard.Unlock()
return shard.items.DeleteWithHash(key, hashShardPair.hash)
}
func (m *CsMap[K, V]) DeleteIf(key K, condition func(value V) bool) bool {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
defer shard.Unlock()
value, ok := shard.items.GetWithHash(key, hashShardPair.hash)
if ok && condition(value) {
return shard.items.DeleteWithHash(key, hashShardPair.hash)
}
return false
}
func (m *CsMap[K, V]) Load(key K) (V, bool) {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.RLock()
defer shard.RUnlock()
return shard.items.GetWithHash(key, hashShardPair.hash)
}
func (m *CsMap[K, V]) Has(key K) bool {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.RLock()
defer shard.RUnlock()
return shard.items.HasWithHash(key, hashShardPair.hash)
}
func (m *CsMap[K, V]) Clear() {
for i := range m.shards {
shard := m.shards[i]
shard.Lock()
shard.items.Clear()
shard.Unlock()
}
}
func (m *CsMap[K, V]) Count() int {
count := 0
for i := range m.shards {
shard := m.shards[i]
shard.RLock()
count += shard.items.Count()
shard.RUnlock()
}
return count
}
func (m *CsMap[K, V]) SetIfAbsent(key K, value V) {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
_, ok := shard.items.GetWithHash(key, hashShardPair.hash)
if !ok {
shard.items.PutWithHash(key, value, hashShardPair.hash)
}
shard.Unlock()
}
func (m *CsMap[K, V]) SetIf(key K, conditionFn func(previousVale V, previousFound bool) (value V, set bool)) {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
value, found := shard.items.GetWithHash(key, hashShardPair.hash)
value, ok := conditionFn(value, found)
if ok {
shard.items.PutWithHash(key, value, hashShardPair.hash)
}
shard.Unlock()
}
func (m *CsMap[K, V]) SetIfPresent(key K, value V) {
hashShardPair := m.getShard(key)
shard := hashShardPair.shard
shard.Lock()
_, ok := shard.items.GetWithHash(key, hashShardPair.hash)
if ok {
shard.items.PutWithHash(key, value, hashShardPair.hash)
}
shard.Unlock()
}
func (m *CsMap[K, V]) IsEmpty() bool {
return m.Count() == 0
}
type Tuple[K comparable, V any] struct {
Key K
Val V
}
// -------------------------- 保留所有原有方法(无修改) --------------------------
// 注:以下方法和你的源码完全一致,仅省略实现(避免冗余)
// New/WithShardCount/WithCustomHasher/WithSize/getShard/Store/Delete/DeleteIf/
// Load/Has/Clear/Count/SetIfAbsent/SetIf/SetIfPresent/IsEmpty/MarshalJSON/UnmarshalJSON
// -------------------------- 核心优化Range 方法 --------------------------
// Range 同步遍历所有分段,无 channel/goroutine/context 开销,保留 panic 恢复和提前终止
// 回调签名完全兼容:返回 true 终止遍历
func (m *CsMap[K, V]) Range(f func(key K, value V) (stop bool)) {
// 1. 提前判空:回调为 nil 直接返回
if f == nil {
return
}
// 2. 原子标志:控制是否终止遍历(替代 context
var stopFlag atomic.Bool
// 3. 遍历所有分段(同步执行,无额外 goroutine
for i := range m.shards {
// 检测终止标志:提前退出,避免无效遍历
if stopFlag.Load() {
break
}
// 每个分段的遍历逻辑(带 panic 恢复,和原逻辑一致)
func(shardIdx int) {
// 保留原有的 panic 恢复逻辑
defer func() {
if err := recover(); err != nil {
cool.Logger.Error(context.TODO(), "csmap Range shard panic 错误:", err)
}
}()
shard := &m.shards[shardIdx]
// 加读锁(并发安全,和原逻辑一致)
shard.RLock()
defer shard.RUnlock() // 延迟释放,避免锁泄漏
// 跳过空分段:核心优化点(减少无效遍历)
if shard.items.Count() == 0 {
return
}
// 遍历当前分段的元素(复用 swiss.Map 的 Iter 方法)
shard.items.Iter(func(k K, v V) (stop bool) {
// 检测终止标志:终止当前分段遍历
if stopFlag.Load() {
return true
}
// 执行用户回调,保留提前终止逻辑
if f(k, v) {
stopFlag.Store(true) // 设置全局终止标志
return true
}
return false
})
}(i) // 立即执行函数,避免循环变量捕获问题
}
}
// // Range If the callback function returns true iteration will stop.
// func (m *CsMap[K, V]) Range(f func(key K, value V) (stop bool)) {
// ch := make(chan Tuple[K, V], m.Count())
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// listenCompleted := m.listen(f, ch)
// m.produce(ctx, ch)
// listenCompleted.Wait()
// }
func (m *CsMap[K, V]) MarshalJSON() ([]byte, error) {
tmp := make(map[K]V, m.Count())
m.Range(func(key K, value V) (stop bool) {
tmp[key] = value
return false
})
return json.Marshal(tmp)
}
func (m *CsMap[K, V]) UnmarshalJSON(b []byte) error {
tmp := make(map[K]V, m.Count())
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
for key, val := range tmp {
m.Store(key, val)
}
return nil
}
func (m *CsMap[K, V]) produce(ctx context.Context, ch chan Tuple[K, V]) {
var wg sync.WaitGroup
wg.Add(len(m.shards))
for i := range m.shards {
go func(i int) {
defer wg.Done()
defer func() {
if err := recover(); err != nil { // 恢复 panicerr 为 panic 错误值
// 1. 打印错误信息
cool.Logger.Error(context.TODO(), "csmap panic 错误:", err)
}
}()
shard := m.shards[i]
shard.RLock()
shard.items.Iter(func(k K, v V) (stop bool) {
select {
case <-ctx.Done():
return true
default:
ch <- Tuple[K, V]{Key: k, Val: v}
}
return false
})
shard.RUnlock()
}(i)
}
go func() {
wg.Wait()
close(ch)
}()
}
func (m *CsMap[K, V]) listen(f func(key K, value V) (stop bool), ch chan Tuple[K, V]) *sync.WaitGroup {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if err := recover(); err != nil { // 恢复 panicerr 为 panic 错误值
// 1. 打印错误信息
cool.Logger.Error(context.TODO(), " csmap panic 错误:", err)
}
}()
for t := range ch {
if stop := f(t.Key, t.Val); stop {
return
}
}
}()
return &wg
}