diff --git a/common/utils/concurrent-swiss-map/.golangci.yml b/common/utils/concurrent-swiss-map/.golangci.yml new file mode 100644 index 000000000..f9cb8f2be --- /dev/null +++ b/common/utils/concurrent-swiss-map/.golangci.yml @@ -0,0 +1,46 @@ +run: + skip-dirs: + - swiss + - swiss/simd + - maphash + skip-files: + - "concurrent_swiss_map_benchmark_test.go" + skip-dirs-use-default: false + +linters-settings: + lll: + line-length: 140 + funlen: + lines: 70 + +linters: + disable-all: true + enable: + - bodyclose + - depguard + - errcheck + - exhaustive + - funlen + - goconst + - gocritic + - gocyclo + - revive + - gosimple + - govet + - gosec + - ineffassign + - lll + - misspell + - nakedret + - gofumpt + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - whitespace + +service: + golangci-lint-version: 1.50.x # use the fixed version to not introduce new linters unexpectedly + prepare: + - echo "here I can run custom commands, but no preparation needed for this repo" \ No newline at end of file diff --git a/common/utils/concurrent-swiss-map/.goreleaser.yml b/common/utils/concurrent-swiss-map/.goreleaser.yml new file mode 100644 index 000000000..04dc6bb66 --- /dev/null +++ b/common/utils/concurrent-swiss-map/.goreleaser.yml @@ -0,0 +1,27 @@ +project_name: concurrent-swiss-map + +release: + github: + name: concurrent-swiss-map + owner: mhmtszr + +before: + hooks: + - go mod tidy + +builds: + - skip: true + +changelog: + sort: asc + use: github + filters: + exclude: + - '^test:' + - '^docs:' + - '^chore:' + - 'merge conflict' + - Merge pull request + - Merge remote-tracking branch + - Merge branch + - go mod tidy \ No newline at end of file diff --git a/common/utils/concurrent-swiss-map/LICENSE b/common/utils/concurrent-swiss-map/LICENSE new file mode 100644 index 000000000..57eae8b7d --- /dev/null +++ b/common/utils/concurrent-swiss-map/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Mehmet Sezer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/common/utils/concurrent-swiss-map/README.md b/common/utils/concurrent-swiss-map/README.md new file mode 100644 index 000000000..7f6696fa1 --- /dev/null +++ b/common/utils/concurrent-swiss-map/README.md @@ -0,0 +1,103 @@ +# Concurrent Swiss Map [![GoDoc][doc-img]][doc] [![Build Status][ci-img]][ci] [![Coverage Status][cov-img]][cov] [![Go Report Card][go-report-img]][go-report] + +**Concurrent Swiss Map** is an open-source Go library that provides a high-performance, thread-safe generic concurrent hash map implementation designed to handle concurrent access efficiently. It's built with a focus on simplicity, speed, and reliability, making it a solid choice for scenarios where concurrent access to a hash map is crucial. + +Uses [dolthub/swiss](https://github.com/dolthub/swiss) map implementation under the hood. + +## Installation + +Supports 1.18+ Go versions because of Go Generics + +``` +go get github.com/mhmtszr/concurrent-swiss-map +``` + +## Usage + +New functions will be added soon... + +```go +package main + +import ( + "hash/fnv" + + csmap "github.com/mhmtszr/concurrent-swiss-map" +) + +func main() { + myMap := csmap.New[string, int]( + // set the number of map shards. the default value is 32. + csmap.WithShardCount[string, int](32), + + // if don't set custom hasher, use the built-in maphash. + csmap.WithCustomHasher[string, int](func(key string) uint64 { + hash := fnv.New64a() + hash.Write([]byte(key)) + return hash.Sum64() + }), + + // set the total capacity, every shard map has total capacity/shard count capacity. the default value is 0. + csmap.WithSize[string, int](1000), + ) + + key := "swiss-map" + myMap.Store(key, 10) + + val, ok := myMap.Load(key) + println("load val:", val, "exists:", ok) + + deleted := myMap.Delete(key) + println("deleted:", deleted) + + ok = myMap.Has(key) + println("has:", ok) + + empty := myMap.IsEmpty() + println("empty:", empty) + + myMap.SetIfAbsent(key, 11) + + myMap.Range(func(key string, value int) (stop bool) { + println("range:", key, value) + return true + }) + + count := myMap.Count() + println("count:", count) + + // Output: + // load val: 10 exists: true + // deleted: true + // has: false + // empty: true + // range: swiss-map 11 + // count: 1 +} +``` + +## Basic Architecture +![img.png](img.png) + +## Benchmark Test +Benchmark was made on: +- Apple M1 Max +- 32 GB memory + +Benchmark test results can be obtained by running [this file](concurrent_swiss_map_benchmark_test.go) on local computers. + +![benchmark.png](benchmark.png) + +### Benchmark Results + +- Memory usage of the concurrent swiss map is better than other map implementations in all checked test scenarios. +- In high concurrent systems, the concurrent swiss map is faster, but in systems containing few concurrent operations, it works similarly to RWMutexMap. + +[doc-img]: https://godoc.org/github.com/mhmtszr/concurrent-swiss-map?status.svg +[doc]: https://godoc.org/github.com/mhmtszr/concurrent-swiss-map +[ci-img]: https://github.com/mhmtszr/concurrent-swiss-map/actions/workflows/build-test.yml/badge.svg +[ci]: https://github.com/mhmtszr/concurrent-swiss-map/actions/workflows/build-test.yml +[cov-img]: https://codecov.io/gh/mhmtszr/concurrent-swiss-map/branch/master/graph/badge.svg +[cov]: https://codecov.io/gh/mhmtszr/concurrent-swiss-map +[go-report-img]: https://goreportcard.com/badge/github.com/mhmtszr/concurrent-swiss-map +[go-report]: https://goreportcard.com/report/github.com/mhmtszr/concurrent-swiss-map \ No newline at end of file diff --git a/common/utils/concurrent-swiss-map/benchmark.png b/common/utils/concurrent-swiss-map/benchmark.png new file mode 100644 index 000000000..3899f0c1b Binary files /dev/null and b/common/utils/concurrent-swiss-map/benchmark.png differ diff --git a/common/utils/concurrent-swiss-map/concurrent_swiss_map.go b/common/utils/concurrent-swiss-map/concurrent_swiss_map.go new file mode 100644 index 000000000..c584f6a48 --- /dev/null +++ b/common/utils/concurrent-swiss-map/concurrent_swiss_map.go @@ -0,0 +1,286 @@ +package csmap + +import ( + "context" + "encoding/json" + "sync" + + "github.com/mhmtszr/concurrent-swiss-map/maphash" + "github.com/panjf2000/ants/v2" + + "github.com/mhmtszr/concurrent-swiss-map/swiss" +) + +type CsMap[K comparable, V any] struct { + hasher func(key K) uint64 + shards []shard[K, V] + shardCount uint64 + size uint64 +} + +type HashShardPair[K comparable, V any] struct { + shard shard[K, V] + hash uint64 +} + +type shard[K comparable, V any] struct { + items *swiss.Map[K, V] + *sync.RWMutex +} + +// OptFunc is a type that is used in New function for passing options. +type OptFunc[K comparable, V any] func(o *CsMap[K, V]) + +// New function creates *CsMap[K, V]. +func New[K comparable, V any](options ...OptFunc[K, V]) *CsMap[K, V] { + m := CsMap[K, V]{ + hasher: maphash.NewHasher[K]().Hash, + shardCount: 32, + } + for _, option := range options { + option(&m) + } + + m.shards = make([]shard[K, V], m.shardCount) + + for i := 0; i < int(m.shardCount); i++ { + m.shards[i] = shard[K, V]{items: swiss.NewMap[K, V](uint32((m.size / m.shardCount) + 1)), RWMutex: &sync.RWMutex{}} + } + return &m +} + +// Create creates *CsMap. +// +// Deprecated: New function should be used instead. +func Create[K comparable, V any](options ...func(options *CsMap[K, V])) *CsMap[K, V] { + m := CsMap[K, V]{ + hasher: maphash.NewHasher[K]().Hash, + shardCount: 32, + } + for _, option := range options { + option(&m) + } + + m.shards = make([]shard[K, V], m.shardCount) + + for i := 0; i < int(m.shardCount); i++ { + m.shards[i] = shard[K, V]{items: swiss.NewMap[K, V](uint32((m.size / m.shardCount) + 1)), RWMutex: &sync.RWMutex{}} + } + return &m +} + +func WithShardCount[K comparable, V any](count uint64) func(csMap *CsMap[K, V]) { + return func(csMap *CsMap[K, V]) { + csMap.shardCount = count + } +} + +func WithCustomHasher[K comparable, V any](h func(key K) uint64) func(csMap *CsMap[K, V]) { + return func(csMap *CsMap[K, V]) { + csMap.hasher = h + } +} + +func WithSize[K comparable, V any](size uint64) func(csMap *CsMap[K, V]) { + return func(csMap *CsMap[K, V]) { + csMap.size = size + } +} + +func (m *CsMap[K, V]) getShard(key K) HashShardPair[K, V] { + u := m.hasher(key) + return HashShardPair[K, V]{ + hash: u, + shard: m.shards[u%m.shardCount], + } +} + +func (m *CsMap[K, V]) Store(key K, value V) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + shard.items.PutWithHash(key, value, hashShardPair.hash) + shard.Unlock() +} + +func (m *CsMap[K, V]) Delete(key K) bool { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + defer shard.Unlock() + return shard.items.DeleteWithHash(key, hashShardPair.hash) +} + +func (m *CsMap[K, V]) DeleteIf(key K, condition func(value V) bool) bool { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + defer shard.Unlock() + value, ok := shard.items.GetWithHash(key, hashShardPair.hash) + if ok && condition(value) { + return shard.items.DeleteWithHash(key, hashShardPair.hash) + } + return false +} + +func (m *CsMap[K, V]) Load(key K) (V, bool) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.RLock() + defer shard.RUnlock() + return shard.items.GetWithHash(key, hashShardPair.hash) +} + +func (m *CsMap[K, V]) Has(key K) bool { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.RLock() + defer shard.RUnlock() + return shard.items.HasWithHash(key, hashShardPair.hash) +} + +func (m *CsMap[K, V]) Clear() { + for i := range m.shards { + shard := m.shards[i] + + shard.Lock() + shard.items.Clear() + shard.Unlock() + } +} + +func (m *CsMap[K, V]) Count() int { + count := 0 + for i := range m.shards { + shard := m.shards[i] + shard.RLock() + count += shard.items.Count() + shard.RUnlock() + } + return count +} + +func (m *CsMap[K, V]) SetIfAbsent(key K, value V) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + _, ok := shard.items.GetWithHash(key, hashShardPair.hash) + if !ok { + shard.items.PutWithHash(key, value, hashShardPair.hash) + } + shard.Unlock() +} + +func (m *CsMap[K, V]) SetIf(key K, conditionFn func(previousVale V, previousFound bool) (value V, set bool)) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + value, found := shard.items.GetWithHash(key, hashShardPair.hash) + value, ok := conditionFn(value, found) + if ok { + shard.items.PutWithHash(key, value, hashShardPair.hash) + } + shard.Unlock() +} + +func (m *CsMap[K, V]) SetIfPresent(key K, value V) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + _, ok := shard.items.GetWithHash(key, hashShardPair.hash) + if ok { + shard.items.PutWithHash(key, value, hashShardPair.hash) + } + shard.Unlock() +} + +func (m *CsMap[K, V]) IsEmpty() bool { + return m.Count() == 0 +} + +type Tuple[K comparable, V any] struct { + Key K + Val V +} + +// Range If the callback function returns true iteration will stop. +func (m *CsMap[K, V]) Range(f func(key K, value V) (stop bool)) { + ch := make(chan Tuple[K, V], m.Count()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + listenCompleted := m.listen(f, ch) + m.produce(ctx, ch) + listenCompleted.Wait() +} + +func (m *CsMap[K, V]) MarshalJSON() ([]byte, error) { + tmp := make(map[K]V, m.Count()) + m.Range(func(key K, value V) (stop bool) { + tmp[key] = value + return false + }) + return json.Marshal(tmp) +} + +func (m *CsMap[K, V]) UnmarshalJSON(b []byte) error { + tmp := make(map[K]V, m.Count()) + + if err := json.Unmarshal(b, &tmp); err != nil { + return err + } + + for key, val := range tmp { + m.Store(key, val) + } + return nil +} + +func (m *CsMap[K, V]) produce(ctx context.Context, ch chan Tuple[K, V]) { + var wg sync.WaitGroup + wg.Add(len(m.shards)) + + var producepool, _ = ants.NewPoolWithFuncGeneric(-1, func(i int) { + defer wg.Done() + + shard := m.shards[i] + shard.RLock() + shard.items.Iter(func(k K, v V) (stop bool) { + select { + case <-ctx.Done(): + return true + default: + ch <- Tuple[K, V]{Key: k, Val: v} + } + return false + }) + shard.RUnlock() + }) + + for i := range m.shards { + producepool.Invoke(i) + } + + pool.Submit(func() { + wg.Wait() + close(ch) + }) +} + +var pool, _ = ants.NewPool(-1) + +func (m *CsMap[K, V]) listen(f func(key K, value V) (stop bool), ch chan Tuple[K, V]) *sync.WaitGroup { + var wg sync.WaitGroup + wg.Add(1) + pool.Submit(func() { + defer wg.Done() + for t := range ch { + if stop := f(t.Key, t.Val); stop { + return + } + } + }) + + return &wg +} diff --git a/common/utils/concurrent-swiss-map/concurrent_swiss_map_benchmark_test.go b/common/utils/concurrent-swiss-map/concurrent_swiss_map_benchmark_test.go new file mode 100644 index 000000000..5391b8bce --- /dev/null +++ b/common/utils/concurrent-swiss-map/concurrent_swiss_map_benchmark_test.go @@ -0,0 +1,533 @@ +//nolint:all +package csmap_test + +// import ( +// "fmt" +// "runtime" +// "strconv" +// "sync" +// "testing" +// +// "github.com/mhmtszr/concurrent-swiss-map" +//) +// + +// var table = []struct { +// total int +// deletion int +// }{ +// { +// total: 100, +// deletion: 100, +// }, +// { +// total: 5000000, +// deletion: 5000000, +// }, +//} + +// func PrintMemUsage() { +// var m runtime.MemStats +// runtime.ReadMemStats(&m) +// // For info on each, see: https://golang.org/pkg/runtime/#MemStats +// fmt.Printf("Alloc = %v MiB", bToMb(m.Alloc)) +// fmt.Printf("\tTotalAlloc = %v MiB", bToMb(m.TotalAlloc)) +// fmt.Printf("\tSys = %v MiB", bToMb(m.Sys)) +// fmt.Printf("\tNumGC = %v\n", m.NumGC) +//} +// +// func bToMb(b uint64) uint64 { +// return b / 1024 / 1024 +//} + +// func BenchmarkConcurrentSwissMapGoMaxProcs1(b *testing.B) { +// runtime.GOMAXPROCS(1) +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// m1 := csmap.Create[int, string]() +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// func BenchmarkSyncMapGoMaxProcs1(b *testing.B) { +// runtime.GOMAXPROCS(1) +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// var m1 sync.Map +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// func BenchmarkRWMutexMapGoMaxProcs1(b *testing.B) { +// runtime.GOMAXPROCS(1) +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// m1 := CreateTestRWMutexMap() +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// func BenchmarkConcurrentSwissMapGoMaxProcsCore(b *testing.B) { +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// m1 := csmap.Create[int, string]() +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// func BenchmarkSyncMapGoMaxProcsCore(b *testing.B) { +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// var m1 sync.Map +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// func BenchmarkRWMutexMapGoMaxProcsCore(b *testing.B) { +// debug.SetGCPercent(-1) +// debug.SetMemoryLimit(math.MaxInt64) +// for _, v := range table { +// b.Run(fmt.Sprintf("total: %d deletion: %d", v.total, v.deletion), func(b *testing.B) { +// for i := 0; i < b.N; i++ { +// m1 := CreateTestRWMutexMap() +// var wg sync.WaitGroup +// wg.Add(3) +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(i, strconv.Itoa(i)) +// }() +// } +// wg2.Wait() +// }() +// +// go func() { +// defer wg.Done() +// var wg2 sync.WaitGroup +// wg2.Add(v.total) +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg2.Done() +// m1.Store(10, strconv.Itoa(i)) +// m1.Delete(10) +// }() +// } +// wg2.Wait() +// }() +// wg.Wait() +// +// wg.Add(v.deletion + v.total) +// for i := 0; i < v.deletion; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Delete(i) +// }() +// } +// +// for i := 0; i < v.total; i++ { +// i := i +// go func() { +// defer wg.Done() +// m1.Load(i) +// }() +// } +// wg.Wait() +// } +// }) +// } +// PrintMemUsage() +//} + +// type TestRWMutexMap struct { +// m map[int]string +// sync.RWMutex +//} +// +// func CreateTestRWMutexMap() *TestRWMutexMap { +// return &TestRWMutexMap{ +// m: make(map[int]string), +// } +//} +// +// func (m *TestRWMutexMap) Store(key int, value string) { +// m.Lock() +// defer m.Unlock() +// m.m[key] = value +//} +// +// func (m *TestRWMutexMap) Delete(key int) { +// m.Lock() +// defer m.Unlock() +// delete(m.m, key) +//} +// +// func (m *TestRWMutexMap) Load(key int) *string { +// m.RLock() +// defer m.RUnlock() +// s, ok := m.m[key] +// if !ok { +// return nil +// } +// return &s +//} diff --git a/common/utils/concurrent-swiss-map/concurrent_swiss_map_test.go b/common/utils/concurrent-swiss-map/concurrent_swiss_map_test.go new file mode 100644 index 000000000..f480c05cd --- /dev/null +++ b/common/utils/concurrent-swiss-map/concurrent_swiss_map_test.go @@ -0,0 +1,332 @@ +package csmap_test + +import ( + "strconv" + "sync" + "testing" + + csmap "github.com/mhmtszr/concurrent-swiss-map" +) + +func TestHas(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.Store(1, "test") + if !myMap.Has(1) { + t.Fatal("1 should exists") + } +} + +func TestLoad(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.Store(1, "test") + v, ok := myMap.Load(1) + v2, ok2 := myMap.Load(2) + if v != "test" || !ok { + t.Fatal("1 should test") + } + if v2 != "" || ok2 { + t.Fatal("2 should not exist") + } +} + +func TestDelete(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.Store(1, "test") + ok1 := myMap.Delete(20) + ok2 := myMap.Delete(1) + if myMap.Has(1) { + t.Fatal("1 should be deleted") + } + if ok1 { + t.Fatal("ok1 should be false") + } + if !ok2 { + t.Fatal("ok2 should be true") + } +} + +func TestSetIfAbsent(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.SetIfAbsent(1, "test") + if !myMap.Has(1) { + t.Fatal("1 should be exist") + } +} + +func TestSetIfPresent(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.SetIfPresent(1, "test") + if myMap.Has(1) { + t.Fatal("1 should be not exist") + } + + myMap.Store(1, "test") + myMap.SetIfPresent(1, "new-test") + val, _ := myMap.Load(1) + if val != "new-test" { + t.Fatal("val should be new-test") + } +} + +func TestSetIf(t *testing.T) { + myMap := csmap.New[int, string]() + valueA := "value a" + myMap.SetIf(1, func(previousVale string, previousFound bool) (value string, set bool) { + // operate like a SetIfAbsent... + if !previousFound { + return valueA, true + } + return "", false + }) + value, _ := myMap.Load(1) + if value != valueA { + t.Fatal("value should value a") + } + + myMap.SetIf(1, func(previousVale string, previousFound bool) (value string, set bool) { + // operate like a SetIfAbsent... + if !previousFound { + return "bad", true + } + return "", false + }) + value, _ = myMap.Load(1) + if value != valueA { + t.Fatal("value should value a") + } +} + +func TestDeleteIf(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.Store(1, "value b") + ok1 := myMap.DeleteIf(20, func(value string) bool { + t.Fatal("condition function should not have been called") + return false + }) + if ok1 { + t.Fatal("ok1 should be false") + } + + ok2 := myMap.DeleteIf(1, func(value string) bool { + if value != "value b" { + t.Fatal("condition function arg should be tests") + } + return false // don't delete + }) + if ok2 { + t.Fatal("ok1 should be false") + } + + ok3 := myMap.DeleteIf(1, func(value string) bool { + if value != "value b" { + t.Fatal("condition function arg should be tests") + } + return true // delete the entry + }) + if !ok3 { + t.Fatal("ok2 should be true") + } +} + +func TestCount(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.SetIfAbsent(1, "test") + myMap.SetIfAbsent(2, "test2") + if myMap.Count() != 2 { + t.Fatal("count should be 2") + } +} + +func TestIsEmpty(t *testing.T) { + myMap := csmap.New[int, string]() + if !myMap.IsEmpty() { + t.Fatal("map should be empty") + } +} + +func TestRangeStop(t *testing.T) { + myMap := csmap.New[int, string]( + csmap.WithShardCount[int, string](1), + ) + myMap.SetIfAbsent(1, "test") + myMap.SetIfAbsent(2, "test2") + myMap.SetIfAbsent(3, "test2") + total := 0 + myMap.Range(func(key int, value string) (stop bool) { + total++ + return true + }) + if total != 1 { + t.Fatal("total should be 1") + } +} + +func TestRange(t *testing.T) { + myMap := csmap.New[int, string]() + myMap.SetIfAbsent(1, "test") + myMap.SetIfAbsent(2, "test2") + total := 0 + myMap.Range(func(key int, value string) (stop bool) { + total++ + return + }) + if total != 2 { + t.Fatal("total should be 2") + } +} + +func TestCustomHasherWithRange(t *testing.T) { + myMap := csmap.New[int, string]( + csmap.WithCustomHasher[int, string](func(key int) uint64 { + return 0 + }), + ) + myMap.SetIfAbsent(1, "test") + myMap.SetIfAbsent(2, "test2") + myMap.SetIfAbsent(3, "test2") + myMap.SetIfAbsent(4, "test2") + total := 0 + myMap.Range(func(key int, value string) (stop bool) { + total++ + return true + }) + if total != 1 { + t.Fatal("total should be 1, because currently range stops current shard only.") + } +} + +func TestDeleteFromRange(t *testing.T) { + myMap := csmap.New[string, int]( + csmap.WithSize[string, int](1024), + ) + + myMap.Store("aaa", 10) + myMap.Store("aab", 11) + myMap.Store("aac", 15) + myMap.Store("aad", 124) + myMap.Store("aaf", 987) + + myMap.Range(func(key string, value int) (stop bool) { + if value > 20 { + myMap.Delete(key) + } + return false + }) + if myMap.Count() != 3 { + t.Fatal("total should be 3, because currently range deletes values that bigger than 20.") + } +} + +func TestMarshal(t *testing.T) { + myMap := csmap.New[string, int]( + csmap.WithSize[string, int](1024), + ) + + myMap.Store("aaa", 10) + myMap.Store("aab", 11) + + b, _ := myMap.MarshalJSON() + + newMap := csmap.New[string, int]( + csmap.WithSize[string, int](1024), + ) + + _ = newMap.UnmarshalJSON(b) + + if myMap.Count() != 2 || !myMap.Has("aaa") || !myMap.Has("aab") { + t.Fatal("count should be 2 after unmarshal") + } +} + +func TestBasicConcurrentWriteDeleteCount(t *testing.T) { + myMap := csmap.New[int, string]( + csmap.WithShardCount[int, string](32), + csmap.WithSize[int, string](1000), + ) + + var wg sync.WaitGroup + wg.Add(1000000) + for i := 0; i < 1000000; i++ { + i := i + go func() { + defer wg.Done() + myMap.Store(i, strconv.Itoa(i)) + }() + } + wg.Wait() + wg.Add(1000000) + for i := 0; i < 1000000; i++ { + i := i + go func() { + defer wg.Done() + if !myMap.Has(i) { + t.Error(strconv.Itoa(i) + " should exist") + return + } + }() + } + + wg.Wait() + wg.Add(1000000) + + for i := 0; i < 1000000; i++ { + i := i + go func() { + defer wg.Done() + myMap.Delete(i) + }() + } + + wg.Wait() + wg.Add(1000000) + + for i := 0; i < 1000000; i++ { + i := i + go func() { + defer wg.Done() + if myMap.Has(i) { + t.Error(strconv.Itoa(i) + " should not exist") + return + } + }() + } + + wg.Wait() +} + +func TestClear(t *testing.T) { + myMap := csmap.New[int, string]() + loop := 10000 + for i := 0; i < loop; i++ { + myMap.Store(i, "test") + } + + myMap.Clear() + + if !myMap.IsEmpty() { + t.Fatal("count should be true") + } + + // store again + for i := 0; i < loop; i++ { + myMap.Store(i, "test") + } + + // get again + for i := 0; i < loop; i++ { + val, ok := myMap.Load(i) + if ok != true { + t.Fatal("ok should be true") + } + + if val != "test" { + t.Fatal("val should be test") + } + } + + // check again + count := myMap.Count() + if count != loop { + t.Fatal("count should be 1000") + } +} diff --git a/common/utils/concurrent-swiss-map/example/base/base.go b/common/utils/concurrent-swiss-map/example/base/base.go new file mode 100644 index 000000000..b3b7e7cb3 --- /dev/null +++ b/common/utils/concurrent-swiss-map/example/base/base.go @@ -0,0 +1,57 @@ +package main + +import ( + "hash/fnv" + + csmap "github.com/mhmtszr/concurrent-swiss-map" +) + +func main() { + myMap := csmap.New[string, int]( + // set the number of map shards. the default value is 32. + csmap.WithShardCount[string, int](32), + + // if don't set custom hasher, use the built-in maphash. + csmap.WithCustomHasher[string, int](func(key string) uint64 { + hash := fnv.New64a() + hash.Write([]byte(key)) + return hash.Sum64() + }), + + // set the total capacity, every shard map has total capacity/shard count capacity. the default value is 0. + csmap.WithSize[string, int](1000), + ) + + key := "swiss-map" + myMap.Store(key, 10) + + val, ok := myMap.Load(key) + println("load val:", val, "exists:", ok) + + deleted := myMap.Delete(key) + println("deleted:", deleted) + + ok = myMap.Has(key) + println("has:", ok) + + empty := myMap.IsEmpty() + println("empty:", empty) + + myMap.SetIfAbsent(key, 11) + + myMap.Range(func(key string, value int) (stop bool) { + println("range:", key, value) + return true + }) + + count := myMap.Count() + println("count:", count) + + // Output: + // load val: 10 exists: true + // deleted: true + // has: false + // empty: true + // range: swiss-map 11 + // count: 1 +} diff --git a/common/utils/concurrent-swiss-map/go.mod b/common/utils/concurrent-swiss-map/go.mod new file mode 100644 index 000000000..83da03a77 --- /dev/null +++ b/common/utils/concurrent-swiss-map/go.mod @@ -0,0 +1,3 @@ +module github.com/mhmtszr/concurrent-swiss-map + +go 1.18 diff --git a/common/utils/concurrent-swiss-map/go.sum b/common/utils/concurrent-swiss-map/go.sum new file mode 100644 index 000000000..e69de29bb diff --git a/common/utils/concurrent-swiss-map/img.png b/common/utils/concurrent-swiss-map/img.png new file mode 100644 index 000000000..a0e638275 Binary files /dev/null and b/common/utils/concurrent-swiss-map/img.png differ diff --git a/common/utils/limit/LICENSE b/common/utils/concurrent-swiss-map/maphash/LICENSE similarity index 99% rename from common/utils/limit/LICENSE rename to common/utils/concurrent-swiss-map/maphash/LICENSE index 800f2c7c2..261eeb9e9 100644 --- a/common/utils/limit/LICENSE +++ b/common/utils/concurrent-swiss-map/maphash/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2025 肖其顿 + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/common/utils/concurrent-swiss-map/maphash/README.md b/common/utils/concurrent-swiss-map/maphash/README.md new file mode 100644 index 000000000..d91530f99 --- /dev/null +++ b/common/utils/concurrent-swiss-map/maphash/README.md @@ -0,0 +1,4 @@ +# maphash + +Hash any `comparable` type using Golang's fast runtime hash. +Uses [AES](https://en.wikipedia.org/wiki/AES_instruction_set) instructions when available. \ No newline at end of file diff --git a/common/utils/concurrent-swiss-map/maphash/hasher.go b/common/utils/concurrent-swiss-map/maphash/hasher.go new file mode 100644 index 000000000..ef53596a2 --- /dev/null +++ b/common/utils/concurrent-swiss-map/maphash/hasher.go @@ -0,0 +1,48 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package maphash + +import "unsafe" + +// Hasher hashes values of type K. +// Uses runtime AES-based hashing. +type Hasher[K comparable] struct { + hash hashfn + seed uintptr +} + +// NewHasher creates a new Hasher[K] with a random seed. +func NewHasher[K comparable]() Hasher[K] { + return Hasher[K]{ + hash: getRuntimeHasher[K](), + seed: newHashSeed(), + } +} + +// NewSeed returns a copy of |h| with a new hash seed. +func NewSeed[K comparable](h Hasher[K]) Hasher[K] { + return Hasher[K]{ + hash: h.hash, + seed: newHashSeed(), + } +} + +// Hash hashes |key|. +func (h Hasher[K]) Hash(key K) uint64 { + // promise to the compiler that pointer + // |p| does not escape the stack. + p := noescape(unsafe.Pointer(&key)) + return uint64(h.hash(p, h.seed)) +} diff --git a/common/utils/concurrent-swiss-map/maphash/runtime.go b/common/utils/concurrent-swiss-map/maphash/runtime.go new file mode 100644 index 000000000..b192dde8b --- /dev/null +++ b/common/utils/concurrent-swiss-map/maphash/runtime.go @@ -0,0 +1,117 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.18 || go1.19 +// +build go1.18 go1.19 + +package maphash + +import ( + "math/rand" + "unsafe" +) + +type hashfn func(unsafe.Pointer, uintptr) uintptr + +func getRuntimeHasher[K comparable]() (h hashfn) { + a := any(make(map[K]struct{})) + i := (*mapiface)(unsafe.Pointer(&a)) + h = i.typ.hasher + return +} + +//nolint:gosec +var hashSeed = rand.Int() + +func newHashSeed() uintptr { + return uintptr(hashSeed) +} + +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime (via pkg "strings"); see issues 23382 and 7921. +// +//go:nosplit +//go:nocheckptr +//nolint:staticcheck +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +type mapiface struct { + typ *maptype + val *hmap +} + +// go/src/runtime/type.go +type maptype struct { + typ _type + key *_type + elem *_type + bucket *_type + // function for hashing keys (ptr to key, seed) -> hash + hasher func(unsafe.Pointer, uintptr) uintptr + keysize uint8 + elemsize uint8 + bucketsize uint16 + flags uint32 +} + +// go/src/runtime/map.go +type hmap struct { + count int + flags uint8 + B uint8 + noverflow uint16 + // hash seed + hash0 uint32 + buckets unsafe.Pointer + oldbuckets unsafe.Pointer + nevacuate uintptr + // true type is *mapextra + // but we don't need this data + extra unsafe.Pointer +} + +// go/src/runtime/type.go +type ( + tflag uint8 + nameOff int32 + typeOff int32 +) + +// go/src/runtime/type.go +type _type struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag tflag + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str nameOff + ptrToThis typeOff +} diff --git a/common/utils/concurrent-swiss-map/swiss/LICENSE b/common/utils/concurrent-swiss-map/swiss/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/common/utils/concurrent-swiss-map/swiss/README.md b/common/utils/concurrent-swiss-map/swiss/README.md new file mode 100644 index 000000000..cfb41531b --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/README.md @@ -0,0 +1,2 @@ +# swiss +Golang port of Abseil's flat_hash_map diff --git a/common/utils/concurrent-swiss-map/swiss/bits.go b/common/utils/concurrent-swiss-map/swiss/bits.go new file mode 100644 index 000000000..e296ea6e4 --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/bits.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//go:build !amd64 || nosimd + +//nolint:all + +package swiss + +import ( + "math/bits" + "unsafe" +) + +const ( + groupSize = 8 + maxAvgGroupLoad = 7 + + loBits uint64 = 0x0101010101010101 + hiBits uint64 = 0x8080808080808080 +) + +type bitset uint64 + +func metaMatchH2(m *metadata, h h2) bitset { + // https://graphics.stanford.edu/~seander/bithacks.html##ValueInWord + return hasZeroByte(castUint64(m) ^ (loBits * uint64(h))) +} + +func metaMatchEmpty(m *metadata) bitset { + return hasZeroByte(castUint64(m) ^ hiBits) +} + +func nextMatch(b *bitset) uint32 { + s := uint32(bits.TrailingZeros64(uint64(*b))) + *b &= ^(1 << s) // clear bit |s| + return s >> 3 // div by 8 +} + +func hasZeroByte(x uint64) bitset { + return bitset(((x - loBits) & ^(x)) & hiBits) +} + +func castUint64(m *metadata) uint64 { + return *(*uint64)((unsafe.Pointer)(m)) +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 diff --git a/common/utils/concurrent-swiss-map/swiss/bits_amd64.go b/common/utils/concurrent-swiss-map/swiss/bits_amd64.go new file mode 100644 index 000000000..a46474d77 --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/bits_amd64.go @@ -0,0 +1,52 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//nolint:all +//go:build amd64 && !nosimd + +package swiss + +import ( + "github.com/mhmtszr/concurrent-swiss-map/swiss/simd" + "math/bits" + _ "unsafe" +) + +const ( + groupSize = 16 + maxAvgGroupLoad = 14 +) + +type bitset uint16 + +//nolint:all +func metaMatchH2(m *metadata, h h2) bitset { + b := simd.MatchMetadata((*[16]int8)(m), int8(h)) + return bitset(b) +} + +//nolint:all +func metaMatchEmpty(m *metadata) bitset { + b := simd.MatchMetadata((*[16]int8)(m), empty) + return bitset(b) +} + +//nolint:all +func nextMatch(b *bitset) (s uint32) { + s = uint32(bits.TrailingZeros16(uint16(*b))) + *b &= ^(1 << s) // clear bit |s| + return +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 diff --git a/common/utils/concurrent-swiss-map/swiss/map.go b/common/utils/concurrent-swiss-map/swiss/map.go new file mode 100644 index 000000000..8fea5235c --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/map.go @@ -0,0 +1,357 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package swiss + +import ( + "github.com/mhmtszr/concurrent-swiss-map/maphash" +) + +const ( + maxLoadFactor = float32(maxAvgGroupLoad) / float32(groupSize) +) + +// Map is an open-addressing hash map +// based on Abseil's flat_hash_map. +type Map[K comparable, V any] struct { + ctrl []metadata + groups []group[K, V] + hash maphash.Hasher[K] + resident uint32 + dead uint32 + limit uint32 +} + +// metadata is the h2 metadata array for a group. +// find operations first probe the controls bytes +// to filter candidates before matching keys +type metadata [groupSize]int8 + +// group is a group of 16 key-value pairs +type group[K comparable, V any] struct { + keys [groupSize]K + values [groupSize]V +} + +const ( + h1Mask uint64 = 0xffff_ffff_ffff_ff80 + h2Mask uint64 = 0x0000_0000_0000_007f + empty int8 = -128 // 0b1000_0000 + tombstone int8 = -2 // 0b1111_1110 +) + +// h1 is a 57 bit hash prefix +type h1 uint64 + +// h2 is a 7 bit hash suffix +type h2 int8 + +// NewMap constructs a Map. +func NewMap[K comparable, V any](sz uint32) (m *Map[K, V]) { + groups := numGroups(sz) + m = &Map[K, V]{ + ctrl: make([]metadata, groups), + groups: make([]group[K, V], groups), + hash: maphash.NewHasher[K](), + limit: groups * maxAvgGroupLoad, + } + for i := range m.ctrl { + m.ctrl[i] = newEmptyMetadata() + } + return +} + +func (m *Map[K, V]) HasWithHash(key K, hash uint64) (ok bool) { + hi, lo := splitHash(hash) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if key == m.groups[g].keys[s] { + ok = true + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { + ok = false + return + } + g++ // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +func (m *Map[K, V]) GetWithHash(key K, hash uint64) (value V, ok bool) { + hi, lo := splitHash(hash) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if key == m.groups[g].keys[s] { + value, ok = m.groups[g].values[s], true + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { + ok = false + return + } + g++ // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// Put attempts to insert |key| and |value| +func (m *Map[K, V]) Put(key K, value V) { + if m.resident >= m.limit { + m.rehash(m.nextSize()) + } + hi, lo := splitHash(m.hash.Hash(key)) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if key == m.groups[g].keys[s] { // update + m.groups[g].keys[s] = key + m.groups[g].values[s] = value + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { // insert + s := nextMatch(&matches) + m.groups[g].keys[s] = key + m.groups[g].values[s] = value + m.ctrl[g][s] = int8(lo) + m.resident++ + return + } + g++ // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// Put attempts to insert |key| and |value| +func (m *Map[K, V]) PutWithHash(key K, value V, hash uint64) { + if m.resident >= m.limit { + m.rehash(m.nextSize()) + } + hi, lo := splitHash(hash) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if key == m.groups[g].keys[s] { // update + m.groups[g].keys[s] = key + m.groups[g].values[s] = value + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { // insert + s := nextMatch(&matches) + m.groups[g].keys[s] = key + m.groups[g].values[s] = value + m.ctrl[g][s] = int8(lo) + m.resident++ + return + } + g++ // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +func (m *Map[K, V]) DeleteWithHash(key K, hash uint64) (ok bool) { + hi, lo := splitHash(hash) + g := probeStart(hi, len(m.groups)) + for { + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if key == m.groups[g].keys[s] { + ok = true + // optimization: if |m.ctrl[g]| contains any empty + // metadata bytes, we can physically delete |key| + // rather than placing a tombstone. + // The observation is that any probes into group |g| + // would already be terminated by the existing empty + // slot, and therefore reclaiming slot |s| will not + // cause premature termination of probes into |g|. + if metaMatchEmpty(&m.ctrl[g]) != 0 { + m.ctrl[g][s] = empty + m.resident-- + } else { + m.ctrl[g][s] = tombstone + m.dead++ + } + var k K + var v V + m.groups[g].keys[s] = k + m.groups[g].values[s] = v + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { // |key| absent + ok = false + return + } + g++ // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// Clear removes all elements from the Map. +func (m *Map[K, V]) Clear() { + for i, c := range m.ctrl { + for j := range c { + m.ctrl[i][j] = empty + } + } + var k K + var v V + for i := range m.groups { + g := &m.groups[i] + for i := range g.keys { + g.keys[i] = k + g.values[i] = v + } + } + m.resident, m.dead = 0, 0 +} + +// Iter iterates the elements of the Map, passing them to the callback. +// It guarantees that any key in the Map will be visited only once, and +// for un-mutated Maps, every key will be visited once. If the Map is +// Mutated during iteration, mutations will be reflected on return from +// Iter, but the set of keys visited by Iter is non-deterministic. +// +//nolint:gosec +func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) bool { + // take a consistent view of the table in case + // we rehash during iteration + ctrl, groups := m.ctrl, m.groups + // pick a random starting group + g := randIntN(len(groups)) + for n := 0; n < len(groups); n++ { + for s, c := range ctrl[g] { + if c == empty || c == tombstone { + continue + } + k, v := groups[g].keys[s], groups[g].values[s] + if stop := cb(k, v); stop { + return stop + } + } + g++ + if g >= uint32(len(groups)) { + g = 0 + } + } + return false +} + +// Count returns the number of elements in the Map. +func (m *Map[K, V]) Count() int { + return int(m.resident - m.dead) +} + +func (m *Map[K, V]) nextSize() (n uint32) { + n = uint32(len(m.groups)) * 2 + if m.dead >= (m.resident / 2) { + n = uint32(len(m.groups)) + } + return +} + +func (m *Map[K, V]) rehash(n uint32) { + groups, ctrl := m.groups, m.ctrl + m.groups = make([]group[K, V], n) + m.ctrl = make([]metadata, n) + for i := range m.ctrl { + m.ctrl[i] = newEmptyMetadata() + } + m.hash = maphash.NewSeed(m.hash) + m.limit = n * maxAvgGroupLoad + m.resident, m.dead = 0, 0 + for g := range ctrl { + for s := range ctrl[g] { + c := ctrl[g][s] + if c == empty || c == tombstone { + continue + } + m.Put(groups[g].keys[s], groups[g].values[s]) + } + } +} + +// numGroups returns the minimum number of groups needed to store |n| elems. +func numGroups(n uint32) (groups uint32) { + groups = (n + maxAvgGroupLoad - 1) / maxAvgGroupLoad + if groups == 0 { + groups = 1 + } + return +} + +func newEmptyMetadata() (meta metadata) { + for i := range meta { + meta[i] = empty + } + return +} + +func splitHash(h uint64) (h1, h2) { + return h1((h & h1Mask) >> 7), h2(h & h2Mask) +} + +func probeStart(hi h1, groups int) uint32 { + return fastModN(uint32(hi), uint32(groups)) +} + +// lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func fastModN(x, n uint32) uint32 { + return uint32((uint64(x) * uint64(n)) >> 32) +} + +// randIntN returns a random number in the interval [0, n). +func randIntN(n int) uint32 { + return fastModN(fastrand(), uint32(n)) +} diff --git a/common/utils/concurrent-swiss-map/swiss/simd/match.s b/common/utils/concurrent-swiss-map/swiss/simd/match.s new file mode 100644 index 000000000..a87a806b1 --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/simd/match.s @@ -0,0 +1,19 @@ +// Code generated by command: go run asm.go -out match.s -stubs match_amd64.go. DO NOT EDIT. +//nolint +//go:build amd64 + +#include "textflag.h" + +// func MatchMetadata(metadata *[16]int8, hash int8) uint16 +// Requires: SSE2, SSSE3 +TEXT ·MatchMetadata(SB), NOSPLIT, $0-18 + MOVQ metadata+0(FP), AX + MOVBLSX hash+8(FP), CX + MOVD CX, X0 + PXOR X1, X1 + PSHUFB X1, X0 + MOVOU (AX), X1 + PCMPEQB X1, X0 + PMOVMSKB X0, AX + MOVW AX, ret+16(FP) + RET diff --git a/common/utils/concurrent-swiss-map/swiss/simd/match_amd64.go b/common/utils/concurrent-swiss-map/swiss/simd/match_amd64.go new file mode 100644 index 000000000..1dcf6f578 --- /dev/null +++ b/common/utils/concurrent-swiss-map/swiss/simd/match_amd64.go @@ -0,0 +1,9 @@ +// Code generated by command: go run asm.go -out match.s -stubs match_amd64.go. DO NOT EDIT. +//nolint:all +//go:build amd64 + +package simd + +// MatchMetadata performs a 16-way probe of |metadata| using SSE instructions +// nb: |metadata| must be an aligned pointer +func MatchMetadata(metadata *[16]int8, hash int8) uint16 diff --git a/common/utils/limit/NOTICE b/common/utils/limit/NOTICE deleted file mode 100644 index 73d68a92f..000000000 --- a/common/utils/limit/NOTICE +++ /dev/null @@ -1,13 +0,0 @@ - Copyright 2025 肖其顿 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/common/utils/limit/README.md b/common/utils/limit/README.md deleted file mode 100644 index cb911dbaf..000000000 --- a/common/utils/limit/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# limit [![PkgGoDev](https://pkg.go.dev/badge/github.com/xiaoqidun/limit)](https://pkg.go.dev/github.com/xiaoqidun/limit) -一个高性能、并发安全的 Go 语言动态速率限制器 - -# 安装指南 -```shell -go get -u github.com/xiaoqidun/limit -``` - -# 快速开始 -```go -package main - -import ( - "fmt" - - "github.com/xiaoqidun/limit" - "golang.org/x/time/rate" -) - -func main() { - // 1. 创建一个新的 Limiter 实例 - limiter := limit.New() - // 2. 确保在程序退出前优雅地停止后台任务,这非常重要 - defer limiter.Stop() - // 3. 为任意键 "some-key" 获取一个速率限制器 - // - rate.Limit(2): 表示速率为 "每秒2个请求" - // - 2: 表示桶的容量 (Burst),允许瞬时处理2个请求 - rateLimiter := limiter.Get("some-key", rate.Limit(2), 2) - // 4. 模拟3次连续的突发请求 - // 由于速率和容量都为2,只有前两次请求能立即成功 - for i := 0; i < 3; i++ { - if rateLimiter.Allow() { - fmt.Printf("请求 %d: 已允许\n", i+1) - } else { - fmt.Printf("请求 %d: 已拒绝\n", i+1) - } - } -} -``` - -# 授权协议 -本项目使用 [Apache License 2.0](https://github.com/xiaoqidun/limit/blob/main/LICENSE) 授权协议 \ No newline at end of file diff --git a/common/utils/limit/example_test.go b/common/utils/limit/example_test.go deleted file mode 100644 index 5f96cbba9..000000000 --- a/common/utils/limit/example_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2025 肖其顿 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package limit_test - -import ( - "fmt" - "time" - - "github.com/xiaoqidun/limit" - "golang.org/x/time/rate" -) - -// ExampleLimiter 演示了 limit 包的基本用法。 -func ExampleLimiter() { - // 创建一个使用默认配置的 Limiter 实例 - limiter := limit.New() - // 程序退出前,优雅地停止后台任务,这非常重要 - defer limiter.Stop() - // 为一个特定的测试键获取一个速率限制器 - // 限制为每秒2个请求,最多允许3个并发(桶容量) - testKey := "testKey" - rateLimiter := limiter.Get(testKey, rate.Limit(2), 3) - // 模拟连续的请求 - for i := 0; i < 5; i++ { - if rateLimiter.Allow() { - fmt.Printf("请求 %d: 已允许\n", i+1) - } else { - fmt.Printf("请求 %d: 已拒绝\n", i+1) - } - time.Sleep(100 * time.Millisecond) - } - // 手动移除一个不再需要的限制器 - limiter.Del(testKey) - // Output: - // 请求 1: 已允许 - // 请求 2: 已允许 - // 请求 3: 已允许 - // 请求 4: 已拒绝 - // 请求 5: 已拒绝 -} - -// ExampleNewWithConfig 展示了如何使用自定义配置。 -func ExampleNewWithConfig() { - // 自定义配置 - config := limit.Config{ - ShardCount: 64, // 分片数量,必须是2的幂 - GCInterval: 5 * time.Minute, // GC 检查周期 - Expiration: 15 * time.Minute, // 限制器过期时间 - } - // 使用自定义配置创建一个 Limiter 实例 - customLimiter := limit.NewWithConfig(config) - defer customLimiter.Stop() - fmt.Println("使用自定义配置的限制器已成功创建") - // Output: - // 使用自定义配置的限制器已成功创建 -} diff --git a/common/utils/limit/go.mod b/common/utils/limit/go.mod deleted file mode 100644 index f291a2948..000000000 --- a/common/utils/limit/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/xiaoqidun/limit - -go 1.20 - -require golang.org/x/time v0.8.0 diff --git a/common/utils/limit/go.sum b/common/utils/limit/go.sum deleted file mode 100644 index d06eb417b..000000000 --- a/common/utils/limit/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= -golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/common/utils/limit/limit.go b/common/utils/limit/limit.go deleted file mode 100644 index fd89e16f3..000000000 --- a/common/utils/limit/limit.go +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright 2025 肖其顿 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package limit 提供了一个高性能、并发安全的动态速率限制器。 -// 它使用分片锁来减少高并发下的锁竞争,并能自动清理长期未使用的限制器。 -package limit - -import ( - "hash" - "hash/fnv" - "sync" - "sync/atomic" - "time" - - "golang.org/x/time/rate" -) - -// defaultShardCount 是默认的分片数量,设为2的幂可以优化哈希计算。 -const defaultShardCount = 32 - -// Config 定义了 Limiter 的可配置项。 -type Config struct { - // ShardCount 指定分片数量,必须是2的幂。如果为0或无效值,则使用默认值32。 - ShardCount int - // GCInterval 指定GC周期,即检查并清理过期限制器的间隔。如果为0,则使用默认值10分钟。 - GCInterval time.Duration - // Expiration 指定过期时间,即限制器在最后一次使用后能存活多久。如果为0,则使用默认值30分钟。 - Expiration time.Duration -} - -// Limiter 是一个高性能、分片实现的动态速率限制器。 -// 它的实例在并发使用时是安全的。 -type Limiter struct { - // 存储所有分片 - shards []*shard - // 配置信息 - config Config - // 标记限制器是否已停止 - stopped atomic.Bool - // 确保Stop方法只执行一次 - stopOnce sync.Once -} - -// New 使用默认配置创建一个新的 Limiter 实例。 -func New() *Limiter { - return NewWithConfig(Config{}) -} - -// NewWithConfig 根据提供的配置创建一个新的 Limiter 实例。 -func NewWithConfig(config Config) *Limiter { - // 如果未设置,则使用默认值 - if config.ShardCount == 0 { - config.ShardCount = defaultShardCount - } - if config.GCInterval == 0 { - config.GCInterval = 10 * time.Minute - } - if config.Expiration == 0 { - config.Expiration = 30 * time.Minute - } - // 确保分片数量是2的幂,以便进行高效的位运算 - if config.ShardCount <= 0 || (config.ShardCount&(config.ShardCount-1)) != 0 { - config.ShardCount = defaultShardCount - } - l := &Limiter{ - shards: make([]*shard, config.ShardCount), - config: config, - } - // 初始化所有分片 - for i := 0; i < config.ShardCount; i++ { - l.shards[i] = newShard(config.GCInterval, config.Expiration) - } - return l -} - -// Get 获取或创建一个与指定键关联的速率限制器。 -// 如果限制器已存在,它会根据传入的 r (速率) 和 b (并发数) 更新其配置。 -// 如果 Limiter 实例已被 Stop 方法关闭,此方法将返回 nil。 -func (l *Limiter) Get(k string, r rate.Limit, b int) *rate.Limiter { - // 快速路径检查,避免在已停止时进行哈希和查找 - if l.stopped.Load() { - return nil - } - // 定位到具体分片进行操作 - return l.getShard(k).get(k, r, b) -} - -// Del 手动移除一个与指定键关联的速率限制器。 -// 如果 Limiter 实例已被 Stop 方法关闭,此方法不执行任何操作。 -func (l *Limiter) Del(k string) { - // 快速路径检查 - if l.stopped.Load() { - return - } - // 定位到具体分片进行操作 - l.getShard(k).del(k) -} - -// Stop 停止 Limiter 的所有后台清理任务,并释放相关资源。 -// 此方法对于并发调用是安全的,并且可以被多次调用。 -func (l *Limiter) Stop() { - l.stopOnce.Do(func() { - l.stopped.Store(true) - for _, s := range l.shards { - s.stop() - } - }) -} - -// getShard 根据key的哈希值获取对应的分片。 -func (l *Limiter) getShard(key string) *shard { - hasher := fnvHasherPool.Get().(hash.Hash32) - defer func() { - hasher.Reset() - fnvHasherPool.Put(hasher) - }() - _, _ = hasher.Write([]byte(key)) // FNV-1a never returns an error. - // 使用位运算代替取模,提高效率 - return l.shards[hasher.Sum32()&(uint32(l.config.ShardCount)-1)] -} - -// shard 代表 Limiter 的一个分片,它包含独立的锁和数据,以减少全局锁竞争。 -type shard struct { - mutex sync.Mutex - stopCh chan struct{} - limiter map[string]*session - stopOnce sync.Once - waitGroup sync.WaitGroup -} - -// newShard 创建一个新的分片实例,并启动其gc任务。 -func newShard(gcInterval, expiration time.Duration) *shard { - s := &shard{ - // mutex 会被自动初始化为其零值(未锁定状态) - stopCh: make(chan struct{}), - limiter: make(map[string]*session), - } - s.waitGroup.Add(1) - go s.gc(gcInterval, expiration) - return s -} - -// gc 定期清理分片中过期的限制器。 -func (s *shard) gc(interval, expiration time.Duration) { - defer s.waitGroup.Done() - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - // 优先检查停止信号,确保能快速响应 - select { - case <-s.stopCh: - return - default: - } - select { - case <-ticker.C: - s.mutex.Lock() - // 再次检查分片是否已停止,防止在等待锁期间被停止 - if s.limiter == nil { - s.mutex.Unlock() - return - } - for k, v := range s.limiter { - // 清理过期的限制器 - if time.Since(v.lastGet) > expiration { - // 将 session 对象放回池中前,重置其状态 - v.limiter = nil - v.lastGet = time.Time{} - sessionPool.Put(v) - delete(s.limiter, k) - } - } - s.mutex.Unlock() - case <-s.stopCh: - // 收到停止信号,退出goroutine - return - } - } -} - -// get 获取或创建一个新的速率限制器,如果已存在则更新其配置。 -func (s *shard) get(k string, r rate.Limit, b int) *rate.Limiter { - s.mutex.Lock() - defer s.mutex.Unlock() - // 检查分片是否已停止 - if s.limiter == nil { - return nil - } - sess, ok := s.limiter[k] - if !ok { - // 从池中获取 session 对象 - sess = sessionPool.Get().(*session) - sess.limiter = rate.NewLimiter(r, b) - s.limiter[k] = sess - } else { - // 如果已存在,则更新其速率和并发数 - sess.limiter.SetLimit(r) - sess.limiter.SetBurst(b) - } - sess.lastGet = time.Now() - return sess.limiter -} - -// del 从分片中移除一个键的速率限制器。 -func (s *shard) del(k string) { - s.mutex.Lock() - defer s.mutex.Unlock() - // 检查分片是否已停止 - if s.limiter == nil { - return - } - if sess, ok := s.limiter[k]; ok { - // 将 session 对象放回池中前,重置其状态 - sess.limiter = nil - sess.lastGet = time.Time{} - sessionPool.Put(sess) - delete(s.limiter, k) - } -} - -// stop 停止分片的gc任务,并同步等待其完成后再清理资源。 -func (s *shard) stop() { - // 使用 sync.Once 确保 channel 只被关闭一次,彻底避免并发风险 - s.stopOnce.Do(func() { - close(s.stopCh) - }) - // 等待 gc goroutine 完全退出 - s.waitGroup.Wait() - // 锁定并进行最终的资源清理 - // 因为 gc 已经退出,所以此时只有 Get/Del 会竞争锁 - s.mutex.Lock() - defer s.mutex.Unlock() - // 检查是否已被清理,防止重复操作 - if s.limiter == nil { - return - } - // 将所有 session 对象放回对象池 - for _, sess := range s.limiter { - sess.limiter = nil - sess.lastGet = time.Time{} - sessionPool.Put(sess) - } - // 清理map,释放内存,并作为停止标记 - s.limiter = nil -} - -// session 存储每个键的速率限制器实例和最后访问时间。 -type session struct { - // 最后一次访问时间 - lastGet time.Time - // 速率限制器 - limiter *rate.Limiter -} - -// sessionPool 使用 sync.Pool 来复用 session 对象,以减少 GC 压力。 -var sessionPool = sync.Pool{ - New: func() interface{} { - return new(session) - }, -} - -// fnvHasherPool 使用 sync.Pool 来复用 FNV-1a 哈希对象,以减少高并发下的内存分配。 -var fnvHasherPool = sync.Pool{ - New: func() interface{} { - return fnv.New32a() - }, -} diff --git a/common/utils/limit/limit_test.go b/common/utils/limit/limit_test.go deleted file mode 100644 index 9b5e43e53..000000000 --- a/common/utils/limit/limit_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2025 肖其顿 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package limit - -import ( - "fmt" - "sync" - "testing" - "time" - - "golang.org/x/time/rate" -) - -// TestLimiter 覆盖了 Limiter 的主要功能。 -func TestLimiter(t *testing.T) { - // 子测试:验证基本的允许/拒绝逻辑 - t.Run("基本功能测试", func(t *testing.T) { - limiter := New() - defer limiter.Stop() - key := "测试键" - // 创建一个每秒2个令牌,桶容量为1的限制器 - rl := limiter.Get(key, rate.Limit(2), 1) - if rl == nil { - t.Fatal("limiter.Get() 意外返回 nil,测试无法继续") - } - if !rl.Allow() { - t.Error("rl.Allow(): 首次调用应返回 true, 实际为 false") - } - if rl.Allow() { - t.Error("rl.Allow(): 超出突发容量的调用应返回 false, 实际为 true") - } - time.Sleep(500 * time.Millisecond) - if !rl.Allow() { - t.Error("rl.Allow(): 令牌补充后的调用应返回 true, 实际为 false") - } - }) - - // 子测试:验证 Del 方法的功能 - t.Run("删除功能测试", func(t *testing.T) { - limiter := New() - defer limiter.Stop() - key := "测试键" - rl1 := limiter.Get(key, rate.Limit(2), 1) - if !rl1.Allow() { - t.Fatal("获取限制器后的首次 Allow() 调用失败") - } - limiter.Del(key) - rl2 := limiter.Get(key, rate.Limit(2), 1) - if !rl2.Allow() { - t.Error("Del() 后重新获取的限制器未能允许请求") - } - }) - - // 子测试:验证 Stop 方法的功能 - t.Run("停止功能测试", func(t *testing.T) { - limiter := New() - limiter.Stop() - if rl := limiter.Get("任意键", 1, 1); rl != nil { - t.Error("Stop() 后 Get() 应返回 nil, 实际返回了有效实例") - } - // 多次调用 Stop 不应引发 panic - limiter.Stop() - }) - - // 子测试:验证并发安全性 - t.Run("并发安全测试", func(t *testing.T) { - limiter := New() - defer limiter.Stop() - var wg sync.WaitGroup - numGoroutines := 100 - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - key := fmt.Sprintf("并发测试键-%d", i) - if limiter.Get(key, rate.Limit(10), 5) == nil { - t.Errorf("并发获取键 '%s' 时, Get() 意外返回 nil", key) - } - }(i) - } - wg.Wait() - }) -} diff --git a/common/utils/timer/LICENSE b/common/utils/timer/LICENSE new file mode 100644 index 000000000..beb8af51b --- /dev/null +++ b/common/utils/timer/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020-2021 蚂蚁实验室 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/common/utils/timer/README.md b/common/utils/timer/README.md new file mode 100644 index 000000000..46b1141eb --- /dev/null +++ b/common/utils/timer/README.md @@ -0,0 +1,130 @@ +## timer +[![Go](https://github.com/antlabs/timer/workflows/Go/badge.svg)](https://github.com/antlabs/timer/actions) +[![codecov](https://codecov.io/gh/antlabs/timer/branch/master/graph/badge.svg)](https://codecov.io/gh/antlabs/timer) + +timer是高性能定时器库 +## feature +* 支持一次性定时器 +* 支持周期性定时器 +* 支持多种数据结构后端,最小堆,5级时间轮 + +## 一次性定时器 +```go +import ( + "github.com/antlabs/timer" + "log" +) + +func main() { + tm := timer.NewTimer() + + tm.AfterFunc(1*time.Second, func() { + log.Printf("after\n") + }) + + tm.AfterFunc(10*time.Second, func() { + log.Printf("after\n") + }) + tm.Run() +} +``` +## 周期性定时器 +```go +import ( + "github.com/antlabs/timer" + "log" +) + +func main() { + tm := timer.NewTimer() + + tm.ScheduleFunc(1*time.Second, func() { + log.Printf("schedule\n") + }) + + tm.Run() +} +``` +## 自定义周期性定时器 +实现时间翻倍定时的例子 +```go +type curstomTest struct { + count int +} +// 只要实现Next接口就行 +func (c *curstomTest) Next(now time.Time) (rv time.Time) { + rv = now.Add(time.Duration(c.count) * time.Millisecond * 10) + c.count++ + return +} + +func main() { + tm := timer.NewTimer(timer.WithMinHeap()) + node := tm.CustomFunc(&curstomTest{count: 1}, func() { + log.Printf("%v\n", time.Now()) + }) + tm.Run() +} +``` +## 取消某一个定时器 +```go +import ( + "log" + "time" + + "github.com/antlabs/timer" +) + +func main() { + + tm := timer.NewTimer() + + // 只会打印2 time.Second + tm.AfterFunc(2*time.Second, func() { + log.Printf("2 time.Second") + }) + + // tk3 会被 tk3.Stop()函数调用取消掉 + tk3 := tm.AfterFunc(3*time.Second, func() { + log.Printf("3 time.Second") + }) + + tk3.Stop() //取消tk3 + + tm.Run() +} +``` +## 选择不同的的数据结构 +```go +import ( + "github.com/antlabs/timer" + "log" +) + +func main() { + tm := timer.NewTimer(timer.WithMinHeap())// 选择最小堆,默认时间轮 +} +``` +## benchmark + +github.com/antlabs/timer 性能最高 +``` +goos: linux +goarch: amd64 +pkg: benchmark +Benchmark_antlabs_Timer_AddTimer/N-1m-16 9177537 124 ns/op +Benchmark_antlabs_Timer_AddTimer/N-5m-16 10152950 128 ns/op +Benchmark_antlabs_Timer_AddTimer/N-10m-16 9955639 127 ns/op +Benchmark_RussellLuo_Timingwheel_AddTimer/N-1m-16 5316916 222 ns/op +Benchmark_RussellLuo_Timingwheel_AddTimer/N-5m-16 5848843 218 ns/op +Benchmark_RussellLuo_Timingwheel_AddTimer/N-10m-16 5872621 231 ns/op +Benchmark_ouqiang_Timewheel/N-1m-16 720667 1622 ns/op +Benchmark_ouqiang_Timewheel/N-5m-16 807018 1573 ns/op +Benchmark_ouqiang_Timewheel/N-10m-16 666183 1557 ns/op +Benchmark_Stdlib_AddTimer/N-1m-16 8031864 144 ns/op +Benchmark_Stdlib_AddTimer/N-5m-16 8437442 151 ns/op +Benchmark_Stdlib_AddTimer/N-10m-16 8080659 167 ns/op + +``` +* 压测代码位于 +https://github.com/junelabs/timer-benchmark \ No newline at end of file diff --git a/common/utils/timer/_long-time-test/build.sh b/common/utils/timer/_long-time-test/build.sh new file mode 100755 index 000000000..a86bcc3e1 --- /dev/null +++ b/common/utils/timer/_long-time-test/build.sh @@ -0,0 +1 @@ +go build -race long-time-test.go diff --git a/common/utils/timer/_long-time-test/long-time-test.go b/common/utils/timer/_long-time-test/long-time-test.go new file mode 100644 index 000000000..06eba2188 --- /dev/null +++ b/common/utils/timer/_long-time-test/long-time-test.go @@ -0,0 +1,157 @@ +package main + +import ( + "log" + "sync" + "time" + + "github.com/antlabs/timer" +) + +// 这是一个长时间测试代码 + +// 测试周期执行 +func schedule(tm timer.Timer) { + tm.ScheduleFunc(200*time.Millisecond, func() { + log.Printf("schedule 200 milliseconds\n") + }) + + tm.ScheduleFunc(time.Second, func() { + log.Printf("schedule second\n") + }) + + tm.ScheduleFunc(1*time.Minute, func() { + log.Printf("schedule minute\n") + }) + + tm.ScheduleFunc(1*time.Hour, func() { + log.Printf("schedule hour\n") + }) + + tm.ScheduleFunc(24*time.Hour, func() { + log.Printf("schedule day\n") + }) +} + +// 测试一次性定时器 +func after(tm timer.Timer) { + var wg sync.WaitGroup + wg.Add(4) + defer wg.Wait() + + go func() { + defer wg.Done() + for i := 0; i < 3600*24; i++ { + i := i + tm.AfterFunc(time.Second, func() { + log.Printf("after second:%d\n", i) + }) + time.Sleep(900 * time.Millisecond) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 60*24; i++ { + i := i + tm.AfterFunc(time.Minute, func() { + log.Printf("after minute:%d\n", i) + }) + time.Sleep(50 * time.Second) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 24; i++ { + i := i + tm.AfterFunc(time.Hour, func() { + log.Printf("after hour:%d\n", i) + }) + time.Sleep(59 * time.Minute) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 1; i++ { + i := i + tm.AfterFunc(24*time.Hour, func() { + log.Printf("after day:%d\n", i) + }) + time.Sleep(59 * time.Minute) + } + }() +} + +// 检测 stop after 消息,没有打印是正确的行为 +func stopNode(tm timer.Timer) { + + var wg sync.WaitGroup + wg.Add(4) + defer wg.Wait() + + go func() { + defer wg.Done() + for i := 0; i < 3600*24; i++ { + i := i + node := tm.AfterFunc(time.Second, func() { + log.Printf("stop after second:%d\n", i) + }) + time.Sleep(900 * time.Millisecond) + node.Stop() + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 60*24; i++ { + i := i + node := tm.AfterFunc(time.Minute, func() { + log.Printf("stop after minute:%d\n", i) + }) + time.Sleep(50 * time.Second) + node.Stop() + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 24; i++ { + i := i + node := tm.AfterFunc(time.Hour, func() { + log.Printf("stop after hour:%d\n", i) + }) + time.Sleep(59 * time.Minute) + node.Stop() + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 1; i++ { + i := i + node := tm.AfterFunc(23*time.Hour, func() { + log.Printf("stop after day:%d\n", i) + }) + time.Sleep(22 * time.Hour) + node.Stop() + } + }() +} + +func main() { + + log.SetFlags(log.Ldate | log.Lmicroseconds) + tm := timer.NewTimer() + + go schedule(tm) + go after(tm) + go stopNode(tm) + + go func() { + time.Sleep(time.Hour*24 + time.Hour) + tm.Stop() + }() + tm.Run() +} diff --git a/common/utils/timer/go.mod b/common/utils/timer/go.mod new file mode 100644 index 000000000..7f06e3d7e --- /dev/null +++ b/common/utils/timer/go.mod @@ -0,0 +1,5 @@ +module github.com/antlabs/timer + +go 1.19 + +require github.com/antlabs/stl v0.0.2 diff --git a/common/utils/timer/go.sum b/common/utils/timer/go.sum new file mode 100644 index 000000000..46a6f2770 --- /dev/null +++ b/common/utils/timer/go.sum @@ -0,0 +1,2 @@ +github.com/antlabs/stl v0.0.2 h1:sna1AXR5yIkNE9lWhCcKbheFJSVfCa3vugnGyakI79s= +github.com/antlabs/stl v0.0.2/go.mod h1:kKrO4xrn9cfS1mJVo+/BqePZjAYMXqD0amGF2Ouq7ac= diff --git a/common/utils/timer/min_heap.go b/common/utils/timer/min_heap.go new file mode 100644 index 000000000..d855e9d62 --- /dev/null +++ b/common/utils/timer/min_heap.go @@ -0,0 +1,221 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "container/heap" + "context" + "sync" + "sync/atomic" + "time" + + "github.com/panjf2000/ants/v2" +) + +var _ Timer = (*minHeap)(nil) + +var defaultTimeout = time.Hour + +type minHeap struct { + mu sync.Mutex + minHeaps + chAdd chan struct{} + ctx context.Context + cancel context.CancelFunc + wait sync.WaitGroup + tm *time.Timer + runCount int32 // 单元测试时使用 +} + +// 一次性定时器 +func (m *minHeap) AfterFunc(expire time.Duration, callback func()) TimeNoder { + return m.addCallback(expire, nil, callback, false) +} + +// 周期性定时器 +func (m *minHeap) ScheduleFunc(expire time.Duration, callback func()) TimeNoder { + return m.addCallback(expire, nil, callback, true) +} + +// 自定义下次的时间 +func (m *minHeap) CustomFunc(n Next, callback func()) TimeNoder { + return m.addCallback(time.Duration(0), n, callback, true) +} + +// 加任务 +func (m *minHeap) addCallback(expire time.Duration, n Next, callback func(), isSchedule bool) TimeNoder { + select { + case <-m.ctx.Done(): + panic("cannot add a task to a closed timer") + default: + } + + node := minHeapNode{ + callback: callback, + userExpire: expire, + next: n, + absExpire: time.Now().Add(expire), + isSchedule: isSchedule, + root: m, + } + + if n != nil { + node.absExpire = n.Next(time.Now()) + } + + m.mu.Lock() + heap.Push(&m.minHeaps, &node) + m.wait.Add(1) + m.mu.Unlock() + + select { + case m.chAdd <- struct{}{}: + default: + } + + return &node +} + +func (m *minHeap) removeTimeNode(node *minHeapNode) { + m.mu.Lock() + if node.index < 0 || node.index > int32(len(m.minHeaps)) || int32(len(m.minHeaps)) == 0 { + m.mu.Unlock() + return + } + + heap.Remove(&m.minHeaps, int(node.index)) + m.wait.Done() + m.mu.Unlock() +} + +func (m *minHeap) resetTimeNode(node *minHeapNode, d time.Duration) { + m.mu.Lock() + node.userExpire = d + node.absExpire = time.Now().Add(d) + heap.Fix(&m.minHeaps, int(node.index)) + select { + case m.chAdd <- struct{}{}: + default: + } + m.mu.Unlock() +} + +func (m *minHeap) getNewSleepTime() time.Duration { + if m.minHeaps.Len() == 0 { + return time.Hour + } + + timeout := time.Until(m.minHeaps[0].absExpire) + if timeout < 0 { + timeout = 0 + } + return timeout +} + +var pool, _ = ants.NewPool(-1) + +func (m *minHeap) process() { + for { + m.mu.Lock() + now := time.Now() + // 如果堆中没有元素,就等待 + // 这时候设置一个相对长的时间,避免空转cpu + if m.minHeaps.Len() == 0 { + m.tm.Reset(time.Hour) + m.mu.Unlock() + return + } + + for { + // 取出最小堆的第一个元素 + first := m.minHeaps[0] + + // 时间未到直接过滤掉 + // 只是跳过最近的循环 + if !now.After(first.absExpire) { + break + } + + // 取出待执行的callback + callback := first.callback + // 如果是周期性任务 + if first.isSchedule { + // 计算下次触发的绝对时间点 + first.absExpire = first.Next(now) + // 修改下在堆中的位置 + heap.Fix(&m.minHeaps, int(first.index)) + } else { + // 从堆中删除 + heap.Pop(&m.minHeaps) + m.wait.Done() + } + + // 正在运行的任务数加1 + atomic.AddInt32(&m.runCount, 1) + pool.Submit(func() { + callback() + // 对正在运行的任务数减1 + atomic.AddInt32(&m.runCount, -1) + }) + + // 如果堆中没有元素,就等待 + if m.minHeaps.Len() == 0 { + m.tm.Reset(defaultTimeout) + m.mu.Unlock() + return + } + } + + // 取出第一个元素 + first := m.minHeaps[0] + // 如果第一个元素的时间还没到,就计算下次触发的时间 + if time.Now().Before(first.absExpire) { + to := m.getNewSleepTime() + m.tm.Reset(to) + // fmt.Printf("### now=%v, to = %v, m.minHeaps[0].absExpire = %v\n", time.Now(), to, m.minHeaps[0].absExpire) + m.mu.Unlock() + return + } + m.mu.Unlock() + } +} + +// 运行 +// 为了避免空转cpu, 会等待一个chan, 只要AfterFunc或者ScheduleFunc被调用就会往这个chan里面写值 +func (m *minHeap) Run() { + m.tm = time.NewTimer(time.Hour) + m.process() + for { + select { + case <-m.tm.C: + m.process() + case <-m.chAdd: + m.mu.Lock() + // 极端情况,加完任务立即给删除了, 判断下当前堆中是否有元素 + if m.minHeaps.Len() > 0 { + m.tm.Reset(m.getNewSleepTime()) + } + m.mu.Unlock() + // 进入事件循环,如果为空就会从事件循环里面退出 + case <-m.ctx.Done(): + // 等待所有任务结束 + m.wait.Wait() + return + } + + } +} + +// 停止所有定时器 +func (m *minHeap) Stop() { + m.cancel() +} + +func newMinHeap() (mh *minHeap) { + mh = &minHeap{} + heap.Init(&mh.minHeaps) + mh.chAdd = make(chan struct{}, 1024) + mh.ctx, mh.cancel = context.WithCancel(context.TODO()) + return +} diff --git a/common/utils/timer/min_heap_node.go b/common/utils/timer/min_heap_node.go new file mode 100644 index 000000000..939e690ef --- /dev/null +++ b/common/utils/timer/min_heap_node.go @@ -0,0 +1,62 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "time" +) + +type minHeapNode struct { + callback func() // 用户的callback + absExpire time.Time // 绝对时间 + userExpire time.Duration // 过期时间段 + root *minHeap // 指向最小堆 + next Next // 自定义下个触发的时间点, cronex项目用到了 + index int32 // 在min heap中的索引,方便删除或者重新推入堆中 + isSchedule bool // 是否是周期性任务 +} + +func (m *minHeapNode) Stop() bool { + m.root.removeTimeNode(m) + return true +} +func (m *minHeapNode) Reset(d time.Duration) bool { + m.root.resetTimeNode(m, d) + return true +} + +func (m *minHeapNode) Next(now time.Time) time.Time { + if m.next != nil { + return (m.next).Next(now) + } + return now.Add(m.userExpire) +} + +type minHeaps []*minHeapNode + +func (m minHeaps) Len() int { return len(m) } + +func (m minHeaps) Less(i, j int) bool { return m[i].absExpire.Before(m[j].absExpire) } + +func (m minHeaps) Swap(i, j int) { + m[i], m[j] = m[j], m[i] + m[i].index = int32(i) + m[j].index = int32(j) +} + +func (m *minHeaps) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *m = append(*m, x.(*minHeapNode)) + lastIndex := int32(len(*m) - 1) + (*m)[lastIndex].index = lastIndex +} + +func (m *minHeaps) Pop() any { + old := *m + n := len(old) + x := old[n-1] + *m = old[0 : n-1] + return x +} diff --git a/common/utils/timer/min_heap_node_test.go b/common/utils/timer/min_heap_node_test.go new file mode 100644 index 000000000..6489d4367 --- /dev/null +++ b/common/utils/timer/min_heap_node_test.go @@ -0,0 +1,64 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license + +package timer + +import ( + "container/heap" + "testing" + "time" +) + +func Test_NodeSizeof(t *testing.T) { + t.Run("输出最小堆node的sizeof", func(t *testing.T) { + // t.Logf("minHeapNode size: %d, %d\n", unsafe.Sizeof(minHeapNode{}), unsafe.Sizeof(time.Timer{})) + }) +} +func Test_MinHeap(t *testing.T) { + t.Run("", func(t *testing.T) { + var mh minHeaps + now := time.Now() + n1 := minHeapNode{ + absExpire: now.Add(time.Second), + userExpire: 1 * time.Second, + } + + n2 := minHeapNode{ + absExpire: now.Add(2 * time.Second), + userExpire: 2 * time.Second, + } + + n3 := minHeapNode{ + absExpire: now.Add(3 * time.Second), + userExpire: 3 * time.Second, + } + + n6 := minHeapNode{ + absExpire: now.Add(6 * time.Second), + userExpire: 6 * time.Second, + } + n5 := minHeapNode{ + absExpire: now.Add(5 * time.Second), + userExpire: 5 * time.Second, + } + n4 := minHeapNode{ + absExpire: now.Add(4 * time.Second), + userExpire: 4 * time.Second, + } + mh.Push(&n1) + mh.Push(&n2) + mh.Push(&n3) + mh.Push(&n6) + mh.Push(&n5) + mh.Push(&n4) + + for i := 1; len(mh) > 0; i++ { + v := heap.Pop(&mh).(*minHeapNode) + + if v.userExpire != time.Duration(i)*time.Second { + t.Errorf("index(%d) v.userExpire(%v) != %v", i, v.userExpire, time.Duration(i)*time.Second) + } + } + }) +} diff --git a/common/utils/timer/min_heap_test.go b/common/utils/timer/min_heap_test.go new file mode 100644 index 000000000..e0788c012 --- /dev/null +++ b/common/utils/timer/min_heap_test.go @@ -0,0 +1,329 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "log" + "sync" + "sync/atomic" + "testing" + "time" +) + +// 测试AfterFunc有没有运行以及时间间隔可对 +func Test_MinHeap_AfterFunc_Run(t *testing.T) { + t.Run("1ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + + go tm.Run() + count := int32(0) + + tc := make(chan time.Duration, 2) + + var mu sync.Mutex + isClose := false + now := time.Now() + node1 := tm.AfterFunc(time.Millisecond, func() { + + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + + node2 := tm.AfterFunc(time.Millisecond, func() { + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + + time.Sleep(time.Millisecond * 3) + mu.Lock() + isClose = true + close(tc) + node1.Stop() + node2.Stop() + mu.Unlock() + for tv := range tc { + if tv < time.Millisecond || tv > 2*time.Millisecond { + t.Errorf("tc < time.Millisecond tc > 2*time.Millisecond") + + } + } + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) + + t.Run("10ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + + go tm.Run() // 运行事件循环 + count := int32(0) + tc := make(chan time.Duration, 2) + + var mu sync.Mutex + isClosed := false + now := time.Now() + node1 := tm.AfterFunc(time.Millisecond*10, func() { + now2 := time.Now() + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClosed { + tc <- time.Since(now) + } + mu.Unlock() + log.Printf("node1.Lock:%v\n", time.Since(now2)) + }) + node2 := tm.AfterFunc(time.Millisecond*10, func() { + now2 := time.Now() + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClosed { + tc <- time.Since(now) + } + mu.Unlock() + log.Printf("node2.Lock:%v\n", time.Since(now2)) + }) + + time.Sleep(time.Millisecond * 24) + now3 := time.Now() + mu.Lock() + node1.Stop() + node2.Stop() + isClosed = true + close(tc) + mu.Unlock() + + log.Printf("node1.Stop:%v\n", time.Since(now3)) + cnt := 1 + for tv := range tc { + left := time.Millisecond * 10 * time.Duration(cnt) + right := time.Duration(cnt) * 2 * 10 * time.Millisecond + if tv < left || tv > right { + t.Errorf("index(%d) (%v)tc < %v || tc > %v", cnt, tv, left, right) + } + // cnt++ + } + if atomic.LoadInt32(&count) != 2 { + + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) + + t.Run("90ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + go tm.Run() + count := int32(0) + tm.AfterFunc(time.Millisecond*90, func() { atomic.AddInt32(&count, 1) }) + tm.AfterFunc(time.Millisecond*90, func() { atomic.AddInt32(&count, 2) }) + + time.Sleep(time.Millisecond * 180) + if atomic.LoadInt32(&count) != 3 { + t.Errorf("count != 3") + } + + }) +} + +// 测试Schedule 运行的周期可对 +func Test_MinHeap_ScheduleFunc_Run(t *testing.T) { + t.Run("1ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + go tm.Run() + count := int32(0) + + _ = tm.ScheduleFunc(2*time.Millisecond, func() { + log.Printf("%v\n", time.Now()) + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) == 2 { + tm.Stop() + } + }) + + time.Sleep(time.Millisecond * 5) + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) + + t.Run("10ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + go tm.Run() + count := int32(0) + tc := make(chan time.Duration, 2) + var mu sync.Mutex + isClosed := false + now := time.Now() + + node := tm.ScheduleFunc(time.Millisecond*10, func() { + mu.Lock() + atomic.AddInt32(&count, 1) + + if atomic.LoadInt32(&count) <= 2 && !isClosed { + tc <- time.Since(now) + } + mu.Unlock() + }) + + time.Sleep(time.Millisecond * 25) + + mu.Lock() + close(tc) + isClosed = true + node.Stop() + mu.Unlock() + + cnt := 1 + for tv := range tc { + left := time.Millisecond * 10 * time.Duration(cnt) + right := time.Duration(cnt) * 2 * 10 * time.Millisecond + if tv < left || tv > right { + t.Errorf("index(%d) (%v)tc < %v || tc > %v", cnt, tv, left, right) + } + cnt++ + } + + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) + + t.Run("30ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + go tm.Run() + count := int32(0) + c := make(chan bool, 1) + + node := tm.ScheduleFunc(time.Millisecond*30, func() { + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) == 2 { + c <- true + } + }) + go func() { + <-c + node.Stop() + }() + + time.Sleep(time.Millisecond * 70) + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) +} + +// 测试Stop是否会等待正在运行的任务结束 +func Test_Run_Stop(t *testing.T) { + t.Run("1ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + count := uint32(0) + tm.AfterFunc(time.Millisecond, func() { atomic.AddUint32(&count, 1) }) + tm.AfterFunc(time.Millisecond, func() { atomic.AddUint32(&count, 1) }) + go func() { + time.Sleep(time.Millisecond * 4) + tm.Stop() + }() + tm.Run() + if atomic.LoadUint32(&count) != 2 { + t.Errorf("count != 2") + } + }) +} + +type curstomTest struct { + count int32 +} + +func (c *curstomTest) Next(now time.Time) (rv time.Time) { + rv = now.Add(time.Duration(c.count) * time.Millisecond * 10) + atomic.AddInt32(&c.count, 1) + return +} + +// 验证自定义函数的运行间隔时间 +func Test_CustomFunc(t *testing.T) { + t.Run("custom", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + // mh := tm.(*minHeap) // 最小堆 + tc := make(chan time.Duration, 2) + now := time.Now() + count := uint32(1) + stop := make(chan bool, 1) + // 自定义函数 + node := tm.CustomFunc(&curstomTest{count: 1}, func() { + + if atomic.LoadUint32(&count) == 2 { + return + } + // 计算运行次数 + atomic.AddUint32(&count, 1) + tc <- time.Since(now) + // 关闭这个任务 + close(stop) + }) + + go func() { + <-stop + node.Stop() + tm.Stop() + }() + + tm.Run() + close(tc) + cnt := 1 + for tv := range tc { + left := time.Millisecond * 10 * time.Duration(cnt) + right := time.Duration(cnt) * 2 * 10 * time.Millisecond + if tv < left || tv > right { + t.Errorf("index(%d) (%v)tc < %v || tc > %v", cnt, tv, left, right) + } + cnt++ + } + if atomic.LoadUint32(&count) != 2 { + t.Errorf("count != 2") + } + + // 正在运行的任务是比较短暂的,所以外部很难 + // if mh.runCount != int32(1) { + // t.Errorf("mh.runCount:%d != 1", mh.runCount) + // } + + }) +} + +// 验证运行次数是符合预期的 +func Test_RunCount(t *testing.T) { + t.Run("runcount-10ms", func(t *testing.T) { + tm := NewTimer(WithMinHeap()) + max := 10 + go func() { + tm.Run() + }() + + count := uint32(0) + for i := 0; i < max; i++ { + tm.ScheduleFunc(time.Millisecond*10, func() { + atomic.AddUint32(&count, 1) + }) + } + + time.Sleep(time.Millisecond * 15) + tm.Stop() + if count != uint32(max) { + t.Errorf("count:%d != %d", count, max) + } + + }) +} diff --git a/common/utils/timer/option.go b/common/utils/timer/option.go new file mode 100644 index 000000000..b2606b6a9 --- /dev/null +++ b/common/utils/timer/option.go @@ -0,0 +1,39 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +type option struct { + timeWheel bool + minHeap bool + skiplist bool + rbtree bool +} + +type Option func(c *option) + +func WithTimeWheel() Option { + return func(o *option) { + o.timeWheel = true + } +} + +func WithMinHeap() Option { + return func(o *option) { + o.minHeap = true + } +} + +// TODO +func WithSkipList() Option { + return func(o *option) { + o.skiplist = true + } +} + +// TODO +func WithRbtree() Option { + return func(o *option) { + o.rbtree = true + } +} diff --git a/common/utils/timer/t_test.go b/common/utils/timer/t_test.go new file mode 100644 index 000000000..90b7b9ace --- /dev/null +++ b/common/utils/timer/t_test.go @@ -0,0 +1,17 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "fmt" + "testing" + "unsafe" +) + +func Test_Look(t *testing.T) { + + tmp := newTimeHead(0, 0) + offset := unsafe.Offsetof(tmp.Head) + fmt.Printf("%d\n", offset) +} diff --git a/common/utils/timer/time_wheel.go b/common/utils/timer/time_wheel.go new file mode 100644 index 000000000..77fc907c6 --- /dev/null +++ b/common/utils/timer/time_wheel.go @@ -0,0 +1,294 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "context" + "fmt" + "sync/atomic" + "time" + "unsafe" + + "github.com/antlabs/stl/list" +) + +const ( + nearShift = 8 + + nearSize = 1 << nearShift + + levelShift = 6 + + levelSize = 1 << levelShift + + nearMask = nearSize - 1 + + levelMask = levelSize - 1 +) + +type timeWheel struct { + // 单调递增累加值, 走过一个时间片就+1 + jiffies uint64 + + // 256个槽位 + t1 [nearSize]*Time + + // 4个64槽位, 代表不同的刻度 + t2Tot5 [4][levelSize]*Time + + // 时间只精确到10ms + // curTimePoint 为1就是10ms 为2就是20ms + curTimePoint time.Duration + + // 上下文 + ctx context.Context + + // 取消函数 + cancel context.CancelFunc +} + +func newTimeWheel() *timeWheel { + ctx, cancel := context.WithCancel(context.Background()) + + t := &timeWheel{ctx: ctx, cancel: cancel} + + t.init() + + return t +} + +func (t *timeWheel) init() { + for i := 0; i < nearSize; i++ { + t.t1[i] = newTimeHead(1, uint64(i)) + } + + for i := 0; i < 4; i++ { + for j := 0; j < levelSize; j++ { + t.t2Tot5[i][j] = newTimeHead(uint64(i+2), uint64(j)) + } + } + + // t.curTimePoint = get10Ms() +} + +func maxVal() uint64 { + return (1 << (nearShift + 4*levelShift)) - 1 +} + +func levelMax(index int) uint64 { + return 1 << (nearShift + index*levelShift) +} + +func (t *timeWheel) index(n int) uint64 { + return (t.jiffies >> (nearShift + levelShift*n)) & levelMask +} + +func (t *timeWheel) add(node *timeNode, jiffies uint64) *timeNode { + var head *Time + expire := node.expire + idx := expire - jiffies + + level, index := uint64(1), uint64(0) + + if idx < nearSize { + + index = uint64(expire) & nearMask + head = t.t1[index] + + } else { + + max := maxVal() + for i := 0; i <= 3; i++ { + + if idx > max { + idx = max + expire = idx + jiffies + } + + if uint64(idx) < levelMax(i+1) { + index = uint64(expire >> (nearShift + i*levelShift) & levelMask) + head = t.t2Tot5[i][index] + level = uint64(i) + 2 + break + } + } + } + + if head == nil { + panic("not found head") + } + + head.lockPushBack(node, level, index) + + return node +} + +func (t *timeWheel) AfterFunc(expire time.Duration, callback func()) TimeNoder { + jiffies := atomic.LoadUint64(&t.jiffies) + + expire = expire/(time.Millisecond*10) + time.Duration(jiffies) + + node := &timeNode{ + expire: uint64(expire), + callback: callback, + root: t, + } + + return t.add(node, jiffies) +} + +func getExpire(expire time.Duration, jiffies uint64) time.Duration { + return expire/(time.Millisecond*10) + time.Duration(jiffies) +} + +func (t *timeWheel) ScheduleFunc(userExpire time.Duration, callback func()) TimeNoder { + jiffies := atomic.LoadUint64(&t.jiffies) + + expire := getExpire(userExpire, jiffies) + + node := &timeNode{ + userExpire: userExpire, + expire: uint64(expire), + callback: callback, + isSchedule: true, + root: t, + } + + return t.add(node, jiffies) +} + +func (t *timeWheel) Stop() { + t.cancel() +} + +// 移动链表 +func (t *timeWheel) cascade(levelIndex int, index int) { + tmp := newTimeHead(0, 0) + + l := t.t2Tot5[levelIndex][index] + l.Lock() + if l.Len() == 0 { + l.Unlock() + return + } + + l.ReplaceInit(&tmp.Head) + + // 每次链表的元素被移动走,都修改version + l.version.Add(1) + l.Unlock() + + offset := unsafe.Offsetof(tmp.Head) + tmp.ForEachSafe(func(pos *list.Head) { + node := (*timeNode)(pos.Entry(offset)) + t.add(node, atomic.LoadUint64(&t.jiffies)) + }) +} + +// moveAndExec函数功能 +// 1. 先移动到near链表里面 +// 2. near链表节点为空时,从上一层里面移动一些节点到下一层 +// 3. 再执行 +func (t *timeWheel) moveAndExec() { + // 这里时间溢出 + if uint32(t.jiffies) == 0 { + // TODO + // return + } + + // 如果本层的盘子没有定时器,这时候从上层的盘子移动一些过来 + index := t.jiffies & nearMask + if index == 0 { + for i := 0; i <= 3; i++ { + index2 := t.index(i) + t.cascade(i, int(index2)) + if index2 != 0 { + break + } + } + } + + atomic.AddUint64(&t.jiffies, 1) + + t.t1[index].Lock() + if t.t1[index].Len() == 0 { + t.t1[index].Unlock() + return + } + + head := newTimeHead(0, 0) + t1 := t.t1[index] + t1.ReplaceInit(&head.Head) + t1.version.Add(1) + t.t1[index].Unlock() + + // 执行,链表中的定时器 + offset := unsafe.Offsetof(head.Head) + + head.ForEachSafe(func(pos *list.Head) { + val := (*timeNode)(pos.Entry(offset)) + head.Del(pos) + + if val.stop.Load() == haveStop { + return + } + + go val.callback() + + if val.isSchedule { + jiffies := t.jiffies + // 这里的jiffies必须要减去1 + // 当前的callback被调用,已经包含一个时间片,如果不把这个时间片减去, + // 每次多一个时间片,就变成累加器, 最后周期定时器慢慢会变得不准 + val.expire = uint64(getExpire(val.userExpire, jiffies-1)) + t.add(val, jiffies) + } + }) +} + +// get10Ms函数通过参数传递,为了方便测试 +func (t *timeWheel) run(get10Ms func() time.Duration) { + // 先判断是否需要更新 + // 内核里面实现使用了全局jiffies和本地的jiffies比较,应用层没有jiffies,直接使用时间比较 + // 这也是skynet里面的做法 + + ms10 := get10Ms() + + if ms10 < t.curTimePoint { + + fmt.Printf("github.com/antlabs/timer:Time has been called back?from(%d)(%d)\n", + ms10, t.curTimePoint) + + t.curTimePoint = ms10 + return + } + + diff := ms10 - t.curTimePoint + t.curTimePoint = ms10 + + for i := 0; i < int(diff); i++ { + t.moveAndExec() + } +} + +// 自定义, TODO +func (t *timeWheel) CustomFunc(n Next, callback func()) TimeNoder { + return &timeNode{} +} + +func (t *timeWheel) Run() { + t.curTimePoint = get10Ms() + // 10ms精度 + tk := time.NewTicker(time.Millisecond * 10) + defer tk.Stop() + + for { + select { + case <-tk.C: + t.run(get10Ms) + case <-t.ctx.Done(): + return + } + } +} diff --git a/common/utils/timer/time_wheel_node.go b/common/utils/timer/time_wheel_node.go new file mode 100644 index 000000000..c029513e0 --- /dev/null +++ b/common/utils/timer/time_wheel_node.go @@ -0,0 +1,114 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/antlabs/stl/list" +) + +const ( + haveStop = uint32(1) +) + +// 先使用sync.Mutex实现功能 +// 后面使用cas优化 +type Time struct { + timeNode + sync.Mutex + + // |---16bit---|---16bit---|------32bit-----| + // |---level---|---index---|-------seq------| + // level 在near盘子里就是1, 在T2ToTt[0]盘子里就是2起步 + // index 就是各自盘子的索引值 + // seq 自增id + version atomic.Uint64 +} + +func newTimeHead(level uint64, index uint64) *Time { + head := &Time{} + head.version.Store(genVersionHeight(level, index)) + head.Init() + return head +} + +func genVersionHeight(level uint64, index uint64) uint64 { + return level<<(32+16) | index<<32 +} + +func (t *Time) lockPushBack(node *timeNode, level uint64, index uint64) { + t.Lock() + defer t.Unlock() + if node.stop.Load() == haveStop { + return + } + + t.AddTail(&node.Head) + atomic.StorePointer(&node.list, unsafe.Pointer(t)) + //更新节点的version信息 + node.version.Store(t.version.Load()) +} + +type timeNode struct { + expire uint64 + userExpire time.Duration + callback func() + stop atomic.Uint32 + list unsafe.Pointer //存放表头信息 + version atomic.Uint64 //保存节点版本信息 + isSchedule bool + root *timeWheel + list.Head +} + +// 一个timeNode节点有4个状态 +// 1.存在于初始化链表中 +// 2.被移动到tmp链表 +// 3.1 和 3.2是if else的状态 +// +// 3.1被移动到new链表 +// 3.2直接执行 +// +// 1和3.1状态是没有问题的 +// 2和3.2状态会是没有锁保护下的操作,会有数据竞争 +func (t *timeNode) Stop() bool { + + t.stop.Store(haveStop) + + // 使用版本号算法让timeNode知道自己是否被移动了 + // timeNode的version和表头的version一样表示没有被移动可以直接删除 + // 如果不一样,可能在第2或者3.2状态,使用惰性删除 + cpyList := (*Time)(atomic.LoadPointer(&t.list)) + cpyList.Lock() + defer cpyList.Unlock() + if t.version.Load() != cpyList.version.Load() { + return false + } + + cpyList.Del(&t.Head) + return true +} + +// warning: 该函数目前没有稳定 +func (t *timeNode) Reset(expire time.Duration) bool { + cpyList := (*Time)(atomic.LoadPointer(&t.list)) + cpyList.Lock() + defer cpyList.Unlock() + // TODO: 这里有一个问题,如果在执行Reset的时候,这个节点已经被移动到tmp链表 + // if atomic.LoadUint64(&t.version) != atomic.LoadUint64(&cpyList.version) { + // return + // } + cpyList.Del(&t.Head) + jiffies := atomic.LoadUint64(&t.root.jiffies) + + expire = expire/(time.Millisecond*10) + time.Duration(jiffies) + t.expire = uint64(expire) + + t.root.add(t, jiffies) + return true +} diff --git a/common/utils/timer/time_wheel_test.go b/common/utils/timer/time_wheel_test.go new file mode 100644 index 000000000..520f04f1f --- /dev/null +++ b/common/utils/timer/time_wheel_test.go @@ -0,0 +1,189 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "context" + "math" + "sync/atomic" + "testing" + "time" +) + +func Test_maxVal(t *testing.T) { + + if maxVal() != uint64(math.MaxUint32) { + t.Error("maxVal() != uint64(math.MaxUint32)") + } +} + +func Test_LevelMax(t *testing.T) { + if levelMax(1) != uint64(1<<(nearShift+levelShift)) { + t.Error("levelMax(1) != uint64(1<<(nearShift+levelShift))") + } + + if levelMax(2) != uint64(1<<(nearShift+2*levelShift)) { + t.Error("levelMax(2) != uint64(1<<(nearShift+2*levelShift))") + } + + if levelMax(3) != uint64(1<<(nearShift+3*levelShift)) { + t.Error("levelMax(3) != uint64(1<<(nearShift+3*levelShift))") + } + + if levelMax(4) != uint64(1<<(nearShift+4*levelShift)) { + t.Error("levelMax(4) != uint64(1<<(nearShift+4*levelShift))") + } + +} + +func Test_GenVersion(t *testing.T) { + if genVersionHeight(1, 0xf) != uint64(0x0001000f00000000) { + t.Error("genVersionHeight(1, 0xf) != uint64(0x0001000f00000000)") + } + + if genVersionHeight(1, 64) != uint64(0x0001004000000000) { + t.Error("genVersionHeight(2, 0xf) != uint64(0x0001004000000000)") + } + +} + +// 测试1小时 +func Test_hour(t *testing.T) { + tw := newTimeWheel() + + testHour := new(bool) + done := make(chan struct{}, 1) + tw.AfterFunc(time.Hour, func() { + *testHour = true + done <- struct{}{} + }) + + expire := getExpire(time.Hour, 0) + for i := 0; i < int(expire)+10; i++ { + get10Ms := func() time.Duration { + return tw.curTimePoint + 1 + } + tw.run(get10Ms) + } + + select { + case <-done: + case <-time.After(time.Second / 100): + } + + if *testHour == false { + t.Error("testHour == false") + } + +} + +// 测试周期性定时器, 5s +func Test_ScheduleFunc_5s(t *testing.T) { + tw := newTimeWheel() + + var first5 int32 + ctx, cancel := context.WithCancel(context.Background()) + + const total = int32(1000) + + testTime := time.Second * 5 + + tw.ScheduleFunc(testTime, func() { + atomic.AddInt32(&first5, 1) + if atomic.LoadInt32(&first5) == total { + cancel() + } + + }) + + expire := getExpire(testTime*time.Duration(total), 0) + for i := 0; i <= int(expire)+10; i++ { + get10Ms := func() time.Duration { + return tw.curTimePoint + 1 + } + tw.run(get10Ms) + } + + select { + case <-ctx.Done(): + case <-time.After(time.Second / 100): + } + + if total != first5 { + t.Errorf("total:%d != first5:%d\n", total, first5) + } +} + +// 测试周期性定时器, 1hour +func Test_ScheduleFunc_hour(t *testing.T) { + tw := newTimeWheel() + + var first5 int32 + ctx, cancel := context.WithCancel(context.Background()) + + const total = int32(100) + testTime := time.Hour + + tw.ScheduleFunc(testTime, func() { + atomic.AddInt32(&first5, 1) + if atomic.LoadInt32(&first5) == total { + cancel() + } + + }) + + expire := getExpire(testTime*time.Duration(total), 0) + for i := 0; i <= int(expire)+10; i++ { + get10Ms := func() time.Duration { + return tw.curTimePoint + 1 + } + tw.run(get10Ms) + } + + select { + case <-ctx.Done(): + case <-time.After(time.Second / 100): + } + + if total != first5 { + t.Errorf("total:%d != first5:%d\n", total, first5) + } + +} + +// 测试周期性定时器, 1day +func Test_ScheduleFunc_day(t *testing.T) { + tw := newTimeWheel() + + var first5 int32 + ctx, cancel := context.WithCancel(context.Background()) + + const total = int32(10) + testTime := time.Hour * 24 + + tw.ScheduleFunc(testTime, func() { + atomic.AddInt32(&first5, 1) + if atomic.LoadInt32(&first5) == total { + cancel() + } + + }) + + expire := getExpire(testTime*time.Duration(total), 0) + for i := 0; i <= int(expire)+10; i++ { + get10Ms := func() time.Duration { + return tw.curTimePoint + 1 + } + tw.run(get10Ms) + } + + select { + case <-ctx.Done(): + case <-time.After(time.Second / 100): + } + + if total != first5 { + t.Errorf("total:%d != first5:%d\n", total, first5) + } +} diff --git a/common/utils/timer/time_wheel_utils.go b/common/utils/timer/time_wheel_utils.go new file mode 100644 index 000000000..66f15bca7 --- /dev/null +++ b/common/utils/timer/time_wheel_utils.go @@ -0,0 +1,10 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import "time" + +func get10Ms() time.Duration { + return time.Duration(int64(time.Now().UnixNano() / int64(time.Millisecond) / 10)) +} diff --git a/common/utils/timer/timer.go b/common/utils/timer/timer.go new file mode 100644 index 000000000..9dc952cc8 --- /dev/null +++ b/common/utils/timer/timer.go @@ -0,0 +1,53 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import "time" + +type Next interface { + Next(time.Time) time.Time +} + +// 定时器接口 +type Timer interface { + // 一次性定时器 + AfterFunc(expire time.Duration, callback func()) TimeNoder + + // 周期性定时器 + ScheduleFunc(expire time.Duration, callback func()) TimeNoder + + // 自定义下次的时间 + CustomFunc(n Next, callback func()) TimeNoder + + // 运行 + Run() + + // 停止所有定时器 + Stop() +} + +// 停止单个定时器 +type TimeNoder interface { + Stop() bool + // 重置时间器 + Reset(expire time.Duration) bool +} + +// 定时器构造函数 +func NewTimer(opt ...Option) Timer { + var o option + for _, cb := range opt { + cb(&o) + } + + if o.timeWheel { + return newTimeWheel() + } + + if o.minHeap { + return newMinHeap() + } + + return newTimeWheel() +} diff --git a/common/utils/timer/timer_test.go b/common/utils/timer/timer_test.go new file mode 100644 index 000000000..7936bc613 --- /dev/null +++ b/common/utils/timer/timer_test.go @@ -0,0 +1,219 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "log" + "sync" + "sync/atomic" + "testing" + "time" +) + +func Test_ScheduleFunc(t *testing.T) { + tm := NewTimer() + + log.SetFlags(log.Ldate | log.Lmicroseconds) + count := uint32(0) + log.Printf("start\n") + + tm.ScheduleFunc(time.Millisecond*100, func() { + log.Printf("schedule\n") + atomic.AddUint32(&count, 1) + }) + + go func() { + time.Sleep(570 * time.Millisecond) + log.Printf("stop\n") + tm.Stop() + }() + + tm.Run() + if count != 5 { + t.Errorf("count:%d != 5\n", count) + } + +} + +func Test_AfterFunc(t *testing.T) { + tm := NewTimer() + go tm.Run() + log.Printf("start\n") + + count := uint32(0) + tm.AfterFunc(time.Millisecond*20, func() { + log.Printf("after Millisecond * 20") + atomic.AddUint32(&count, 1) + }) + + tm.AfterFunc(time.Second, func() { + log.Printf("after second") + atomic.AddUint32(&count, 1) + }) + + /* + tm.AfterFunc(time.Minute, func() { + log.Printf("after Minute") + }) + */ + /* + tm.AfterFunc(time.Hour, nil) + tm.AfterFunc(time.Hour*24, nil) + tm.AfterFunc(time.Hour*24*365, nil) + tm.AfterFunc(time.Hour*24*365*12, nil) + */ + + time.Sleep(time.Second + time.Millisecond*100) + tm.Stop() + + if count != 2 { + t.Errorf("count:%d != 2\n", count) + } + +} + +func Test_Node_Stop_1(t *testing.T) { + tm := NewTimer() + count := uint32(0) + node := tm.AfterFunc(time.Millisecond*10, func() { + atomic.AddUint32(&count, 1) + }) + go func() { + time.Sleep(time.Millisecond * 30) + node.Stop() + tm.Stop() + }() + + tm.Run() + if count != 1 { + t.Errorf("count:%d == 1\n", count) + } +} + +func Test_Node_Stop(t *testing.T) { + tm := NewTimer() + count := uint32(0) + node := tm.AfterFunc(time.Millisecond*100, func() { + atomic.AddUint32(&count, 1) + }) + node.Stop() + go func() { + time.Sleep(time.Millisecond * 200) + tm.Stop() + }() + tm.Run() + + if count == 1 { + t.Errorf("count:%d == 1\n", count) + } + +} + +// 测试重置定时器 +func Test_Reset(t *testing.T) { + t.Run("min heap reset", func(t *testing.T) { + + tm := NewTimer(WithMinHeap()) + + go tm.Run() + count := int32(0) + + tc := make(chan time.Duration, 2) + + var mu sync.Mutex + isClose := false + now := time.Now() + node1 := tm.AfterFunc(time.Millisecond*100, func() { + + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + + node2 := tm.AfterFunc(time.Millisecond*100, func() { + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + node1.Reset(time.Millisecond) + node2.Reset(time.Millisecond) + + time.Sleep(time.Millisecond * 3) + mu.Lock() + isClose = true + close(tc) + node1.Stop() + node2.Stop() + mu.Unlock() + for tv := range tc { + if tv < time.Millisecond || tv > 2*time.Millisecond { + t.Errorf("tc < time.Millisecond tc > 2*time.Millisecond") + + } + } + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + + }) + + t.Run("time wheel reset", func(t *testing.T) { + tm := NewTimer() + + go func() { + tm.Run() + }() + + count := int32(0) + + tc := make(chan time.Duration, 2) + + var mu sync.Mutex + isClose := false + now := time.Now() + node1 := tm.AfterFunc(time.Millisecond*10, func() { + + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + + node2 := tm.AfterFunc(time.Millisecond*10, func() { + mu.Lock() + atomic.AddInt32(&count, 1) + if atomic.LoadInt32(&count) <= 2 && !isClose { + tc <- time.Since(now) + } + mu.Unlock() + }) + + node1.Reset(time.Millisecond * 20) + node2.Reset(time.Millisecond * 20) + + time.Sleep(time.Millisecond * 40) + mu.Lock() + isClose = true + close(tc) + node1.Stop() + node2.Stop() + mu.Unlock() + for tv := range tc { + if tv < time.Millisecond*20 || tv > 2*time.Millisecond*20 { + t.Errorf("tc < time.Millisecond tc > 2*time.Millisecond") + } + } + if atomic.LoadInt32(&count) != 2 { + t.Errorf("count:%d != 2", atomic.LoadInt32(&count)) + } + }) +} diff --git a/common/utils/timer/timer_wheel_utils_test.go b/common/utils/timer/timer_wheel_utils_test.go new file mode 100644 index 000000000..0c4cdd62e --- /dev/null +++ b/common/utils/timer/timer_wheel_utils_test.go @@ -0,0 +1,14 @@ +// Copyright 2020-2024 guonaihong, antlabs. All rights reserved. +// +// mit license +package timer + +import ( + "fmt" + "testing" +) + +func Test_Get10Ms(t *testing.T) { + + fmt.Printf("%v:%d", get10Ms(), get10Ms()) +} diff --git a/go.work b/go.work index c34248269..682455af7 100644 --- a/go.work +++ b/go.work @@ -8,24 +8,25 @@ use ( ./common/cool ./common/utils/bitset ./common/utils/bytearray + ./common/utils/concurrent-swiss-map ./common/utils/cronex ./common/utils/event ./common/utils/go-jsonrpc ./common/utils/go-sensitive-word-1.3.3 ./common/utils/goja - ./common/utils/limit ./common/utils/lockfree-1.1.3 ./common/utils/log ./common/utils/qqwry ./common/utils/sturc + ./common/utils/timer ./common/utils/xml ./logic ./login ./modules ./modules/base - ./modules/player ./modules/config ./modules/dict + ./modules/player ./modules/space ./modules/task ) diff --git a/logic/controller/login_getserver.go b/logic/controller/login_getserver.go index 1227435f5..52aca8d45 100644 --- a/logic/controller/login_getserver.go +++ b/logic/controller/login_getserver.go @@ -41,7 +41,7 @@ func (h Controller) GetServerOnline(data *user.SidInfo, c gnet.Conn) (result *rp ser := playerservice.NewUserService(data.Head.UserID) f, b := ser.Friend.Get() for _, v := range f { - result.FriendInfo = append(result.FriendInfo, rpc.FriendInfo{v, 1}) + result.FriendInfo = append(result.FriendInfo, rpc.FriendInfo{Userid: v, TimePoke: 1}) } result.BlackInfo = b defer func() { diff --git a/logic/controller/pet_egg.go b/logic/controller/pet_egg.go index 2ed762915..db96d2fc8 100644 --- a/logic/controller/pet_egg.go +++ b/logic/controller/pet_egg.go @@ -5,22 +5,20 @@ import ( "blazing/logic/service/fight" "blazing/logic/service/pet" "blazing/logic/service/player" + "blazing/modules/player/model" ) // GetBreedInfo 获取繁殖信息协议 // 前端到后端无数据 请求协议 func (ctl Controller) GetBreedInfo( - data *pet.C2S_GET_BREED_INFO, playerObj *player.Player) (result *pet.S2C_GET_BREED_INFO, err errorcode.ErrorCode) { //这个时候player应该是空的 + data *pet.C2S_GET_BREED_INFO, player *player.Player) (result *model.S2C_GET_BREED_INFO, err errorcode.ErrorCode) { //这个时候player应该是空的 - result = &pet.S2C_GET_BREED_INFO{} - result.BreedLeftTime = 5000 - result.HatchLeftTime = 5000 - result.HatchState = 1 - result.BreedState = 1 - result.EggID = 1 - result.Intimacy = 1 - result.FeMalePetID = 1 - result.MalePetID = 3 + result = &model.S2C_GET_BREED_INFO{} + r := player.Service.Egg.Get() + if r == nil { + return + } + result = &r.Data // TODO: 实现获取繁殖信息的具体逻辑 return result, 0 @@ -45,10 +43,21 @@ func (ctl Controller) GetBreedPet( // StartBreed 开始繁殖协议 // 前端到后端 func (ctl Controller) StartBreed( - data *pet.C2S_START_BREED, playerObj *player.Player) (result *fight.NullOutboundInfo, err errorcode.ErrorCode) { //这个时候player应该是空的 - + data *pet.C2S_START_BREED, player *player.Player) (result *fight.NullOutboundInfo, err errorcode.ErrorCode) { //这个时候player应该是空的 + _, MalePet, found := player.FindPet(data.Male) + if !found { + return nil, errorcode.ErrorCodes.ErrPokemonNotExists + } + _, Female, found := player.FindPet(data.Female) + if !found { + return nil, errorcode.ErrorCodes.ErrPokemonNotExists + } // TODO: 实现开始繁殖的具体逻辑 result = &fight.NullOutboundInfo{} + r := player.Service.Egg.StartBreed(MalePet, Female) + if !r { + return nil, errorcode.ErrorCodes.ErrCannotPerformAction + } return result, 0 } @@ -56,15 +65,20 @@ func (ctl Controller) StartBreed( // GetEggList 获取精灵蛋数组 // 前端到后端无数据 请求协议 func (ctl Controller) GetEggList( - data *pet.C2S_GET_EGG_LIST, playerObj *player.Player) (result *pet.S2C_GET_EGG_LIST, err errorcode.ErrorCode) { //这个时候player应该是空的 + data *pet.C2S_GET_EGG_LIST, player *player.Player) (result *pet.S2C_GET_EGG_LIST, err errorcode.ErrorCode) { //这个时候player应该是空的 result = &pet.S2C_GET_EGG_LIST{} // TODO: 实现获取精灵蛋列表的逻辑 // 示例数据,实际应从玩家数据中获取 - result.EggList = append(result.EggList, pet.EggInfo{EggID: 1, OwnerID: 10001, EggCatchTime: 122123, - MalePetID: 1, - FeMalePetID: 3, - }) + r := player.Service.Egg.Get() + if r == nil { + return + } + + for _, v := range r.EggList { + result.EggList = append(result.EggList, v) + } + return result, 0 } diff --git a/logic/controller/user_info.go b/logic/controller/user_info.go index 416b0bff4..81ca8b5d1 100644 --- a/logic/controller/user_info.go +++ b/logic/controller/user_info.go @@ -16,7 +16,7 @@ import ( func (h Controller) GetUserSimInfo(data *user.SimUserInfoInboundInfo, player *player.Player) (result *user.SimUserInfoOutboundInfo, err errorcode.ErrorCode) { result = &user.SimUserInfoOutboundInfo{} - copier.Copy(result, player.Service.Info.Person(data.UserId)) + copier.Copy(result, player.Service.Info.Person(data.UserId).Data) return result, 0 } @@ -27,7 +27,7 @@ func (h Controller) GetUserSimInfo(data *user.SimUserInfoInboundInfo, player *pl func (h Controller) GetUserMoreInfo(data *user.MoreUserInfoInboundInfo, player *player.Player) (result *user.MoreUserInfoOutboundInfo, err errorcode.ErrorCode) { result = &user.MoreUserInfoOutboundInfo{} info := player.Service.Info.Person(data.UserId) - copier.CopyWithOption(result, info, copier.Option{IgnoreEmpty: true, DeepCopy: true}) + copier.CopyWithOption(result, info.Data, copier.Option{IgnoreEmpty: true, DeepCopy: true}) //todo 待实现 return result, 0 @@ -56,4 +56,4 @@ func (h Controller) GetPlayerExp(data *item.ExpTotalRemainInboundInfo, player *p TotalExp: uint32(player.Info.ExpPool), }, 0 -} \ No newline at end of file +} diff --git a/logic/main.go b/logic/main.go index f8f900ce7..b6a6ede4a 100644 --- a/logic/main.go +++ b/logic/main.go @@ -46,6 +46,21 @@ func signalHandlerForMain(sig os.Signal) { // main 程序主入口函数 func main() { + // item := model.NeweggConfig() + // item.GeneratedPetIDs = []model.GeneratedPetID{ + // {PetID: 1, Prob: 0.01}, + // {PetID: 2, Prob: 0.01}, + // {PetID: 3, Prob: 0.01}, + // {PetID: 4, Prob: 0.01}, + // {PetID: 5, Prob: 0.01}, + // {PetID: 6, Prob: 0.01}, + // } + // item.MalePet = 1 + // item.FemalePet = 2 + // _, err := g.DB(item.GroupName()).Model(item.TableName()).FieldsEx("id").Data(item).Insert() + // if err != nil { + // panic(err) + // } //loadAccounts() // if cool.IsRedisMode { // go cool.ListenFunc(gctx.New()) diff --git a/logic/service/pet/egg.go b/logic/service/pet/egg.go index 3ea49190b..4af10a64d 100644 --- a/logic/service/pet/egg.go +++ b/logic/service/pet/egg.go @@ -1,6 +1,9 @@ package pet -import "blazing/logic/service/common" +import ( + "blazing/logic/service/common" + "blazing/modules/player/model" +) // C2S_GET_BREED_PET 获取繁殖精灵协议 // 前端到后端 @@ -16,33 +19,6 @@ type S2C_GET_BREED_PET struct { FemaleList []uint32 `json:"femaleList"` // 可繁殖雌性的精灵数组 参数为精灵捕获时间 } -// S2C_GET_BREED_INFO 获取繁殖信息协议 -// 后端到前端 -type S2C_GET_BREED_INFO struct { - // BreedState 繁殖状态 - BreedState uint32 `json:"breedState"` - // BreedLeftTime 繁殖剩余时间 - BreedLeftTime uint32 `json:"breedLeftTime"` - // BreedCoolTime 繁殖冷却时间 - BreedCoolTime uint32 `json:"breedCoolTime"` - // MalePetCatchTime 雄性精灵捕捉时间 - MalePetCatchTime uint32 `json:"malePetCatchTime"` - // MalePetID 雄性精灵ID - MalePetID uint32 `json:"malePetID"` - // FeMalePetCatchTime 雌性精灵捕捉时间 - FeMalePetCatchTime uint32 `json:"feMalePetCatchTime"` - // FeMalePetID 雌性精灵ID - FeMalePetID uint32 `json:"feMalePetID"` - // HatchState 孵化状态 - HatchState uint32 `json:"hatchState"` - // HatchLeftTime 孵化剩余时间 - HatchLeftTime uint32 `json:"hatchLeftTime"` - // EggID 当前孵化的精灵蛋ID - EggID uint32 `json:"eggID"` - // Intimacy 亲密度 1 = 悲伤 以此类推 ["悲伤","冷淡","平淡","友好","亲密无间"] - Intimacy uint32 `json:"intimacy"` -} - // C2S_GET_BREED_INFO 获取繁殖信息协议 // 前端到后端无数据 请求协议 type C2S_GET_BREED_INFO struct { @@ -58,17 +34,8 @@ type C2S_GET_EGG_LIST struct { // S2C_GET_EGG_LIST 获取精灵蛋数组协议 // 后端到前端 type S2C_GET_EGG_LIST struct { - EggListLen uint32 `struc:"sizeof=EggList"` - EggList []EggInfo `json:"eggList"` // 精灵蛋数组 跟其他数组一样 需要给有数量 -} - -// EggInfo 精灵蛋信息 -type EggInfo struct { - OwnerID uint32 `json:"ownerID"` // 所属人ID - EggCatchTime uint32 `json:"eggCatchTime"` // 精灵蛋获得时间 - EggID uint32 `json:"eggID"` // 精灵蛋ID - MalePetID uint32 `json:"male"` // 雄性精灵ID - FeMalePetID uint32 `json:"female"` // 雌性精灵ID + EggListLen uint32 `struc:"sizeof=EggList"` + EggList []model.EggInfo `json:"eggList"` // 精灵蛋数组 跟其他数组一样 需要给有数量 } // C2S_START_HATCH 开始孵化精灵蛋协议 diff --git a/logic/service/player/player.go b/logic/service/player/player.go index d28c00a2b..c82191e55 100644 --- a/logic/service/player/player.go +++ b/logic/service/player/player.go @@ -169,9 +169,9 @@ func (p *Player) SendPack(b []byte) error { if p.MainConn == nil { return nil } - _, ok := p.MainConn.Context().(*ClientData) + psocket, ok := p.MainConn.Context().(*ClientData) if ok { - return p.MainConn.Context().(*ClientData).SendPack(b) + return psocket.SendPack(b) } return nil } diff --git a/modules/config/model/boss_effect.go b/modules/config/model/boss_effect.go index e374289d7..e85a0e026 100644 --- a/modules/config/model/boss_effect.go +++ b/modules/config/model/boss_effect.go @@ -16,7 +16,7 @@ type PlayerPetSpecialEffect struct { SeIdx uint32 `gorm:"not null;uniqueIndex:idx_se_idx;comment:'精灵特效索引(XML中的Idx)'" json:"se_idx"` //Stat uint32 `gorm:"not null;default:0;comment:'精灵特效状态(XML中的Stat)'" json:"stat"` Eid uint32 `gorm:"not null;index:idx_eid;comment:'精灵特效Eid(XML中的Eid)'" json:"eid"` - Args []int `gorm:"type:json;comment:'精灵特效参数(XML中的Args)'" json:"args"` + Args []int `gorm:"type:jsonb;comment:'精灵特效参数(XML中的Args)'" json:"args"` Desc string `gorm:"type:varchar(255);default:'';comment:'精灵特效描述(XML中的Desc)'" json:"desc"` } diff --git a/modules/config/model/cdk.go b/modules/config/model/cdk.go index 6ce3de644..819575507 100644 --- a/modules/config/model/cdk.go +++ b/modules/config/model/cdk.go @@ -19,8 +19,8 @@ type CDKConfig struct { //cdk可兑换次数,where不等于0 ExchangeRemainCount int64 `gorm:"not null;default:1;comment:'CDK剩余可兑换次数(不能为0才允许兑换,支持查询where !=0)'" json:"exchange_remain_count" description:"剩余可兑换次数"` - ItemRewardIds []uint32 `gorm:"not null;type:json;default:'[]';comment:'绑定奖励物品ID数组,关联item_gift表主键'" json:"item_reward_ids" description:"奖励物品数组"` - ElfRewardIds []uint32 `gorm:"not null;type:json;default:'[]';comment:'绑定奖励精灵ID数组,关联config_pet_boss表主键'" json:"elf_reward_ids" description:"奖励精灵数组"` + ItemRewardIds []uint32 `gorm:"not null;type:jsonb;default:'[]';comment:'绑定奖励物品ID数组,关联item_gift表主键'" json:"item_reward_ids" description:"奖励物品数组"` + ElfRewardIds []uint32 `gorm:"not null;type:jsonb;default:'[]';comment:'绑定奖励精灵ID数组,关联config_pet_boss表主键'" json:"elf_reward_ids" description:"奖励精灵数组"` TitleRewardIds uint32 `gorm:"not null;default:0;comment:'绑定奖励称号'" json:"title_reward_ids" description:"绑定奖励称号"` ValidEndTime time.Time `gorm:"not null;comment:'CDK有效结束时间'" json:"valid_end_time" description:"有效结束时间"` diff --git a/modules/config/model/egg.go b/modules/config/model/egg.go new file mode 100644 index 000000000..63d1382ce --- /dev/null +++ b/modules/config/model/egg.go @@ -0,0 +1,52 @@ +package model + +import ( + "blazing/cool" +) + +// 表名常量定义:egg配置表 +const ( + TableNameeggConfig = "config_pet_egg" // egg配置表(记录egg编号、可兑换次数、奖励配置等核心信息) +) + +// EggConfig egg核心配置模型(含可兑换次数,满足查询`where 可兑换次数 != 0`需求) +type EggConfig struct { + *cool.Model + + //雄性 + + MalePet int32 `gorm:"not null;comment:'雄性宠物ID'" json:"male_pet"` + //雌性 + FemalePet int32 `gorm:"not null;comment:'雌性宠物ID'" json:"female_pet"` + + // 生成的精灵ID及对应概率 + GeneratedPetIDs []GeneratedPetID `gorm:"type:jsonb;comment:'生成的精灵ID及概率配置'" json:"generated_pet_ids"` + + Remark string `gorm:"size:512;default:'';comment:'egg备注'" json:"remark" description:"备注信息"` + //ItemGift []*ItemGift `gorm:"-" orm:"with:item_id=id"` +} +type GeneratedPetID struct { + PetID int32 `json:"pet_id" comment:"生成的精灵ID"` + Prob float64 `json:"prob" comment:"该精灵生成概率"` +} + +// -------------------------- 核心配套方法(遵循项目规范)-------------------------- +func (*EggConfig) TableName() string { + return TableNameeggConfig +} + +func (*EggConfig) GroupName() string { + return "default" +} + +func NeweggConfig() *EggConfig { + return &EggConfig{ + Model: cool.NewModel(), + } +} + +// -------------------------- 表结构自动同步 -------------------------- +func init() { + + cool.CreateTable(&EggConfig{}) +} diff --git a/modules/config/model/task.go b/modules/config/model/task.go index fb54e1c1a..e3272c0c3 100644 --- a/modules/config/model/task.go +++ b/modules/config/model/task.go @@ -32,7 +32,7 @@ type TaskConfig struct { IsAcceptable uint32 `gorm:"not null;default:1;comment:'是否可以被接受'" json:"is_acceptable" description:"是否可以被接受"` // 奖励配置 - ItemRewardIds []uint32 `gorm:"not null;type:json;default:'[]';comment:'绑定奖励物品ID数组,关联item_gift表主键'" json:"item_reward_ids" description:"奖励物品数组"` + ItemRewardIds []uint32 `gorm:"not null;type:jsonb;default:'[]';comment:'绑定奖励物品ID数组,关联item_gift表主键'" json:"item_reward_ids" description:"奖励物品数组"` ElfRewardIds uint32 `gorm:"not null;default:0;comment:'绑定奖励精灵ID,关联elf_gift表主键'" json:"elf_reward_ids" description:"绑定奖励精灵ID"` //绑定奖励 diff --git a/modules/player/model/egg.go b/modules/player/model/egg.go new file mode 100644 index 000000000..c9a39299a --- /dev/null +++ b/modules/player/model/egg.go @@ -0,0 +1,76 @@ +package model + +import ( + "blazing/cool" +) + +// 表名常量 +const TableNamePlayerEgg = "player_egg" + +// Egg 对应数据库表 player_cdk_log,用于记录CDK兑换日志 +type Egg struct { + Base + PlayerID uint64 `gorm:"not null;index:idx_player_Egg_by_player_id;comment:'所属玩家ID'" json:"player_id"` + Data S2C_GET_BREED_INFO `gorm:"type:jsonb;not null;comment:'全部数据'" json:"data"` + CurEgg EggInfo `gorm:"type:jsonb;not null;comment:'当前蛋'" json:"cur_egg"` + EggList []EggInfo `gorm:"type:jsonb;not null;comment:'蛋列表'" json:"egg_list"` +} + +// S2C_GET_BREED_INFO 获取繁殖信息协议 +// 后端到前端 +type S2C_GET_BREED_INFO struct { + // BreedState 繁殖状态 + BreedState uint32 `json:"breedState"` + StartTime uint32 `struc:"skip"` //返回记录 + // BreedLeftTime 繁殖剩余时间 + BreedLeftTime uint32 `json:"breedLeftTime"` + // BreedCoolTime 繁殖冷却时间 + BreedCoolTime uint32 `json:"breedCoolTime"` + // MalePetCatchTime 雄性精灵捕捉时间 + MalePetCatchTime uint32 `json:"malePetCatchTime"` + // MalePetID 雄性精灵ID + MalePetID uint32 `json:"malePetID"` + // FeMalePetCatchTime 雌性精灵捕捉时间 + FeMalePetCatchTime uint32 `json:"feMalePetCatchTime"` + // FeMalePetID 雌性精灵ID + FeMalePetID uint32 `json:"feMalePetID"` + // HatchState 孵化状态 ,0=未孵化 1=孵化中 2=已孵化 + HatchState uint32 `json:"hatchState"` + // HatchLeftTime 孵化剩余时间 + HatchLeftTime uint32 `json:"hatchLeftTime"` + // EggID 当前孵化的精灵蛋ID + EggID uint32 `json:"eggID"` + // Intimacy 亲密度 1 = 悲伤 以此类推 ["悲伤","冷淡","平淡","友好","亲密无间"] + Intimacy uint32 `json:"intimacy"` +} + +// EggInfo 精灵蛋信息 +type EggInfo struct { + OwnerID uint32 `json:"ownerID"` // 所属人ID + EggCatchTime uint32 `json:"eggCatchTime"` // 精灵蛋获得时间 + EggID uint32 `json:"eggID"` // 精灵蛋ID + MalePetID uint32 `json:"male"` // 雄性精灵ID + FeMalePetID uint32 `json:"female"` // 雌性精灵ID +} + +// TableName 返回表名 +func (*Egg) TableName() string { + return TableNamePlayerEgg +} + +// GroupName 返回表组名 +func (*Egg) GroupName() string { + return "default" +} + +// NewEgg 创建一个新的CDK记录 +func NewEgg() *Egg { + return &Egg{ + Base: *NewBase(), + } +} + +// init 程序启动时自动创建表 +func init() { + cool.CreateTable(&Egg{}) +} diff --git a/modules/player/model/FRIEND.go b/modules/player/model/friend.go similarity index 100% rename from modules/player/model/FRIEND.go rename to modules/player/model/friend.go diff --git a/modules/player/model/player.go b/modules/player/model/info.go similarity index 100% rename from modules/player/model/player.go rename to modules/player/model/info.go diff --git a/modules/player/model/sign.go b/modules/player/model/sign.go index c936a671d..20dc281cb 100644 --- a/modules/player/model/sign.go +++ b/modules/player/model/sign.go @@ -18,7 +18,7 @@ type SignInRecord struct { IsCompleted bool `gorm:"not null;default:false;comment:'签到是否完成(0-未完成 1-已完成)'" json:"is_completed"` //通过bitset来实现签到的进度记录 - SignInProgress []uint32 `gorm:"type:json;not null;comment:'签到进度(状压实现,存储每日签到状态)'" json:"sign_in_progress"` + SignInProgress []uint32 `gorm:"type:jsonb;not null;comment:'签到进度(状压实现,存储每日签到状态)'" json:"sign_in_progress"` } // TableName 指定表名(遵循现有规范) diff --git a/modules/player/service/egg.go b/modules/player/service/egg.go new file mode 100644 index 000000000..c0b4e4e74 --- /dev/null +++ b/modules/player/service/egg.go @@ -0,0 +1,48 @@ +package service + +import ( + "blazing/cool" + "blazing/modules/player/model" + "time" +) + +type EggService struct { + BaseService +} + +func NewEggService(id uint32) *EggService { + return &EggService{ + + BaseService: BaseService{userid: id, + + Service: &cool.Service{Model: model.NewEgg()}, + }, + } + +} +func (s *EggService) Get() (out *model.Egg) { + + s.TestModel(s.Model).Scan(&out) + + return + +} +func (s *EggService) StartBreed(m, f *model.PetInfo) bool { + + var tt *model.Egg + s.TestModel(s.Model).Scan(&tt) + if tt == nil { + tt = &model.Egg{} + } + if tt.Data.HatchState != 0 { + return false + + } + tt.Data.StartTime = uint32(time.Now().Unix()) + tt.Data.HatchState = 1 + tt.Data.FeMalePetCatchTime = f.CatchTime + tt.Data.MalePetCatchTime = m.CatchTime + tt.Data.FeMalePetID = f.ID + tt.Data.MalePetID = m.ID + return true +} diff --git a/modules/player/service/info.go b/modules/player/service/info.go index 1af56ad97..5d2d4fdc5 100644 --- a/modules/player/service/info.go +++ b/modules/player/service/info.go @@ -53,17 +53,11 @@ func (s *InfoService) Reg(nick string, color uint32) { //go s.InitTask() } -func (s *InfoService) Person(userid uint32) *model.PlayerInfo { +func (s *InfoService) Person(userid uint32) (out *model.PlayerEX) { - m := cool.DBM(s.Model).Where("player_id", userid) - var tt model.PlayerEX - err := m.Scan(&tt) - if err != nil { - return nil - } + cool.DBM(s.Model).Where("player_id", userid).Scan(&out) - ret := tt.Data - return &ret + return } func (s *InfoService) GetCache() *model.PlayerInfo { diff --git a/modules/player/service/pet.go b/modules/player/service/pet.go index 665024416..ab879d194 100644 --- a/modules/player/service/pet.go +++ b/modules/player/service/pet.go @@ -116,7 +116,7 @@ RETURNING max_ts; `, service.NewBaseSysUserService().Model.TableName()) // 执行 Raw SQL 并扫描返回值 - ret, err := cool.DBM(service.NewBaseSysUserService().Model).Raw(sql, s.userid).All() + ret, _ := cool.DBM(service.NewBaseSysUserService().Model).Raw(sql, s.userid).All() //fmt.Println(ret, err) y.CatchTime = ret.Array()[0].Uint32() m1 := cool.DBM(s.Model).Where("player_id", s.userid) @@ -127,7 +127,7 @@ RETURNING max_ts; player.Free = 0 player.IsVip = cool.Config.ServerInfo.IsVip - _, err = m1.Insert(player) + _, err := m1.Insert(player) if err != nil { panic(err) } diff --git a/modules/player/service/user.go b/modules/player/service/user.go index 296f4d6b3..01433f73b 100644 --- a/modules/player/service/user.go +++ b/modules/player/service/user.go @@ -22,6 +22,7 @@ type UserService struct { Title *TitleService Cdk *CdkService Friend *FriendService + Egg *EggService } func NewUserService(id uint32) *UserService { @@ -38,6 +39,7 @@ func NewUserService(id uint32) *UserService { Title: NewTitleService(id), Cdk: NewCdkService(id), Friend: NewFriendService(id), + Egg: NewEggService(id), } }