"refactor(common): 重构序列化工具包,将serialize重命名为utils并添加bitset组件"

This commit is contained in:
1
2025-07-25 01:29:03 +00:00
parent 84d6d99356
commit 58e972eea3
113 changed files with 11 additions and 11 deletions

View File

@@ -0,0 +1,21 @@
name: Test
on: [push, pull_request]
permissions:
contents: read
jobs:
test:
strategy:
matrix:
go-version: [1.19.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@master
- name: Test
run: go test ./...

View File

@@ -0,0 +1,28 @@
BSD 3-Clause License
Copyright (c) 2023, Nil@Pointer
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,68 @@
# bitset32
[![Test](https://github.com/bits-and-blooms/bitset/workflows/Test/badge.svg)](https://github.com/pointernil/bitset32/actions?query=workflow%3ATest)
[zh_CN 简体中文](./README_zh_CN.md)
## Description
Package bitset32 modified from `"github.com/bits-and-blooms/bitset"`
implements bitset with uint32. Both packages are used in the same way.
If not necessary, it is highly recommended to use
["github.com/bits-and-blooms/bitset"](https://github.com/bits-and-blooms/bitset).
## Go version
```
go version go1.19.4 windows/amd64
```
## Install
```
go get github.com/pointernil/bitset32
```
## Testing
```
go test
go test -cover
```
## Usage
```
package main
import (
"fmt"
"math/rand"
"github.com/pointernil/bitset32"
)
func main() {
fmt.Printf("Hello from BitSet!\n")
var b bitset32.BitSet32
// play some Go Fish
for i := 0; i < 100; i++ {
card1 := uint(rand.Intn(52))
card2 := uint(rand.Intn(52))
b.Set(card1)
if b.Test(card2) {
fmt.Println("Go Fish!")
}
b.Clear(card1)
}
// Chaining
b.Set(10).Set(11)
for i, e := b.NextSet(0); e; i, e = b.NextSet(i + 1) {
fmt.Println("The following bit is set:", i)
}
if b.Intersection(bitset32.New(100).Set(10)).Count() == 1 {
fmt.Println("Intersection works.")
} else {
fmt.Println("Intersection doesn't work???")
}
}
```

View File

@@ -0,0 +1,66 @@
# bitset32
[![Test](https://github.com/bits-and-blooms/bitset/workflows/Test/badge.svg)](https://github.com/pointernil/bitset32/actions?query=workflow%3ATest)
[en English](./README.md)
## 简介
`bitset32` 修改自 `"github.com/bits-and-blooms/bitset"`底层使用uint32存数据`bitset32` `bitset` 用法一致
如非必要请使用 ["github.com/bits-and-blooms/bitset"](https://github.com/bits-and-blooms/bitset)。
## Golang版本
```
go version go1.19.4 windows/amd64
```
## 安装
```
go get github.com/pointernil/bitset32
```
## 测试
```
go test
go test -cover
```
## 使用示意
```
package main
import (
"fmt"
"math/rand"
"github.com/pointernil/bitset32"
)
func main() {
fmt.Printf("! \n")
var b bitset32.BitSet32
// play some Go Fish
for i := 0; i < 100; i++ {
card1 := uint(rand.Intn(52))
card2 := uint(rand.Intn(52))
b.Set(card1)
if b.Test(card2) {
fmt.Println("Go Fish!")
}
b.Clear(card1)
}
// Chaining
b.Set(10).Set(11)
for i, e := b.NextSet(0); e; i, e = b.NextSet(i + 1) {
fmt.Println("The following bit is set:", i)
}
if b.Intersection(bitset32.New(100).Set(10)).Count() == 1 {
fmt.Println("Intersection works.")
} else {
fmt.Println("Intersection doesn't work???")
}
}
```

View File

@@ -0,0 +1,7 @@
// Package bitset32 modified from "github.com/bits-and-blooms/bitset"
// implements bitset with uint32. Both packages are used in the same way.
// In bitset32, some methods are untested. So if not necessary,
// it is highly recommended to use "github.com/bits-and-blooms/bitset".
// go version go1.19.4 windows/amd64
package bitset32

View File

@@ -0,0 +1,963 @@
/*
Package bitset implements bitsets, a mapping
between non-negative integers and boolean values. It should be more
efficient than map[uint] bool.
It provides methods for setting, clearing, flipping, and testing
individual integers.
But it also provides set intersection, union, difference,
complement, and symmetric operations, as well as tests to
check whether any, all, or no bits are set, and querying a
bitset's current length and number of positive bits.
BitSets are expanded to the size of the largest set bit; the
memory allocation is approximately Max bits, where Max is
the largest set bit. BitSets are never shrunk. On creation,
a hint can be given for the number of bits that will be used.
Many of the methods, including Set,Clear, and Flip, return
a BitSet pointer, which allows for chaining.
Example use:
import "bitset"
var b BitSet
b.Set(10).Set(11)
if b.Test(1000) {
b.Clear(1000)
}
if B.Intersection(bitset.New(100).Set(10)).Count() > 1 {
fmt.Println("Intersection works.")
}
As an alternative to BitSets, one should check out the 'big' package,
which provides a (less set-theoretical) view of bitsets.
*/
package bitset32
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/bits"
"strconv"
)
// the wordSize of a bit set
const wordSize = uint(32)
// log2WordSize is lg(wordSize)
const log2WordSize = uint(5)
// allBits has every bit set
const allBits uint32 = 0xffffffff
// TODO BUGFIX
// default binary BigEndian
var binaryOrder binary.ByteOrder = binary.BigEndian
// A BitSet is a set of bits. The zero value of a BitSet is an empty set of length 0.
type BitSet32 struct {
length uint
set []uint32
}
// Error is used to distinguish errors (panics) generated in this package.
type Error string
// safeSet will fixup b.set to be non-nil and return the field value
func (b *BitSet32) safeSet() []uint32 {
if b.set == nil {
b.set = make([]uint32, wordsNeeded(0))
}
return b.set
}
// SetBitsetFrom fills the bitset with an array of integers without creating a new BitSet instance
func (b *BitSet32) SetBitsetFrom(buf []uint32) {
b.length = uint(len(buf)) * 32
b.set = buf
}
// From is a constructor used to create a BitSet from an array of integers
func From(buf []uint32) *BitSet32 {
return FromWithLength(uint(len(buf))*32, buf)
}
// FromWithLength constructs from an array of integers and length.
func FromWithLength(len uint, set []uint32) *BitSet32 {
return &BitSet32{len, set}
}
// Bytes returns the bitset as array of integers
func (b *BitSet32) Bytes() []uint32 {
return b.set
}
// wordsNeeded calculates the number of words needed for i bits
func wordsNeeded(i uint) int {
if i > (Cap() - wordSize + 1) {
return int(Cap() >> log2WordSize)
}
return int((i + (wordSize - 1)) >> log2WordSize)
}
// wordsNeededUnbound calculates the number of words needed for i bits, possibly exceeding the capacity.
// This function is useful if you know that the capacity cannot be exceeded (e.g., you have an existing bitmap).
func wordsNeededUnbound(i uint) int {
return int((i + (wordSize - 1)) >> log2WordSize)
}
// wordsIndex calculates the index of words in a `uint64`
func wordsIndex(i uint) uint {
return i & (wordSize - 1)
}
// New creates a new BitSet with a hint that length bits will be required
func New(length uint) (bset *BitSet32) {
defer func() {
if r := recover(); r != nil {
bset = &BitSet32{
0,
make([]uint32, 0),
}
}
}()
bset = &BitSet32{
length,
make([]uint32, wordsNeeded(length)),
}
return bset
}
// Cap returns the total possible capacity, or number of bits
func Cap() uint {
return ^uint(0)
}
// Len returns the number of bits in the BitSet.
// Note the difference to method Count, see example.
func (b *BitSet32) Len() uint {
return b.length
}
// extendSet adds additional words to incorporate new bits if needed
func (b *BitSet32) extendSet(i uint) {
if i >= Cap() {
panic("You are exceeding the capacity")
}
nsize := wordsNeeded(i + 1)
if b.set == nil {
b.set = make([]uint32, nsize)
} else if cap(b.set) >= nsize {
b.set = b.set[:nsize] // fast resize
} else if len(b.set) < nsize {
newset := make([]uint32, nsize, 2*nsize) // increase capacity 2x
copy(newset, b.set)
b.set = newset
}
b.length = i + 1
}
// Test whether bit i is set.
func (b *BitSet32) Test(i uint) bool {
if i >= b.length {
return false
}
return b.set[i>>log2WordSize]&(1<<wordsIndex(i)) != 0
}
// Set bit i to 1, the capacity of the bitset is automatically
// increased accordingly.
// If i>= Cap(), this function will panic.
// Warning: using a very large value for 'i'
// may lead to a memory shortage and a panic: the caller is responsible
// for providing sensible parameters in line with their memory capacity.
func (b *BitSet32) Set(i uint) *BitSet32 {
if i >= b.length { // if we need more bits, make 'em
b.extendSet(i)
}
b.set[i>>log2WordSize] |= 1 << wordsIndex(i)
return b
}
// Clear bit i to 0
func (b *BitSet32) Clear(i uint) *BitSet32 {
if i >= b.length {
return b
}
b.set[i>>log2WordSize] &^= 1 << wordsIndex(i)
return b
}
// SetTo sets bit i to value.
// If i>= Cap(), this function will panic.
// Warning: using a very large value for 'i'
// may lead to a memory shortage and a panic: the caller is responsible
// for providing sensible parameters in line with their memory capacity.
func (b *BitSet32) SetTo(i uint, value bool) *BitSet32 {
if value {
return b.Set(i)
}
return b.Clear(i)
}
// Flip bit at i.
// If i>= Cap(), this function will panic.
// Warning: using a very large value for 'i'
// may lead to a memory shortage and a panic: the caller is responsible
// for providing sensible parameters in line with their memory capacity.
func (b *BitSet32) Flip(i uint) *BitSet32 {
if i >= b.length {
return b.Set(i)
}
b.set[i>>log2WordSize] ^= 1 << wordsIndex(i)
return b
}
// FlipRange bit in [start, end).
// If end>= Cap(), this function will panic.
// Warning: using a very large value for 'end'
// may lead to a memory shortage and a panic: the caller is responsible
// for providing sensible parameters in line with their memory capacity.
func (b *BitSet32) FlipRange(start, end uint) *BitSet32 {
if start >= end {
return b
}
if end-1 >= b.length { // if we need more bits, make 'em
b.extendSet(end - 1)
}
var startWord uint = start >> log2WordSize
var endWord uint = end >> log2WordSize
b.set[startWord] ^= ^(^uint32(0) << wordsIndex(start))
for i := startWord; i < endWord; i++ {
b.set[i] = ^b.set[i]
}
if end&(wordSize-1) != 0 {
b.set[endWord] ^= ^uint32(0) >> wordsIndex(-end)
}
return b
}
// Shrink shrinks BitSet so that the provided value is the last possible
// set value. It clears all bits > the provided index and reduces the size
// and length of the set.
//
// Note that the parameter value is not the new length in bits: it is the
// maximal value that can be stored in the bitset after the function call.
// The new length in bits is the parameter value + 1. Thus it is not possible
// to use this function to set the length to 0, the minimal value of the length
// after this function call is 1.
//
// A new slice is allocated to store the new bits, so you may see an increase in
// memory usage until the GC runs. Normally this should not be a problem, but if you
// have an extremely large BitSet its important to understand that the old BitSet will
// remain in memory until the GC frees it.
func (b *BitSet32) Shrink(lastbitindex uint) *BitSet32 {
length := lastbitindex + 1
idx := wordsNeeded(length)
if idx > len(b.set) {
return b
}
shrunk := make([]uint32, idx)
copy(shrunk, b.set[:idx])
b.set = shrunk
b.length = length
lastWordUsedBits := length % 32
if lastWordUsedBits != 0 {
b.set[idx-1] &= allBits >> uint32(32-wordsIndex(lastWordUsedBits))
}
return b
}
// Compact shrinks BitSet to so that we preserve all set bits, while minimizing
// memory usage. Compact calls Shrink.
func (b *BitSet32) Compact() *BitSet32 {
idx := len(b.set) - 1
for ; idx >= 0 && b.set[idx] == 0; idx-- {
}
newlength := uint((idx + 1) << log2WordSize)
if newlength >= b.length {
return b // nothing to do
}
if newlength > 0 {
return b.Shrink(newlength - 1)
}
// TODO: FIX
// We preserve one word
return b.Shrink(31)
}
// InsertAt takes an index which indicates where a bit should be
// inserted. Then it shifts all the bits in the set to the left by 1, starting
// from the given index position, and sets the index position to 0.
//
// Depending on the size of your BitSet, and where you are inserting the new entry,
// this method could be extremely slow and in some cases might cause the entire BitSet
// to be recopied.
func (b *BitSet32) InsertAt(idx uint) *BitSet32 {
insertAtElement := idx >> log2WordSize
// if length of set is a multiple of wordSize we need to allocate more space first
if b.isLenExactMultiple() {
b.set = append(b.set, uint32(0))
}
var i uint
for i = uint(len(b.set) - 1); i > insertAtElement; i-- {
// all elements above the position where we want to insert can simply by shifted
b.set[i] <<= 1
// we take the most significant bit of the previous element and set it as
// the least significant bit of the current element
// TODO: FIX
b.set[i] |= (b.set[i-1] & 0x80000000) >> 31
}
// generate a mask to extract the data that we need to shift left
// within the element where we insert a bit
dataMask := uint32(1)<<uint32(wordsIndex(idx)) - 1
// extract that data that we'll shift
data := b.set[i] & (^dataMask)
// set the positions of the data mask to 0 in the element where we insert
b.set[i] &= dataMask
// shift data mask to the left and insert its data to the slice element
b.set[i] |= data << 1
// add 1 to length of BitSet
b.length++
return b
}
// String creates a string representation of the Bitmap
func (b *BitSet32) String() string {
// follows code from https://github.com/RoaringBitmap/roaring
var buffer bytes.Buffer
start := []byte("{")
buffer.Write(start)
counter := 0
i, e := b.NextSet(0)
for e {
counter = counter + 1
// to avoid exhausting the memory
if counter > 0x40000 {
buffer.WriteString("...")
break
}
buffer.WriteString(strconv.FormatInt(int64(i), 10))
i, e = b.NextSet(i + 1)
if e {
buffer.WriteString(",")
}
}
buffer.WriteString("}")
return buffer.String()
}
// DeleteAt deletes the bit at the given index position from
// within the bitset
// All the bits residing on the left of the deleted bit get
// shifted right by 1
// The running time of this operation may potentially be
// relatively slow, O(length)
func (b *BitSet32) DeleteAt(i uint) *BitSet32 {
// the index of the slice element where we'll delete a bit
deleteAtElement := i >> log2WordSize
// generate a mask for the data that needs to be shifted right
// within that slice element that gets modified
dataMask := ^((uint32(1) << wordsIndex(i)) - 1)
// extract the data that we'll shift right from the slice element
data := b.set[deleteAtElement] & dataMask
// set the masked area to 0 while leaving the rest as it is
b.set[deleteAtElement] &= ^dataMask
// shift the previously extracted data to the right and then
// set it in the previously masked area
b.set[deleteAtElement] |= (data >> 1) & dataMask
// loop over all the consecutive slice elements to copy each
// lowest bit into the highest position of the previous element,
// then shift the entire content to the right by 1
for i := int(deleteAtElement) + 1; i < len(b.set); i++ {
b.set[i-1] |= (b.set[i] & 1) << 31
b.set[i] >>= 1
}
b.length = b.length - 1
return b
}
// NextSet returns the next bit set from the specified index,
// including possibly the current index
// along with an error code (true = valid, false = no set bit found)
// for i,e := v.NextSet(0); e; i,e = v.NextSet(i + 1) {...}
//
// Users concerned with performance may want to use NextSetMany to
// retrieve several values at once.
func (b *BitSet32) NextSet(i uint) (uint, bool) {
x := int(i >> log2WordSize)
if x >= len(b.set) {
return 0, false
}
w := b.set[x]
w = w >> wordsIndex(i)
if w != 0 {
return i + uint(bits.TrailingZeros32(w)), true
}
x = x + 1
for x < len(b.set) {
if b.set[x] != 0 {
return uint(x)*wordSize + uint(bits.TrailingZeros32(b.set[x])), true
}
x = x + 1
}
return 0, false
}
// NextSetMany returns many next bit sets from the specified index,
// including possibly the current index and up to cap(buffer).
// If the returned slice has len zero, then no more set bits were found
//
// buffer := make([]uint, 256) // this should be reused
// j := uint(0)
// j, buffer = bitmap.NextSetMany(j, buffer)
// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) {
// for k := range buffer {
// do something with buffer[k]
// }
// j += 1
// }
//
// It is possible to retrieve all set bits as follow:
//
// indices := make([]uint, bitmap.Count())
// bitmap.NextSetMany(0, indices)
//
// However if bitmap.Count() is large, it might be preferable to
// use several calls to NextSetMany, for performance reasons.
func (b *BitSet32) NextSetMany(i uint, buffer []uint) (uint, []uint) {
myanswer := buffer
capacity := cap(buffer)
x := int(i >> log2WordSize)
if x >= len(b.set) || capacity == 0 {
return 0, myanswer[:0]
}
skip := wordsIndex(i)
word := b.set[x] >> skip
myanswer = myanswer[:capacity]
size := int(0)
for word != 0 {
r := uint(bits.TrailingZeros32(word))
t := word & ((^word) + 1)
myanswer[size] = r + i
size++
if size == capacity {
goto End
}
word = word ^ t
}
x++
for idx, word := range b.set[x:] {
for word != 0 {
r := uint(bits.TrailingZeros32(word))
t := word & ((^word) + 1)
myanswer[size] = r + (uint(x+idx) << 6)
size++
if size == capacity {
goto End
}
word = word ^ t
}
}
End:
if size > 0 {
return myanswer[size-1], myanswer[:size]
}
return 0, myanswer[:0]
}
// NextClear returns the next clear bit from the specified index,
// including possibly the current index
// along with an error code (true = valid, false = no bit found i.e. all bits are set)
func (b *BitSet32) NextClear(i uint) (uint, bool) {
x := int(i >> log2WordSize)
if x >= len(b.set) {
return 0, false
}
w := b.set[x]
w = w >> wordsIndex(i)
wA := allBits >> wordsIndex(i)
index := i + uint(bits.TrailingZeros32(^w))
if w != wA && index < b.length {
return index, true
}
x++
for x < len(b.set) {
index = uint(x)*wordSize + uint(bits.TrailingZeros32(^b.set[x]))
if b.set[x] != allBits && index < b.length {
return index, true
}
x++
}
return 0, false
}
// ClearAll clears the entire BitSet
func (b *BitSet32) ClearAll() *BitSet32 {
if b != nil && b.set != nil {
for i := range b.set {
b.set[i] = 0
}
}
return b
}
// wordCount returns the number of words used in a bit set
func (b *BitSet32) wordCount() int {
return wordsNeededUnbound(b.length)
}
// Clone this BitSet
func (b *BitSet32) Clone() *BitSet32 {
c := New(b.length)
if b.set != nil { // Clone should not modify current object
copy(c.set, b.set)
}
return c
}
// Copy into a destination BitSet using the Go array copy semantics:
// the number of bits copied is the minimum of the number of bits in the current
// BitSet (Len()) and the destination Bitset.
// We return the number of bits copied in the destination BitSet.
func (b *BitSet32) Copy(c *BitSet32) (count uint) {
if c == nil {
return
}
if b.set != nil { // Copy should not modify current object
copy(c.set, b.set)
}
count = c.length
if b.length < c.length {
count = b.length
}
// Cleaning the last word is needed to keep the invariant that other functions, such as Count, require
// that any bits in the last word that would exceed the length of the bitmask are set to 0.
c.cleanLastWord()
return
}
// CopyFull copies into a destination BitSet such that the destination is
// identical to the source after the operation, allocating memory if necessary.
func (b *BitSet32) CopyFull(c *BitSet32) {
if c == nil {
return
}
c.length = b.length
if len(b.set) == 0 {
if c.set != nil {
c.set = c.set[:0]
}
} else {
if cap(c.set) < len(b.set) {
c.set = make([]uint32, len(b.set))
} else {
c.set = c.set[:len(b.set)]
}
copy(c.set, b.set)
}
}
// Count (number of set bits).
// Also known as "popcount" or "population count".
func (b *BitSet32) Count() uint {
if b != nil && b.set != nil {
return uint(popcntSlice(b.set))
}
return 0
}
// Equal tests the equivalence of two BitSets.
// False if they are of different sizes, otherwise true
// only if all the same bits are set
func (b *BitSet32) Equal(c *BitSet32) bool {
if c == nil || b == nil {
return c == b
}
if b.length != c.length {
return false
}
if b.length == 0 { // if they have both length == 0, then could have nil set
return true
}
wn := b.wordCount()
for p := 0; p < wn; p++ {
if c.set[p] != b.set[p] {
return false
}
}
return true
}
func panicIfNull(b *BitSet32) {
if b == nil {
panic(Error("BitSet must not be null"))
}
}
// Difference of base set and other set
// This is the BitSet equivalent of &^ (and not)
func (b *BitSet32) Difference(compare *BitSet32) (result *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
result = b.Clone() // clone b (in case b is bigger than compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
for i := 0; i < l; i++ {
result.set[i] = b.set[i] &^ compare.set[i]
}
return
}
// DifferenceCardinality computes the cardinality of the differnce
func (b *BitSet32) DifferenceCardinality(compare *BitSet32) uint {
panicIfNull(b)
panicIfNull(compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
cnt := uint64(0)
cnt += popcntMaskSlice(b.set[:l], compare.set[:l])
cnt += popcntSlice(b.set[l:])
return uint(cnt)
}
// InPlaceDifference computes the difference of base set and other set
// This is the BitSet equivalent of &^ (and not)
func (b *BitSet32) InPlaceDifference(compare *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
for i := 0; i < l; i++ {
b.set[i] &^= compare.set[i]
}
}
// Convenience function: return two bitsets ordered by
// increasing length. Note: neither can be nil
func sortByLength(a *BitSet32, b *BitSet32) (ap *BitSet32, bp *BitSet32) {
if a.length <= b.length {
ap, bp = a, b
} else {
ap, bp = b, a
}
return
}
// Intersection of base set and other set
// This is the BitSet equivalent of & (and)
func (b *BitSet32) Intersection(compare *BitSet32) (result *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
result = New(b.length)
for i, word := range b.set {
result.set[i] = word & compare.set[i]
}
return
}
// IntersectionCardinality computes the cardinality of the union
func (b *BitSet32) IntersectionCardinality(compare *BitSet32) uint {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
cnt := popcntAndSlice(b.set, compare.set)
return uint(cnt)
}
// InPlaceIntersection destructively computes the intersection of
// base set and the compare set.
// This is the BitSet equivalent of & (and)
func (b *BitSet32) InPlaceIntersection(compare *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
for i := 0; i < l; i++ {
b.set[i] &= compare.set[i]
}
for i := l; i < len(b.set); i++ {
b.set[i] = 0
}
if compare.length > 0 {
if compare.length-1 >= b.length {
b.extendSet(compare.length - 1)
}
}
}
// Union of base set and other set
// This is the BitSet equivalent of | (or)
func (b *BitSet32) Union(compare *BitSet32) (result *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
result = compare.Clone()
for i, word := range b.set {
result.set[i] = word | compare.set[i]
}
return
}
// UnionCardinality computes the cardinality of the uniton of the base set
// and the compare set.
func (b *BitSet32) UnionCardinality(compare *BitSet32) uint {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
cnt := popcntOrSlice(b.set, compare.set)
if len(compare.set) > len(b.set) {
cnt += popcntSlice(compare.set[len(b.set):])
}
return uint(cnt)
}
// InPlaceUnion creates the destructive union of base set and compare set.
// This is the BitSet equivalent of | (or).
func (b *BitSet32) InPlaceUnion(compare *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
if compare.length > 0 && compare.length-1 >= b.length {
b.extendSet(compare.length - 1)
}
for i := 0; i < l; i++ {
b.set[i] |= compare.set[i]
}
if len(compare.set) > l {
for i := l; i < len(compare.set); i++ {
b.set[i] = compare.set[i]
}
}
}
// SymmetricDifference of base set and other set
// This is the BitSet equivalent of ^ (xor)
func (b *BitSet32) SymmetricDifference(compare *BitSet32) (result *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
// compare is bigger, so clone it
result = compare.Clone()
for i, word := range b.set {
result.set[i] = word ^ compare.set[i]
}
return
}
// SymmetricDifferenceCardinality computes the cardinality of the symmetric difference
func (b *BitSet32) SymmetricDifferenceCardinality(compare *BitSet32) uint {
panicIfNull(b)
panicIfNull(compare)
b, compare = sortByLength(b, compare)
cnt := popcntXorSlice(b.set, compare.set)
if len(compare.set) > len(b.set) {
cnt += popcntSlice(compare.set[len(b.set):])
}
return uint(cnt)
}
// InPlaceSymmetricDifference creates the destructive SymmetricDifference of base set and other set
// This is the BitSet equivalent of ^ (xor)
func (b *BitSet32) InPlaceSymmetricDifference(compare *BitSet32) {
panicIfNull(b)
panicIfNull(compare)
l := int(compare.wordCount())
if l > int(b.wordCount()) {
l = int(b.wordCount())
}
if compare.length > 0 && compare.length-1 >= b.length {
b.extendSet(compare.length - 1)
}
for i := 0; i < l; i++ {
b.set[i] ^= compare.set[i]
}
if len(compare.set) > l {
for i := l; i < len(compare.set); i++ {
b.set[i] = compare.set[i]
}
}
}
// Is the length an exact multiple of word sizes?
func (b *BitSet32) isLenExactMultiple() bool {
return wordsIndex(b.length) == 0
}
// Clean last word by setting unused bits to 0
func (b *BitSet32) cleanLastWord() {
if !b.isLenExactMultiple() {
b.set[len(b.set)-1] &= allBits >> (wordSize - wordsIndex(b.length))
}
}
// Complement computes the (local) complement of a bitset (up to length bits)
func (b *BitSet32) Complement() (result *BitSet32) {
panicIfNull(b)
result = New(b.length)
for i, word := range b.set {
result.set[i] = ^word
}
result.cleanLastWord()
return
}
// All returns true if all bits are set, false otherwise. Returns true for
// empty sets.
func (b *BitSet32) All() bool {
panicIfNull(b)
return b.Count() == b.length
}
// None returns true if no bit is set, false otherwise. Returns true for
// empty sets.
func (b *BitSet32) None() bool {
panicIfNull(b)
if b != nil && b.set != nil {
for _, word := range b.set {
if word > 0 {
return false
}
}
}
return true
}
// Any returns true if any bit is set, false otherwise
func (b *BitSet32) Any() bool {
panicIfNull(b)
return !b.None()
}
// IsSuperSet returns true if this is a superset of the other set
func (b *BitSet32) IsSuperSet(other *BitSet32) bool {
for i, e := other.NextSet(0); e; i, e = other.NextSet(i + 1) {
if !b.Test(i) {
return false
}
}
return true
}
// IsStrictSuperSet returns true if this is a strict superset of the other set
func (b *BitSet32) IsStrictSuperSet(other *BitSet32) bool {
return b.Count() > other.Count() && b.IsSuperSet(other)
}
// DumpAsBits dumps a bit set as a string of bits
func (b *BitSet32) DumpAsBits() string {
if b.set == nil {
return "."
}
buffer := bytes.NewBufferString("")
i := len(b.set) - 1
for ; i >= 0; i-- {
fmt.Fprintf(buffer, "%064b.", b.set[i])
}
return buffer.String()
}
// BinaryStorageSize returns the binary storage requirements
func (b *BitSet32) BinaryStorageSize() int {
nWords := b.wordCount()
return binary.Size(uint64(0)) + binary.Size(b.set[:nWords])
}
// WriteTo writes a BitSet to a stream
func (b *BitSet32) WriteTo(stream io.Writer) (int64, error) {
length := uint64(b.length)
// Write length
err := binary.Write(stream, binaryOrder, length)
if err != nil {
return 0, err
}
// Write set
// current implementation of bufio.Writer is more memory efficient than
// binary.Write for large set
writer := bufio.NewWriter(stream)
var item = make([]byte, binary.Size(uint32(0))) // for serializing one uint32
nWords := b.wordCount()
for i := range b.set[:nWords] {
binaryOrder.PutUint32(item, b.set[i])
if nn, err := writer.Write(item); err != nil {
return int64(i*binary.Size(uint32(0)) + nn), err
}
}
err = writer.Flush()
return int64(b.BinaryStorageSize()), err
}
// ReadFrom reads a BitSet from a stream written using WriteTo
func (b *BitSet32) ReadFrom(stream io.Reader) (int64, error) {
var length uint64
// Read length first
err := binary.Read(stream, binaryOrder, &length)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset := New(uint(length))
if uint64(newset.length) != length {
return 0, errors.New("unmarshalling error: type mismatch")
}
var item [4]byte
nWords := wordsNeeded(uint(length))
reader := bufio.NewReader(io.LimitReader(stream, 4*int64(nWords)))
for i := 0; i < nWords; i++ {
if _, err := io.ReadFull(reader, item[:]); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset.set[i] = binaryOrder.Uint32(item[:])
}
*b = *newset
return int64(b.BinaryStorageSize()), nil
}

View File

@@ -0,0 +1,43 @@
package bitset32
// MaxConsecutiveOne
func (b *BitSet32) MaxConsecutiveOne(start, end uint) uint {
return b.consecutiveMaxCount(start, end, true)
}
// MaxConsecutiveZero
func (b *BitSet32) MaxConsecutiveZero(start, end uint) uint {
return b.consecutiveMaxCount(start, end, false)
}
func (b *BitSet32) consecutiveMaxCount(start, end uint, flag bool) uint {
flag = !flag
if end > b.Len() {
end = b.Len()
}
if start >= b.Len() {
return 0
}
if start > end {
return 0
}
rt, sum := uint(0), uint(0)
for i := start; i < end; i++ {
if xor(flag, b.Test(i)) {
sum++
continue
}
if sum > rt {
rt = sum
}
sum = 0
}
if sum > rt {
rt = sum
}
return rt
}
func xor(a, b bool) bool {
return (a || b) && !(a && b)
}

View File

@@ -0,0 +1,454 @@
// Copyright 2014 Will Fitzgerald. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file tests bit sets
package bitset32
import (
"math/rand"
"testing"
)
func BenchmarkSet(b *testing.B) {
b.StopTimer()
r := rand.New(rand.NewSource(0))
sz := 100000
s := New(uint(sz))
b.StartTimer()
for i := 0; i < b.N; i++ {
s.Set(uint(r.Int31n(int32(sz))))
}
}
func BenchmarkGetTest(b *testing.B) {
b.StopTimer()
r := rand.New(rand.NewSource(0))
sz := 100000
s := New(uint(sz))
b.StartTimer()
for i := 0; i < b.N; i++ {
s.Test(uint(r.Int31n(int32(sz))))
}
}
func BenchmarkSetExpand(b *testing.B) {
b.StopTimer()
sz := uint(100000)
b.StartTimer()
for i := 0; i < b.N; i++ {
var s BitSet32
s.Set(sz)
}
}
// go test -bench=Count
func BenchmarkCount(b *testing.B) {
b.StopTimer()
s := New(100000)
for i := 0; i < 100000; i += 100 {
s.Set(uint(i))
}
b.StartTimer()
for i := 0; i < b.N; i++ {
s.Count()
}
}
// go test -bench=Iterate
func BenchmarkIterate(b *testing.B) {
b.StopTimer()
s := New(10000)
for i := 0; i < 10000; i += 3 {
s.Set(uint(i))
}
b.StartTimer()
for j := 0; j < b.N; j++ {
c := uint(0)
for i, e := s.NextSet(0); e; i, e = s.NextSet(i + 1) {
c++
}
}
}
// go test -bench=SparseIterate
func BenchmarkSparseIterate(b *testing.B) {
b.StopTimer()
s := New(100000)
for i := 0; i < 100000; i += 30 {
s.Set(uint(i))
}
b.StartTimer()
for j := 0; j < b.N; j++ {
c := uint(0)
for i, e := s.NextSet(0); e; i, e = s.NextSet(i + 1) {
c++
}
}
}
// go test -bench=LemireCreate
// see http://lemire.me/blog/2016/09/22/swift-versus-java-the-BitSet32-performance-test/
func BenchmarkLemireCreate(b *testing.B) {
for i := 0; i < b.N; i++ {
bitmap := New(0) // we force dynamic memory allocation
for v := uint(0); v <= 100000000; v += 100 {
bitmap.Set(v)
}
}
}
// go test -bench=LemireCount
// see http://lemire.me/blog/2016/09/22/swift-versus-java-the-BitSet32-performance-test/
func BenchmarkLemireCount(b *testing.B) {
bitmap := New(100000000)
for v := uint(0); v <= 100000000; v += 100 {
bitmap.Set(v)
}
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
sum += bitmap.Count()
}
if sum == 0 { // added just to fool ineffassign
return
}
}
// go test -bench=LemireIterate
// see http://lemire.me/blog/2016/09/22/swift-versus-java-the-BitSet32-performance-test/
func BenchmarkLemireIterate(b *testing.B) {
bitmap := New(100000000)
for v := uint(0); v <= 100000000; v += 100 {
bitmap.Set(v)
}
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
sum++
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
// go test -bench=LemireIterateb
// see http://lemire.me/blog/2016/09/22/swift-versus-java-the-BitSet32-performance-test/
func BenchmarkLemireIterateb(b *testing.B) {
bitmap := New(100000000)
for v := uint(0); v <= 100000000; v += 100 {
bitmap.Set(v)
}
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
sum += j
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
// go test -bench=BenchmarkLemireIterateManyb
// see http://lemire.me/blog/2016/09/22/swift-versus-java-the-BitSet32-performance-test/
func BenchmarkLemireIterateManyb(b *testing.B) {
bitmap := New(100000000)
for v := uint(0); v <= 100000000; v += 100 {
bitmap.Set(v)
}
buffer := make([]uint, 256)
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
j := uint(0)
j, buffer = bitmap.NextSetMany(j, buffer)
for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j, buffer) {
for k := range buffer {
sum += buffer[k]
}
j++
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
func setRnd(bits []uint32, halfings int) {
var rndsource = rand.NewSource(0)
var rnd = rand.New(rndsource)
for i := range bits {
bits[i] = 0xFFFFFFFF
for j := 0; j < halfings; j++ {
bits[i] &= rnd.Uint32()
}
}
}
// go test -bench=BenchmarkFlorianUekermannIterateMany
func BenchmarkFlorianUekermannIterateMany(b *testing.B) {
var input = make([]uint32, 68)
setRnd(input, 4)
var bitmap = From(input)
buffer := make([]uint, 256)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
var last, batch = bitmap.NextSetMany(0, buffer)
for len(batch) > 0 {
for _, idx := range batch {
checksum += idx
}
last, batch = bitmap.NextSetMany(last+1, batch)
}
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannIterateManyReg(b *testing.B) {
var input = make([]uint32, 68)
setRnd(input, 4)
var bitmap = From(input)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
checksum += j
}
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
// function provided by FlorianUekermann
func good(set []uint32) (checksum uint) {
for wordIdx, word := range set {
var wordIdx = uint(wordIdx * 64)
for word != 0 {
var bitIdx = uint(trailingZeroes32(word))
word ^= 1 << bitIdx
var index = wordIdx + bitIdx
checksum += index
}
}
return checksum
}
func BenchmarkFlorianUekermannIterateManyComp(b *testing.B) {
var input = make([]uint32, 68)
setRnd(input, 4)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
checksum += good(input)
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
/////// Mid density
// go test -bench=BenchmarkFlorianUekermannLowDensityIterateMany
func BenchmarkFlorianUekermannLowDensityIterateMany(b *testing.B) {
var input = make([]uint32, 1000000)
var rnd = rand.NewSource(0).(rand.Source64)
for i := 0; i < 50000; i++ {
input[rnd.Uint64()%1000000] = 1
}
var bitmap = From(input)
buffer := make([]uint, 256)
b.ResetTimer()
var sum = uint(0)
for i := 0; i < b.N; i++ {
j := uint(0)
j, buffer = bitmap.NextSetMany(j, buffer)
for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j, buffer) {
for k := range buffer {
sum += buffer[k]
}
j++
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannLowDensityIterateManyReg(b *testing.B) {
var input = make([]uint32, 1000000)
var rnd = rand.NewSource(0).(rand.Source64)
for i := 0; i < 50000; i++ {
input[rnd.Uint64()%1000000] = 1
}
var bitmap = From(input)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
checksum += j
}
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannLowDensityIterateManyComp(b *testing.B) {
var input = make([]uint32, 1000000)
var rnd = rand.NewSource(0).(rand.Source64)
for i := 0; i < 50000; i++ {
input[rnd.Uint64()%1000000] = 1
}
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
checksum += good(input)
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
/////// Mid density
// go test -bench=BenchmarkFlorianUekermannMidDensityIterateMany
func BenchmarkFlorianUekermannMidDensityIterateMany(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 3000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 32)
}
var bitmap = From(input)
buffer := make([]uint, 256)
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
j := uint(0)
j, buffer = bitmap.NextSetMany(j, buffer)
for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j, buffer) {
for k := range buffer {
sum += buffer[k]
}
j++
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannMidDensityIterateManyReg(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 3000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 32)
}
var bitmap = From(input)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
checksum += j
}
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannMidDensityIterateManyComp(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 3000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 32)
}
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
checksum += good(input)
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
////////// High density
func BenchmarkFlorianUekermannMidStrongDensityIterateMany(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 20000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 64)
}
var bitmap = From(input)
buffer := make([]uint, 256)
b.ResetTimer()
sum := uint(0)
for i := 0; i < b.N; i++ {
j := uint(0)
j, buffer = bitmap.NextSetMany(j, buffer)
for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j, buffer) {
for k := range buffer {
sum += buffer[k]
}
j++
}
}
if sum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannMidStrongDensityIterateManyReg(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 20000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 32)
}
var bitmap = From(input)
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
for j, e := bitmap.NextSet(0); e; j, e = bitmap.NextSet(j + 1) {
checksum += j
}
}
if checksum == 0 { // added just to fool ineffassign
return
}
}
func BenchmarkFlorianUekermannMidStrongDensityIterateManyComp(b *testing.B) {
var input = make([]uint32, 1000000)
var rndSource = rand.NewSource(0)
var rnd = rand.New(rndSource)
for i := 0; i < 20000000; i++ {
input[rnd.Uint32()%1000000] |= uint32(1) << (rnd.Uint32() % 32)
}
b.ResetTimer()
var checksum = uint(0)
for i := 0; i < b.N; i++ {
checksum += good(input)
}
if checksum == 0 { // added just to fool ineffassign
return
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,46 @@
package bitset32
import (
bitset64 "github.com/bits-and-blooms/bitset"
)
type BitSet64 struct {
*bitset64.BitSet
}
// TODO: TestFunc
func (b *BitSet64) MaxConsecutiveOne(start, end uint) uint {
return b.continueMaxCount(start, end, true)
}
func (b *BitSet64) MaxConsecutiveZero(start, end uint) uint {
return b.continueMaxCount(start, end, false)
}
func (b *BitSet64) continueMaxCount(start, end uint, flag bool) uint {
flag = !flag
if end > b.Len() {
end = b.Len()
}
if start >= b.Len() {
return 0
}
if start > end {
return 0
}
rt, sum := uint(0), uint(0)
for i := start; i < end; i++ {
if xor(flag, b.Test(i)) {
sum++
continue
}
if sum > rt {
rt = sum
}
sum = 0
}
if sum > rt {
rt = sum
}
return rt
}

View File

@@ -0,0 +1,174 @@
package bitset32
import (
"math"
"math/rand"
"testing"
"time"
"github.com/bits-and-blooms/bitset"
)
var opc int
var opcT int
var ft string = "%v|pos:%9X|opc:%9X|result:%v\n"
var rt string = "%v|opc:%9d|pass:%9d\n"
var opNum = 100
var bitTestNum = 10000
var randNum = math.MaxInt32 / 2
func TestBitSet(t *testing.T) {
var b32 = New(1)
var b64 = bitset.New(1)
res := true
pos := uint(0)
rand.Seed(time.Now().Unix())
for j := 0; j < 1; j++ {
// Test, Set,
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
b32 = b32.Set(uint(pos))
b64 = b64.Set(uint(pos))
res = b32.Test(pos) == b64.Test(pos)
opc++
if res {
opcT++
} else {
t.Log(ft, time.Now(), pos, opc, res)
}
}
// Clear
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
b32 = b32.Clear(uint(pos))
b64 = b64.Clear(uint(pos))
res = b32.Test(pos) == b64.Test(pos)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// SetTo = Set + Clear
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
value := rand.Intn(randNum)%2 == 1
b32 = b32.SetTo(uint(pos), value)
b64 = b64.SetTo(uint(pos), value)
res = b32.Test(pos) == b64.Test(pos)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// Flip
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
b64 = b64.Flip(pos)
b32 = b32.Flip(pos)
res = isSameBitset(b32, b64)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// Flip Range
for i := 0; i < opNum; i++ {
start, end := uint(rand.Intn(randNum)), uint(rand.Intn(randNum))
if start > end {
start, end = end, start
}
b64 = b64.FlipRange(start, end)
b32 = b32.FlipRange(start, end)
res = isSameBitset(b32, b64)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// InsertAt
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
b64 = b64.InsertAt(pos)
b32 = b32.InsertAt(pos)
res = isSameBitset(b32, b64)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// DeleteAt
for i := 0; i < opNum; i++ {
pos = uint(rand.Intn(randNum))
if b64.Len() < pos || b32.Len() < pos {
continue
}
b64 = b64.DeleteAt(pos)
b32 = b32.DeleteAt(pos)
res = isSameBitset(b32, b64)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
// Compact, Shrink
for i := 0; i < opNum; i++ {
b32 = b32.Compact()
b64 = b64.Compact()
res = isSameBitset(b32, b64)
opc++
if res {
opcT++
} else {
t.Logf(ft, time.Now().Unix(), pos, opc, res)
}
}
}
// Compact
b32 = b32.Compact()
b64 = b64.Compact()
res = isSameBitset(b32, b64)
t.Log("Compact:", res)
bs64 := &BitSet64{b64}
t.Log("Max Count:", b32.MaxConsecutiveOne(0, b32.Len()), bs64.MaxConsecutiveOne(0, b64.Len()))
t.Log("String:", b32.String() == b64.String())
t.Logf(rt, time.Now().Unix(), opc, opcT)
}
func isSameBitset(b32 *BitSet32, b64 *bitset.BitSet) bool {
if b32.Len() != b64.Len() {
return false
}
for i := 0; i < bitTestNum; i++ {
pos := uint(rand.Intn(randNum))
if b32.Test(pos) != b64.Test(pos) {
return false
}
}
return true
}
/*
Running tool: C:\support\go\bin\go.exe test -timeout 30s -run ^TestBitSet$ bitset -v
=== RUN TestBitSet
d:\workspace\DataStruct\go\bitset\bitset_test.go:139: Max Count: 23067608 23067608
d:\workspace\DataStruct\go\bitset\bitset_test.go:140: true
d:\workspace\DataStruct\go\bitset\bitset_test.go:141: 1678626066|opc: 800|pass: 800
--- PASS: TestBitSet (23.61s)
PASS
ok bitset 24.243s
*/

View File

@@ -0,0 +1,5 @@
module github.com/pointernil/bitset32
go 1.19
require github.com/bits-and-blooms/bitset v1.5.0 // direct

View File

@@ -0,0 +1,2 @@
github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8=
github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA=

View File

@@ -0,0 +1,43 @@
package bitset32
import "math/bits"
func popcntSlice(s []uint32) uint64 {
var cnt int
for _, x := range s {
cnt += bits.OnesCount32(x)
}
return uint64(cnt)
}
func popcntMaskSlice(s, m []uint32) uint64 {
var cnt int
for i := range s {
cnt += bits.OnesCount32(s[i] &^ m[i])
}
return uint64(cnt)
}
func popcntAndSlice(s, m []uint32) uint64 {
var cnt int
for i := range s {
cnt += bits.OnesCount32(s[i] & m[i])
}
return uint64(cnt)
}
func popcntOrSlice(s, m []uint32) uint64 {
var cnt int
for i := range s {
cnt += bits.OnesCount32(s[i] | m[i])
}
return uint64(cnt)
}
func popcntXorSlice(s, m []uint32) uint64 {
var cnt int
for i := range s {
cnt += bits.OnesCount32(s[i] ^ m[i])
}
return uint64(cnt)
}

View File

@@ -0,0 +1,10 @@
//go:build go1.9
// +build go1.9
package bitset32
import "math/bits"
func trailingZeroes32(v uint32) uint {
return uint(bits.TrailingZeros32(v))
}

View File

@@ -0,0 +1,524 @@
package bytearray
import (
"encoding/binary"
"errors"
"io"
"math"
"sync"
)
// ByteArray 提供字节数组的读写操作,支持大小端字节序
type ByteArray struct {
buf []byte
posWrite int
posRead int
endian binary.ByteOrder
}
// 默认使用大端字节序
var defaultEndian binary.ByteOrder = binary.BigEndian
// bufferpool 用于重用ByteArray实例
var bufferpool = &sync.Pool{
New: func() interface{} {
return &ByteArray{endian: defaultEndian}
},
}
// CreateByteArray 创建一个新的ByteArray实例使用指定的字节数组
func CreateByteArray(bytes ...[]byte) *ByteArray {
var ba *ByteArray
if len(bytes) == 0 { //如果是0则为新创建
ba = bufferpool.Get().(*ByteArray)
} else { //读序列
ba = &ByteArray{endian: defaultEndian}
}
for _, num := range bytes {
ba.buf = append(ba.buf, num...)
}
ba.ResetPos()
return ba
}
// releaseByteArray 将ByteArray实例放回池中以便重用
func releaseByteArray(ba *ByteArray) {
ba.Reset()
bufferpool.Put(ba)
}
// Length 返回字节数组的总长度
func (ba *ByteArray) Length() int {
return len(ba.buf)
}
// Available 返回可读取的字节数
func (ba *ByteArray) Available() int {
return ba.Length() - ba.posRead
}
// SetEndian 设置字节序(大端或小端)
func (ba *ByteArray) SetEndian(endian binary.ByteOrder) {
ba.endian = endian
}
// GetEndian 获取当前字节序
func (ba *ByteArray) GetEndian() binary.ByteOrder {
if ba.endian == nil {
return defaultEndian
}
return ba.endian
}
// Grow 确保缓冲区有足够的空间
func (ba *ByteArray) Grow(size int) {
if size <= 0 {
return
}
required := ba.posWrite + size
if len(ba.buf) >= required {
return
}
newBuf := make([]byte, required)
copy(newBuf, ba.buf)
ba.buf = newBuf
}
// SetWritePos 设置写指针位置
func (ba *ByteArray) SetWritePos(pos int) error {
if pos < 0 || pos > ba.Length() {
return io.EOF
}
ba.posWrite = pos
return nil
}
// SetWriteEnd 将写指针设置到末尾
func (ba *ByteArray) SetWriteEnd() {
ba.posWrite = ba.Length()
}
// GetWritePos 获取写指针位置
func (ba *ByteArray) GetWritePos() int {
return ba.posWrite
}
// SetReadPos 设置读指针位置
func (ba *ByteArray) SetReadPos(pos int) error {
if pos < 0 || pos > ba.Length() {
return io.EOF
}
ba.posRead = pos
return nil
}
// SetReadEnd 将读指针设置到末尾
func (ba *ByteArray) SetReadEnd() {
ba.posRead = ba.Length()
}
// GetReadPos 获取读指针位置
func (ba *ByteArray) GetReadPos() int {
return ba.posRead
}
// ResetPos 重置读写指针到开始位置
func (ba *ByteArray) ResetPos() {
ba.posWrite = 0
ba.posRead = 0
}
// Reset 重置ByteArray清空缓冲区并重置指针
func (ba *ByteArray) Reset() {
ba.buf = nil
ba.ResetPos()
}
// Bytes 返回完整的字节数组
func (ba *ByteArray) Bytes() []byte {
defer releaseByteArray(ba) //这里是写数组,写完后退出时释放线程池
return ba.buf
}
// BytesAvailable 返回从当前读指针位置到末尾的字节数组
func (ba *ByteArray) BytesAvailable() []byte {
return ba.buf[ba.posRead:]
}
// ========== 写入方法 ==========
// Write 写入字节数组
func (ba *ByteArray) Write(bytes []byte) (int, error) {
if len(bytes) == 0 {
return 0, nil
}
ba.Grow(len(bytes))
n := copy(ba.buf[ba.posWrite:], bytes)
ba.posWrite += n
return n, nil
}
// WriteByte 写入单个字节
func (ba *ByteArray) WriteByte(b byte) error {
ba.Grow(1)
ba.buf[ba.posWrite] = b
ba.posWrite++
return nil
}
// WriteInt8 写入int8
func (ba *ByteArray) WriteInt8(value int8) error {
return ba.WriteByte(byte(value))
}
// WriteInt16 写入int16根据当前字节序处理
func (ba *ByteArray) WriteInt16(value int16) error {
return ba.writeNumber(value)
}
// WriteUInt16 写入uint16根据当前字节序处理
func (ba *ByteArray) WriteUInt16(value uint16) error {
return ba.writeNumber(value)
}
// WriteInt32 写入int32根据当前字节序处理
func (ba *ByteArray) WriteInt32(value int32) error {
return ba.writeNumber(value)
}
// WriteUInt32 写入uint32根据当前字节序处理
func (ba *ByteArray) WriteUInt32(value uint32) error {
return ba.writeNumber(value)
}
// WriteInt64 写入int64根据当前字节序处理
func (ba *ByteArray) WriteInt64(value int64) error {
return ba.writeNumber(value)
}
// Writeuint32 写入uint32根据当前字节序处理
func (ba *ByteArray) Writeuint32(value uint32) error {
return ba.writeNumber(value)
}
// WriteFloat32 写入float32根据当前字节序处理
func (ba *ByteArray) WriteFloat32(value float32) error {
return ba.writeNumber(math.Float32bits(value))
}
// WriteFloat64 写入float64根据当前字节序处理
func (ba *ByteArray) WriteFloat64(value float64) error {
return ba.writeNumber(math.Float64bits(value))
}
// WriteBool 写入布尔值
func (ba *ByteArray) WriteBool(value bool) error {
var b byte
if value {
b = 1
} else {
b = 0
}
return ba.WriteByte(b)
}
// WriteString 写入字符串
func (ba *ByteArray) WriteString(value string) error {
_, err := ba.Write([]byte(value))
return err
}
// WriteUTF 写入UTF字符串带长度前缀
func (ba *ByteArray) WriteUTF(value string) error {
bytes := []byte(value)
if err := ba.WriteUInt16(uint16(len(bytes))); err != nil {
return err
}
_, err := ba.Write(bytes)
return err
}
// ReadUTF8Array 读取 UTF8 字符串数组(格式:先读取 Int32 长度,再读取多个 UTF 字符串)
func (ba *ByteArray) ReadUTF8Array() ([]string, error) {
count, err := ba.ReadInt32()
if err != nil {
return nil, err
}
if count < 0 {
return nil, errors.New("invalid array length")
}
array := make([]string, 0, count)
for i := 0; i < int(count); i++ {
str, err := ba.ReadUTF()
if err != nil {
return nil, err
}
array = append(array, str)
}
return array, nil
}
// ReadInt32Array 读取 Int32 数组(格式:先读取 Int32 长度,再读取多个 Int32
func (ba *ByteArray) ReadInt32Array() ([]int32, error) {
count, err := ba.ReadInt32()
if err != nil {
return nil, err
}
if count < 0 {
return nil, errors.New("invalid array length")
}
array := make([]int32, 0, count)
for i := 0; i < int(count); i++ {
val, err := ba.ReadInt32()
if err != nil {
return nil, err
}
array = append(array, val)
}
return array, nil
}
// WriteUTF8 写入UTF8字符串不带长度前缀
func (ba *ByteArray) WriteUTF8(value string) error {
_, err := ba.Write([]byte(value))
return err
}
// 通用写入数值方法
func (ba *ByteArray) writeNumber(value interface{}) error {
var size int
switch value.(type) {
case int8, uint8:
size = 1
case int16, uint16:
size = 2
case int32, uint32, float32:
size = 4
case int64, uint64, float64:
size = 8
default:
return errors.New("unsupported number type")
}
ba.Grow(size)
switch v := value.(type) {
case int8:
ba.buf[ba.posWrite] = byte(v)
case uint8:
ba.buf[ba.posWrite] = v
case int16:
ba.endian.PutUint16(ba.buf[ba.posWrite:], uint16(v))
case uint16:
ba.endian.PutUint16(ba.buf[ba.posWrite:], v)
case int32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], uint32(v))
case uint32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], v)
case int64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], uint64(v))
case uint64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], v)
case float32:
ba.endian.PutUint32(ba.buf[ba.posWrite:], math.Float32bits(v))
case float64:
ba.endian.PutUint64(ba.buf[ba.posWrite:], math.Float64bits(v))
}
ba.posWrite += size
return nil
}
// ========== 读取方法 ==========
// Read 读取字节数组到指定缓冲区
func (ba *ByteArray) Read(bytes []byte) (int, error) {
if len(bytes) == 0 {
return 0, nil
}
if ba.posRead+len(bytes) > ba.Length() {
return 0, io.EOF
}
n := copy(bytes, ba.buf[ba.posRead:])
ba.posRead += n
return n, nil
}
// ReadByte 读取单个字节
func (ba *ByteArray) ReadByte() (byte, error) {
if ba.posRead >= ba.Length() {
return 0, io.EOF
}
b := ba.buf[ba.posRead]
ba.posRead++
return b, nil
}
// ReadInt8 读取int8
func (ba *ByteArray) ReadInt8() (int8, error) {
b, err := ba.ReadByte()
return int8(b), err
}
// ReadUInt8 读取uint8
func (ba *ByteArray) ReadUInt8() (uint8, error) {
return ba.ReadByte()
}
// ReadInt16 读取int16根据当前字节序处理
func (ba *ByteArray) ReadInt16() (int16, error) {
var v uint16
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int16(v), nil
}
// ReadUInt16 读取uint16根据当前字节序处理
func (ba *ByteArray) ReadUInt16() (uint16, error) {
var v uint16
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadInt32 读取int32根据当前字节序处理
func (ba *ByteArray) ReadInt32() (int32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int32(v), nil
}
// ReadUInt32 读取uint32根据当前字节序处理
func (ba *ByteArray) ReadUInt32() (uint32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadInt64 读取int64根据当前字节序处理
func (ba *ByteArray) ReadInt64() (int64, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return int64(v), nil
}
// Readuint32 读取uint32根据当前字节序处理
func (ba *ByteArray) Readuint32() (uint32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return v, nil
}
// ReadFloat32 读取float32根据当前字节序处理
func (ba *ByteArray) ReadFloat32() (float32, error) {
var v uint32
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return math.Float32frombits(v), nil
}
// ReadFloat64 读取float64根据当前字节序处理
func (ba *ByteArray) ReadFloat64() (float64, error) {
var v uint64
if err := ba.readNumber(&v); err != nil {
return 0, err
}
return math.Float64frombits(v), nil
}
// ReadBool 读取布尔值
func (ba *ByteArray) ReadBool() (bool, error) {
b, err := ba.ReadByte()
if err != nil {
return false, err
}
return b != 0, nil
}
// ReadString 读取指定长度的字符串
func (ba *ByteArray) ReadString(length int) (string, error) {
if length < 0 {
return "", errors.New("invalid string length")
}
if ba.posRead+length > ba.Length() {
return "", io.EOF
}
str := string(ba.buf[ba.posRead : ba.posRead+length])
ba.posRead += length
return str, nil
}
// ReadUTF 读取UTF字符串带长度前缀
func (ba *ByteArray) ReadUTF() (string, error) {
length, err := ba.ReadUInt16()
if err != nil {
return "", err
}
return ba.ReadString(int(length))
}
// 通用读取数值方法
func (ba *ByteArray) readNumber(value interface{}) error {
var size int
switch value.(type) {
case *int16, *uint16:
size = 2
case *int32, *uint32, *float32:
size = 4
case *int64, *uint64, *float64:
size = 8
default:
return errors.New("unsupported number type")
}
if ba.posRead+size > ba.Length() {
return io.EOF
}
buf := ba.buf[ba.posRead : ba.posRead+size]
ba.posRead += size
switch v := value.(type) {
case *int16:
*v = int16(ba.endian.Uint16(buf))
case *uint16:
*v = ba.endian.Uint16(buf)
case *int32:
*v = int32(ba.endian.Uint32(buf))
case *uint32:
*v = ba.endian.Uint32(buf)
case *int64:
*v = int64(ba.endian.Uint64(buf))
case *uint64:
*v = ba.endian.Uint64(buf)
case *float32:
*v = math.Float32frombits(ba.endian.Uint32(buf))
case *float64:
*v = math.Float64frombits(ba.endian.Uint64(buf))
}
return nil
}

View File

@@ -0,0 +1,21 @@
package bytearray
import (
"testing"
)
func BenchmarkByteArray(b *testing.B) {
b.ReportAllocs()
//写入
testbyte:=CreateByteArray()
for i := 0; i < b.N; i++ {
testbyte.writeNumber(42)
testbyte.writeNumber(12345.4)
testbyte.writeNumber(123456789)
testbyte.writeNumber(1234567890123456789)
testbyte.WriteString("test string")
testbyte.Bytes()
}
}

View File

@@ -0,0 +1,3 @@
module blazing/common/utils/bytearray
go 1.20

View File

@@ -0,0 +1,18 @@
name: Go Checks
on:
pull_request:
push:
branches: ["main"]
workflow_dispatch:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event_name == 'push' && github.sha || github.ref }}
cancel-in-progress: true
jobs:
go-check:
uses: ipdxco/unified-github-workflows/.github/workflows/go-check.yml@v1.0

View File

@@ -0,0 +1,20 @@
name: Go Test
on:
pull_request:
push:
branches: ["main"]
workflow_dispatch:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event_name == 'push' && github.sha || github.ref }}
cancel-in-progress: true
jobs:
go-test:
uses: ipdxco/unified-github-workflows/.github/workflows/go-test.yml@v1.0
secrets:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -0,0 +1,19 @@
name: Release Checker
on:
pull_request_target:
paths: [ 'version.json' ]
types: [ opened, synchronize, reopened, labeled, unlabeled ]
workflow_dispatch:
permissions:
contents: write
pull-requests: write
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
release-check:
uses: ipdxco/unified-github-workflows/.github/workflows/release-check.yml@v1.0

View File

@@ -0,0 +1,17 @@
name: Releaser
on:
push:
paths: [ 'version.json' ]
workflow_dispatch:
permissions:
contents: write
concurrency:
group: ${{ github.workflow }}-${{ github.sha }}
cancel-in-progress: true
jobs:
releaser:
uses: ipdxco/unified-github-workflows/.github/workflows/releaser.yml@v1.0

View File

@@ -0,0 +1,18 @@
name: Tag Push Checker
on:
push:
tags:
- v*
permissions:
contents: read
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
releaser:
uses: ipdxco/unified-github-workflows/.github/workflows/tagpush.yml@v1.0

View File

@@ -0,0 +1,5 @@
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

View File

@@ -0,0 +1,19 @@
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -0,0 +1,313 @@
go-jsonrpc
==================
[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/filecoin-project/go-jsonrpc)
[![](https://img.shields.io/badge/made%20by-Protocol%20Labs-blue.svg?style=flat-square)](https://protocol.ai)
> Low Boilerplate JSON-RPC 2.0 library
## Usage examples
### Server
```go
// Have a type with some exported methods
type SimpleServerHandler struct {
n int
}
func (h *SimpleServerHandler) AddGet(in int) int {
h.n += in
return h.n
}
func main() {
// create a new server instance
rpcServer := jsonrpc.NewServer()
// create a handler instance and register it
serverHandler := &SimpleServerHandler{}
rpcServer.Register("SimpleServerHandler", serverHandler)
// rpcServer is now http.Handler which will serve jsonrpc calls to SimpleServerHandler.AddGet
// a method with a single int param, and an int response. The server supports both http and websockets.
// serve the api
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
fmt.Println("URL: ", "ws://"+testServ.Listener.Addr().String())
[..do other app stuff / wait..]
}
```
### Client
```go
func start() error {
// Create a struct where each field is an exported function with signatures matching rpc calls
var client struct {
AddGet func(int) int
}
// Make jsonrp populate func fields in the struct with JSONRPC calls
closer, err := jsonrpc.NewClient(context.Background(), rpcURL, "SimpleServerHandler", &client, nil)
if err != nil {
return err
}
defer closer()
...
n := client.AddGet(10)
// if the server is the one from the example above, n = 10
n := client.AddGet(2)
// if the server is the one from the example above, n = 12
}
```
### Supported function signatures
```go
type _ interface {
// No Params / Return val
Func1()
// With Params
// Note: If param types implement json.[Un]Marshaler, go-jsonrpc will use it
Func2(param1 int, param2 string, param3 struct{A int})
// Returning errors
// * For some connection errors, go-jsonrpc will return jsonrpc.RPCConnectionError{}.
// * RPC-returned errors will be constructed with basic errors.New(__"string message"__)
// * JSON-RPC error codes can be mapped to typed errors with jsonrpc.Errors - https://pkg.go.dev/github.com/filecoin-project/go-jsonrpc#Errors
// * For typed errors to work, server needs to be constructed with the `WithServerErrors`
// option, and the client needs to be constructed with the `WithErrors` option
Func3() error
// Returning a value
// Note: The value must be serializable with encoding/json.
Func4() int
// Returning a value and an error
// Note: if the handler returns an error and a non-zero value, the value will not
// be returned to the client - the client will see a zero value.
Func4() (int, error)
// With context
// * Context isn't passed as JSONRPC param, instead it has a number of different uses
// * When the context is cancelled on the client side, context cancellation should propagate to the server handler
// * In http mode the http request will be aborted
// * In websocket mode the client will send a `xrpc.cancel` with a single param containing ID of the cancelled request
// * If the context contains an opencensus trace span, it will be propagated to the server through a
// `"Meta": {"SpanContext": base64.StdEncoding.EncodeToString(propagation.Binary(span.SpanContext()))}` field in
// the jsonrpc request
//
Func5(ctx context.Context, param1 string) error
// With non-json-serializable (e.g. interface) params
// * There are client and server options which make it possible to register transformers for types
// to make them json-(de)serializable
// * Server side: jsonrpc.WithParamDecoder(new(io.Reader), func(ctx context.Context, b []byte) (reflect.Value, error) { ... }
// * Client side: jsonrpc.WithParamEncoder(new(io.Reader), func(value reflect.Value) (reflect.Value, error) { ... }
// * For io.Reader specifically there's a simple param encoder/decoder implementation in go-jsonrpc/httpio package
// which will pass reader data through separate http streams on a different hanhler.
// * Note: a similar mechanism for return value transformation isn't supported yet
Func6(r io.Reader)
// Returning a channel
// * Only supported in websocket mode
// * If no error is returned, the return value will be an int channelId
// * When the server handler writes values into the channel, the client will receive `xrpc.ch.val` notifications
// with 2 params: [chanID: int, value: any]
// * When the channel is closed the client will receive `xrpc.ch.close` notification with a single param: [chanId: int]
// * The client-side channel will be closed when the websocket connection breaks; Server side will discard writes to
// the channel. Handlers should rely on the context to know when to stop writing to the returned channel.
// NOTE: There is no good backpressure mechanism implemented for channels, returning values faster that the client can
// receive them may cause memory leaks.
Func7(ctx context.Context, param1 int, param2 string) (<-chan int, error)
}
```
### Custom Transport Feature
The go-jsonrpc library supports creating clients with custom transport mechanisms (e.g. use for IPC). This allows for greater flexibility in how requests are sent and received, enabling the use of custom protocols, special handling of requests, or integration with other systems.
#### Example Usage of Custom Transport
Here is an example demonstrating how to create a custom client with a custom transport mechanism:
```go
// Setup server
serverHandler := &SimpleServerHandler{} // some type with methods
rpcServer := jsonrpc.NewServer()
rpcServer.Register("SimpleServerHandler", serverHandler)
// Custom doRequest function
doRequest := func(ctx context.Context, body []byte) (io.ReadCloser, error) {
reader := bytes.NewReader(body)
pr, pw := io.Pipe()
go func() {
defer pw.Close()
rpcServer.HandleRequest(ctx, reader, pw) // handle the rpc frame
}()
return pr, nil
}
var client struct {
Add func(int) error
}
// Create custom client
closer, err := jsonrpc.NewCustomClient("SimpleServerHandler", []interface{}{&client}, doRequest)
if err != nil {
log.Fatalf("Failed to create client: %v", err)
}
defer closer()
// Use the client
if err := client.Add(10); err != nil {
log.Fatalf("Failed to call Add: %v", err)
}
fmt.Printf("Current value: %d\n", client.AddGet(5))
```
### Reverse Calling Feature
The go-jsonrpc library also supports reverse calling, where the server can make calls to the client. This is useful in scenarios where the server needs to notify or request data from the client.
NOTE: Reverse calling only works in websocket mode
#### Example Usage of Reverse Calling
Here is an example demonstrating how to set up reverse calling:
```go
// Define the client handler interface
type ClientHandler struct {
CallOnClient func(int) (int, error)
}
// Define the server handler
type ServerHandler struct {}
func (h *ServerHandler) Call(ctx context.Context) error {
revClient, ok := jsonrpc.ExtractReverseClient[ClientHandler](ctx)
if !ok {
return fmt.Errorf("no reverse client")
}
result, err := revClient.CallOnClient(7) // Multiply by 2 on client
if err != nil {
return fmt.Errorf("call on client: %w", err)
}
if result != 14 {
return fmt.Errorf("unexpected result: %d", result)
}
return nil
}
// Define client handler
type RevCallTestClientHandler struct {
}
func (h *RevCallTestClientHandler) CallOnClient(a int) (int, error) {
return a * 2, nil
}
// Setup server with reverse client capability
rpcServer := jsonrpc.NewServer(jsonrpc.WithReverseClient[ClientHandler]("Client"))
rpcServer.Register("ServerHandler", &ServerHandler{})
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
// Setup client with reverse call handler
var client struct {
Call func() error
}
closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ServerHandler", []interface{}{
&client,
}, nil, jsonrpc.WithClientHandler("Client", &RevCallTestClientHandler{}))
if err != nil {
log.Fatalf("Failed to create client: %v", err)
}
defer closer()
// Make a call from the client to the server, which will trigger a reverse call
if err := client.Call(); err != nil {
log.Fatalf("Failed to call server: %v", err)
}
```
## Options
### Using `WithServerMethodNameFormatter`
`WithServerMethodNameFormatter` allows you to customize a function that formats the JSON-RPC method name, given namespace and method name.
There are four possible options:
- `jsonrpc.DefaultMethodNameFormatter` - default method name formatter, e.g. `SimpleServerHandler.AddGet`
- `jsonrpc.NewMethodNameFormatter(true, jsonrpc.LowerFirstCharCase)` - method name formatter with namespace, e.g. `SimpleServerHandler.addGet`
- `jsonrpc.NewMethodNameFormatter(false, jsonrpc.OriginalCase)` - method name formatter without namespace, e.g. `AddGet`
- `jsonrpc.NewMethodNameFormatter(false, jsonrpc.LowerFirstCharCase)` - method name formatter without namespace and with the first char lowercased, e.g. `addGet`
> [!NOTE]
> The default method name formatter concatenates the namespace and method name with a dot.
> Go exported methods are capitalized, so, the method name will be capitalized as well.
> e.g. `SimpleServerHandler.AddGet` (capital "A" in "AddGet")
```go
func main() {
// create a new server instance with a custom separator
rpcServer := jsonrpc.NewServer(jsonrpc.WithServerMethodNameFormatter(
func(namespace, method string) string {
return namespace + "_" + method
}),
)
// create a handler instance and register it
serverHandler := &SimpleServerHandler{}
rpcServer.Register("SimpleServerHandler", serverHandler)
// serve the api
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
fmt.Println("URL: ", "ws://"+testServ.Listener.Addr().String())
// rpc method becomes SimpleServerHandler_AddGet
[..do other app stuff / wait..]
}
```
### Using `WithMethodNameFormatter`
`WithMethodNameFormatter` is the client-side counterpart to `WithServerMethodNameFormatter`.
```go
func main() {
closer, err := NewMergeClient(
context.Background(),
"http://example.com",
"SimpleServerHandler",
[]any{&client},
nil,
WithMethodNameFormatter(jsonrpc.NewMethodNameFormatter(false, OriginalCase)),
)
defer closer()
}
```
## Contribute
PRs are welcome!
## License
Dual-licensed under [MIT](https://github.com/filecoin-project/go-jsonrpc/blob/master/LICENSE-MIT) + [Apache 2.0](https://github.com/filecoin-project/go-jsonrpc/blob/master/LICENSE-APACHE)

View File

@@ -0,0 +1,79 @@
package auth
import (
"context"
"reflect"
"golang.org/x/xerrors"
)
type Permission string
type permKey int
var permCtxKey permKey
func WithPerm(ctx context.Context, perms []Permission) context.Context {
return context.WithValue(ctx, permCtxKey, perms)
}
func HasPerm(ctx context.Context, defaultPerms []Permission, perm Permission) bool {
callerPerms, ok := ctx.Value(permCtxKey).([]Permission)
if !ok {
callerPerms = defaultPerms
}
for _, callerPerm := range callerPerms {
if callerPerm == perm {
return true
}
}
return false
}
func PermissionedProxy(validPerms, defaultPerms []Permission, in interface{}, out interface{}) {
rint := reflect.ValueOf(out).Elem()
ra := reflect.ValueOf(in)
for f := 0; f < rint.NumField(); f++ {
field := rint.Type().Field(f)
requiredPerm := Permission(field.Tag.Get("perm"))
if requiredPerm == "" {
panic("missing 'perm' tag on " + field.Name) // ok
}
// Validate perm tag
ok := false
for _, perm := range validPerms {
if requiredPerm == perm {
ok = true
break
}
}
if !ok {
panic("unknown 'perm' tag on " + field.Name) // ok
}
fn := ra.MethodByName(field.Name)
rint.Field(f).Set(reflect.MakeFunc(field.Type, func(args []reflect.Value) (results []reflect.Value) {
ctx := args[0].Interface().(context.Context)
if HasPerm(ctx, defaultPerms, requiredPerm) {
return fn.Call(args)
}
err := xerrors.Errorf("missing permission to invoke '%s' (need '%s')", field.Name, requiredPerm)
rerr := reflect.ValueOf(&err).Elem()
if field.Type.NumOut() == 2 {
return []reflect.Value{
reflect.Zero(field.Type.Out(0)),
rerr,
}
} else {
return []reflect.Value{rerr}
}
}))
}
}

View File

@@ -0,0 +1,48 @@
package auth
import (
"context"
"net/http"
"strings"
logging "github.com/ipfs/go-log/v2"
)
var log = logging.Logger("auth")
type Handler struct {
Verify func(ctx context.Context, token string) ([]Permission, error)
Next http.HandlerFunc
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token := r.Header.Get("Authorization")
if token == "" {
token = r.FormValue("token")
if token != "" {
token = "Bearer " + token
}
}
if token != "" {
if !strings.HasPrefix(token, "Bearer ") {
log.Warn("missing Bearer prefix in auth header")
w.WriteHeader(401)
return
}
token = strings.TrimPrefix(token, "Bearer ")
allow, err := h.Verify(ctx, token)
if err != nil {
log.Warnf("JWT Verification failed (originating from %s): %s", r.RemoteAddr, err)
w.WriteHeader(401)
return
}
ctx = WithPerm(ctx, allow)
}
h.Next(w, r.WithContext(ctx))
}

View File

@@ -0,0 +1,751 @@
package jsonrpc
import (
"bytes"
"container/list"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"runtime/pprof"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
logging "github.com/ipfs/go-log/v2"
"go.opencensus.io/trace"
"go.opencensus.io/trace/propagation"
"golang.org/x/xerrors"
)
const (
methodMinRetryDelay = 100 * time.Millisecond
methodMaxRetryDelay = 10 * time.Minute
)
var (
errorType = reflect.TypeOf(new(error)).Elem()
contextType = reflect.TypeOf(new(context.Context)).Elem()
log = logging.Logger("rpc")
_defaultHTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
)
// ErrClient is an error which occurred on the client side the library
type ErrClient struct {
err error
}
func (e *ErrClient) Error() string {
return fmt.Sprintf("RPC client error: %s", e.err)
}
// Unwrap unwraps the actual error
func (e *ErrClient) Unwrap() error {
return e.err
}
type clientResponse struct {
Jsonrpc string `json:"jsonrpc"`
Result json.RawMessage `json:"result"`
ID interface{} `json:"id"`
Error *JSONRPCError `json:"error,omitempty"`
}
type makeChanSink func() (context.Context, func([]byte, bool))
type clientRequest struct {
req request
ready chan clientResponse
// retCh provides a context and sink for handling incoming channel messages
retCh makeChanSink
}
// ClientCloser is used to close Client from further use
type ClientCloser func()
// NewClient creates new jsonrpc 2.0 client
//
// handler must be pointer to a struct with function fields
// Returned value closes the client connection
// TODO: Example
func NewClient(ctx context.Context, addr string, namespace string, handler interface{}, requestHeader http.Header) (ClientCloser, error) {
return NewMergeClient(ctx, addr, namespace, []interface{}{handler}, requestHeader)
}
type client struct {
namespace string
paramEncoders map[reflect.Type]ParamEncoder
errors *Errors
doRequest func(context.Context, clientRequest) (clientResponse, error)
exiting <-chan struct{}
idCtr int64
methodNameFormatter MethodNameFormatter
}
// NewMergeClient is like NewClient, but allows to specify multiple structs
// to be filled in the same namespace, using one connection
func NewMergeClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, opts ...Option) (ClientCloser, error) {
config := defaultConfig()
for _, o := range opts {
o(&config)
}
u, err := url.Parse(addr)
if err != nil {
return nil, xerrors.Errorf("parsing address: %w", err)
}
switch u.Scheme {
case "ws", "wss":
return websocketClient(ctx, addr, namespace, outs, requestHeader, config)
case "http", "https":
return httpClient(ctx, addr, namespace, outs, requestHeader, config)
default:
return nil, xerrors.Errorf("unknown url scheme '%s'", u.Scheme)
}
}
// NewCustomClient is like NewMergeClient in single-request (http) mode, except it allows for a custom doRequest function
func NewCustomClient(namespace string, outs []interface{}, doRequest func(ctx context.Context, body []byte) (io.ReadCloser, error), opts ...Option) (ClientCloser, error) {
config := defaultConfig()
for _, o := range opts {
o(&config)
}
c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
methodNameFormatter: config.methodNamer,
}
stop := make(chan struct{})
c.exiting = stop
c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) {
b, err := json.Marshal(&cr.req)
if err != nil {
return clientResponse{}, xerrors.Errorf("marshalling request: %w", err)
}
if ctx == nil {
ctx = context.Background()
}
rawResp, err := doRequest(ctx, b)
if err != nil {
return clientResponse{}, xerrors.Errorf("doRequest failed: %w", err)
}
defer rawResp.Close()
var resp clientResponse
if cr.req.ID != nil { // non-notification
if err := json.NewDecoder(rawResp).Decode(&resp); err != nil {
return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err)
}
if resp.ID, err = normalizeID(resp.ID); err != nil {
return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err)
}
if resp.ID != cr.req.ID {
return clientResponse{}, xerrors.New("request and response id didn't match")
}
}
return resp, nil
}
if err := c.provide(outs); err != nil {
return nil, err
}
return func() {
close(stop)
}, nil
}
func httpClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, config Config) (ClientCloser, error) {
c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
methodNameFormatter: config.methodNamer,
}
stop := make(chan struct{})
c.exiting = stop
if requestHeader == nil {
requestHeader = http.Header{}
}
c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) {
b, err := json.Marshal(&cr.req)
if err != nil {
return clientResponse{}, xerrors.Errorf("marshalling request: %w", err)
}
hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b))
if err != nil {
return clientResponse{}, &RPCConnectionError{err}
}
hreq.Header = requestHeader.Clone()
if ctx != nil {
hreq = hreq.WithContext(ctx)
}
hreq.Header.Set("Content-Type", "application/json")
httpResp, err := config.httpClient.Do(hreq)
if err != nil {
return clientResponse{}, &RPCConnectionError{err}
}
// likely a failure outside of our control and ability to inspect; jsonrpc server only ever
// returns json format errors with either a StatusBadRequest or a StatusInternalServerError
if httpResp.StatusCode > http.StatusBadRequest && httpResp.StatusCode != http.StatusInternalServerError {
return clientResponse{}, xerrors.Errorf("request failed, http status %s", httpResp.Status)
}
defer httpResp.Body.Close()
var resp clientResponse
if cr.req.ID != nil { // non-notification
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return clientResponse{}, xerrors.Errorf("http status %s unmarshaling response: %w", httpResp.Status, err)
}
if resp.ID, err = normalizeID(resp.ID); err != nil {
return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err)
}
if resp.ID != cr.req.ID {
return clientResponse{}, xerrors.New("request and response id didn't match")
}
}
return resp, nil
}
if err := c.provide(outs); err != nil {
return nil, err
}
return func() {
close(stop)
}, nil
}
func websocketClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, config Config) (ClientCloser, error) {
connFactory := func() (*websocket.Conn, error) {
conn, _, err := websocket.DefaultDialer.Dial(addr, requestHeader)
if err != nil {
return nil, &RPCConnectionError{xerrors.Errorf("cannot dial address %s for %w", addr, err)}
}
return conn, nil
}
if config.proxyConnFactory != nil {
// used in tests
connFactory = config.proxyConnFactory(connFactory)
}
conn, err := connFactory()
if err != nil {
return nil, err
}
if config.noReconnect {
connFactory = nil
}
c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
methodNameFormatter: config.methodNamer,
}
requests := c.setupRequestChan()
stop := make(chan struct{})
exiting := make(chan struct{})
c.exiting = exiting
var hnd reqestHandler
if len(config.reverseHandlers) > 0 {
h := makeHandler(defaultServerConfig())
h.aliasedMethods = config.aliasedHandlerMethods
for _, reverseHandler := range config.reverseHandlers {
h.register(reverseHandler.ns, reverseHandler.hnd)
}
hnd = h
}
wconn := &wsConn{
conn: conn,
connFactory: connFactory,
reconnectBackoff: config.reconnectBackoff,
pingInterval: config.pingInterval,
timeout: config.timeout,
handler: hnd,
requests: requests,
stop: stop,
exiting: exiting,
reconfun: config.reconnfun,
}
go func() {
lbl := pprof.Labels("jrpc-mode", "wsclient", "jrpc-remote", addr, "jrpc-local", conn.LocalAddr().String(), "jrpc-uuid", uuid.New().String())
pprof.Do(ctx, lbl, func(ctx context.Context) {
wconn.handleWsConn(ctx)
})
}()
if err := c.provide(outs); err != nil {
return nil, err
}
return func() {
close(stop)
<-exiting
}, nil
}
func (c *client) setupRequestChan() chan clientRequest {
requests := make(chan clientRequest)
c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) {
select {
case requests <- cr:
case <-c.exiting:
return clientResponse{}, fmt.Errorf("websocket routine exiting")
}
var ctxDone <-chan struct{}
var resp clientResponse
if ctx != nil {
ctxDone = ctx.Done()
}
// wait for response, handle context cancellation
loop:
for {
select {
case resp = <-cr.ready:
break loop
case <-ctxDone: // send cancel request
ctxDone = nil
rp, err := json.Marshal([]param{{v: reflect.ValueOf(cr.req.ID)}})
if err != nil {
return clientResponse{}, xerrors.Errorf("marshalling cancel request: %w", err)
}
cancelReq := clientRequest{
req: request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: rp,
},
ready: make(chan clientResponse, 1),
}
select {
case requests <- cancelReq:
case <-c.exiting:
log.Warn("failed to send request cancellation, websocket routing exited")
}
}
}
return resp, nil
}
return requests
}
func (c *client) provide(outs []interface{}) error {
for _, handler := range outs {
htyp := reflect.TypeOf(handler)
if htyp.Kind() != reflect.Ptr {
return xerrors.New("expected handler to be a pointer")
}
typ := htyp.Elem()
if typ.Kind() != reflect.Struct {
return xerrors.New("handler should be a struct")
}
val := reflect.ValueOf(handler)
for i := 0; i < typ.NumField(); i++ {
fn, err := c.makeRpcFunc(typ.Field(i))
if err != nil {
return err
}
val.Elem().Field(i).Set(fn)
}
}
return nil
}
func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int) (func() reflect.Value, makeChanSink) {
retVal := reflect.Zero(ftyp.Out(valOut))
chCtor := func() (context.Context, func([]byte, bool)) {
// unpack chan type to make sure it's reflect.BothDir
ctyp := reflect.ChanOf(reflect.BothDir, ftyp.Out(valOut).Elem())
ch := reflect.MakeChan(ctyp, 0) // todo: buffer?
retVal = ch.Convert(ftyp.Out(valOut))
incoming := make(chan reflect.Value, 32)
// gorotuine to handle buffering of items
go func() {
buf := (&list.List{}).Init()
for {
front := buf.Front()
cases := []reflect.SelectCase{
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
},
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(incoming),
},
}
if front != nil {
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectSend,
Chan: ch,
Send: front.Value.(reflect.Value).Elem(),
})
}
chosen, val, ok := reflect.Select(cases)
switch chosen {
case 0:
ch.Close()
return
case 1:
if ok {
vvval := val.Interface().(reflect.Value)
buf.PushBack(vvval)
if buf.Len() > 1 {
if buf.Len() > 10 {
log.Warnw("rpc output message buffer", "n", buf.Len())
} else {
log.Debugw("rpc output message buffer", "n", buf.Len())
}
}
} else {
incoming = nil
}
case 2:
buf.Remove(front)
}
if incoming == nil && buf.Len() == 0 {
ch.Close()
return
}
}
}()
return ctx, func(result []byte, ok bool) {
if !ok {
close(incoming)
return
}
val := reflect.New(ftyp.Out(valOut).Elem())
if err := json.Unmarshal(result, val.Interface()); err != nil {
log.Errorf("error unmarshaling chan response: %s", err)
return
}
if ctx.Err() != nil {
log.Errorf("got rpc message with cancelled context: %s", ctx.Err())
return
}
select {
case incoming <- val:
case <-ctx.Done():
}
}
}
return func() reflect.Value { return retVal }, chCtor
}
func (c *client) sendRequest(ctx context.Context, req request, chCtor makeChanSink) (clientResponse, error) {
creq := clientRequest{
req: req,
ready: make(chan clientResponse, 1),
retCh: chCtor,
}
return c.doRequest(ctx, creq)
}
type rpcFunc struct {
client *client
ftyp reflect.Type
name string
nout int
valOut int
errOut int
// hasCtx is 1 if the function has a context.Context as its first argument.
// Used as the number of the first non-context argument.
hasCtx int
hasRawParams bool
returnValueIsChannel bool
retry bool
notify bool
}
func (fn *rpcFunc) processResponse(resp clientResponse, rval reflect.Value) []reflect.Value {
out := make([]reflect.Value, fn.nout)
if fn.valOut != -1 {
out[fn.valOut] = rval
}
if fn.errOut != -1 {
out[fn.errOut] = reflect.New(errorType).Elem()
if resp.Error != nil {
out[fn.errOut].Set(resp.Error.val(fn.client.errors))
}
}
return out
}
func (fn *rpcFunc) processError(err error) []reflect.Value {
out := make([]reflect.Value, fn.nout)
if fn.valOut != -1 {
out[fn.valOut] = reflect.New(fn.ftyp.Out(fn.valOut)).Elem()
}
if fn.errOut != -1 {
out[fn.errOut] = reflect.New(errorType).Elem()
out[fn.errOut].Set(reflect.ValueOf(&ErrClient{err}))
}
return out
}
func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) {
var id interface{}
if !fn.notify {
id = atomic.AddInt64(&fn.client.idCtr, 1)
// Prepare the ID to send on the wire.
// We track int64 ids as float64 in the inflight map (because that's what
// they'll be decoded to). encoding/json outputs numbers with their minimal
// encoding, avoding the decimal point when possible, i.e. 3 will never get
// converted to 3.0.
var err error
id, err = normalizeID(id)
if err != nil {
return fn.processError(fmt.Errorf("failed to normalize id")) // should probably panic
}
}
var serializedParams json.RawMessage
if fn.hasRawParams {
serializedParams = json.RawMessage(args[fn.hasCtx].Interface().(RawParams))
} else {
params := make([]param, len(args)-fn.hasCtx)
for i, arg := range args[fn.hasCtx:] {
enc, found := fn.client.paramEncoders[arg.Type()]
if found {
// custom param encoder
var err error
arg, err = enc(arg)
if err != nil {
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}
}
params[i] = param{
v: arg,
}
}
var err error
serializedParams, err = json.Marshal(params)
if err != nil {
return fn.processError(fmt.Errorf("marshaling params failed: %w", err))
}
}
var ctx context.Context
var span *trace.Span
if fn.hasCtx == 1 {
ctx = args[0].Interface().(context.Context)
ctx, span = trace.StartSpan(ctx, "api.call")
defer span.End()
}
retVal := func() reflect.Value { return reflect.Value{} }
// if the function returns a channel, we need to provide a sink for the
// messages
var chCtor makeChanSink
if fn.returnValueIsChannel {
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
}
req := request{
Jsonrpc: "2.0",
ID: id,
Method: fn.name,
Params: serializedParams,
}
if span != nil {
span.AddAttributes(trace.StringAttribute("method", req.Method))
eSC := base64.StdEncoding.EncodeToString(
propagation.Binary(span.SpanContext()))
req.Meta = map[string]string{
"SpanContext": eSC,
}
}
b := backoff{
maxDelay: methodMaxRetryDelay,
minDelay: methodMinRetryDelay,
}
var err error
var resp clientResponse
// keep retrying if got a forced closed websocket conn and calling method
// has retry annotation
for attempt := 0; true; attempt++ {
resp, err = fn.client.sendRequest(ctx, req, chCtor)
if err != nil {
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}
if !fn.notify && resp.ID != req.ID {
return fn.processError(xerrors.New("request and response id didn't match"))
}
if fn.valOut != -1 && !fn.returnValueIsChannel {
val := reflect.New(fn.ftyp.Out(fn.valOut))
if resp.Result != nil {
log.Debugw("rpc result", "type", fn.ftyp.Out(fn.valOut))
if err := json.Unmarshal(resp.Result, val.Interface()); err != nil {
log.Warnw("unmarshaling failed", "message", string(resp.Result))
return fn.processError(xerrors.Errorf("unmarshaling result: %w", err))
}
}
retVal = func() reflect.Value { return val.Elem() }
}
retry := resp.Error != nil && resp.Error.Code == eTempWSError && fn.retry
if !retry {
break
}
time.Sleep(b.next(attempt))
}
return fn.processResponse(resp, retVal())
}
const (
ProxyTagRetry = "retry"
ProxyTagNotify = "notify"
ProxyTagRPCMethod = "rpc_method"
)
func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
ftyp := f.Type
if ftyp.Kind() != reflect.Func {
return reflect.Value{}, xerrors.New("handler field not a func")
}
name := c.methodNameFormatter(c.namespace, f.Name)
if tag, ok := f.Tag.Lookup(ProxyTagRPCMethod); ok {
name = tag
}
fun := &rpcFunc{
client: c,
ftyp: ftyp,
name: name,
retry: f.Tag.Get(ProxyTagRetry) == "true",
notify: f.Tag.Get(ProxyTagNotify) == "true",
}
fun.valOut, fun.errOut, fun.nout = processFuncOut(ftyp)
if fun.valOut != -1 && fun.notify {
return reflect.Value{}, xerrors.New("notify methods cannot return values")
}
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
fun.hasCtx = 1
}
// note: hasCtx is also the number of the first non-context argument
if ftyp.NumIn() > fun.hasCtx && ftyp.In(fun.hasCtx) == rtRawParams {
if ftyp.NumIn() > fun.hasCtx+1 {
return reflect.Value{}, xerrors.New("raw params can't be mixed with other arguments")
}
fun.hasRawParams = true
}
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
}

View File

@@ -0,0 +1,65 @@
package jsonrpc
import (
"encoding/json"
"errors"
"reflect"
)
const eTempWSError = -1111111
type RPCConnectionError struct {
err error
}
func (e *RPCConnectionError) Error() string {
if e.err != nil {
return e.err.Error()
}
return "RPCConnectionError"
}
func (e *RPCConnectionError) Unwrap() error {
if e.err != nil {
return e.err
}
return errors.New("RPCConnectionError")
}
type Errors struct {
byType map[reflect.Type]ErrorCode
byCode map[ErrorCode]reflect.Type
}
type ErrorCode int
const FirstUserCode = 2
func NewErrors() Errors {
return Errors{
byType: map[reflect.Type]ErrorCode{},
byCode: map[ErrorCode]reflect.Type{
-1111111: reflect.TypeOf(&RPCConnectionError{}),
},
}
}
func (e *Errors) Register(c ErrorCode, typ interface{}) {
rt := reflect.TypeOf(typ).Elem()
if !rt.Implements(errorType) {
panic("can't register non-error types")
}
e.byType[rt] = c
e.byCode[c] = rt
}
type marshalable interface {
json.Marshaler
json.Unmarshaler
}
type RPCErrorCodec interface {
FromJSONRPCError(JSONRPCError) error
ToJSONRPCError() (JSONRPCError, error)
}

View File

@@ -0,0 +1,23 @@
module github.com/filecoin-project/go-jsonrpc
go 1.20
require (
github.com/google/uuid v1.1.1
github.com/gorilla/mux v1.7.4
github.com/gorilla/websocket v1.4.2
github.com/ipfs/go-log/v2 v2.0.8
github.com/stretchr/testify v1.5.1
go.opencensus.io v0.22.3
go.uber.org/zap v1.14.1
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.uber.org/atomic v1.6.0 // indirect
go.uber.org/multierr v1.5.0 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect
)

View File

@@ -0,0 +1,102 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/ipfs/go-log/v2 v2.0.8 h1:3b3YNopMHlj4AvyhWAx0pDxqSQWYi4/WuWO7yRV6/Qg=
github.com/ipfs/go-log/v2 v2.0.8/go.mod h1:eZs4Xt4ZUJQFM3DlanGhy7TkwwawCZcSByscwkWG+dw=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
go.opencensus.io v0.22.3 h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=

View File

@@ -0,0 +1,513 @@
package jsonrpc
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
"go.opencensus.io/stats"
"go.opencensus.io/tag"
"go.opencensus.io/trace"
"go.opencensus.io/trace/propagation"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/xerrors"
"github.com/filecoin-project/go-jsonrpc/metrics"
)
type RawParams json.RawMessage
var rtRawParams = reflect.TypeOf(RawParams{})
// todo is there a better way to tell 'struct with any number of fields'?
func DecodeParams[T any](p RawParams) (T, error) {
var t T
err := json.Unmarshal(p, &t)
// todo also handle list-encoding automagically (json.Unmarshal doesn't do that, does it?)
return t, err
}
// methodHandler is a handler for a single method
type methodHandler struct {
paramReceivers []reflect.Type
nParams int
receiver reflect.Value
handlerFunc reflect.Value
hasCtx int
hasRawParams bool
errOut int
valOut int
}
// Request / response
type request struct {
Jsonrpc string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
Meta map[string]string `json:"meta,omitempty"`
}
// Limit request size. Ideally this limit should be specific for each field
// in the JSON request but as a simple defensive measure we just limit the
// entire HTTP body.
// Configured by WithMaxRequestSize.
const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB
type handler struct {
methods map[string]methodHandler
errors *Errors
maxRequestSize int64
// aliasedMethods contains a map of alias:original method names.
// These are used as fallbacks if a method is not found by the given method name.
aliasedMethods map[string]string
paramDecoders map[reflect.Type]ParamDecoder
methodNameFormatter MethodNameFormatter
tracer Tracer
}
type Tracer func(method string, params []reflect.Value, results []reflect.Value, err error)
func makeHandler(sc ServerConfig) *handler {
return &handler{
methods: make(map[string]methodHandler),
errors: sc.errors,
aliasedMethods: map[string]string{},
paramDecoders: sc.paramDecoders,
methodNameFormatter: sc.methodNameFormatter,
maxRequestSize: sc.maxRequestSize,
tracer: sc.tracer,
}
}
// Register
func (s *handler) register(namespace string, r interface{}) {
val := reflect.ValueOf(r)
// TODO: expect ptr
for i := 0; i < val.NumMethod(); i++ {
method := val.Type().Method(i)
funcType := method.Func.Type()
hasCtx := 0
if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
hasCtx = 1
}
hasRawParams := false
ins := funcType.NumIn() - 1 - hasCtx
recvs := make([]reflect.Type, ins)
for i := 0; i < ins; i++ {
if hasRawParams && i > 0 {
panic("raw params must be the last parameter")
}
if funcType.In(i+1+hasCtx) == rtRawParams {
hasRawParams = true
}
recvs[i] = method.Type.In(i + 1 + hasCtx)
}
valOut, errOut, _ := processFuncOut(funcType)
s.methods[s.methodNameFormatter(namespace, method.Name)] = methodHandler{
paramReceivers: recvs,
nParams: ins,
handlerFunc: method.Func,
receiver: val,
hasCtx: hasCtx,
hasRawParams: hasRawParams,
errOut: errOut,
valOut: valOut,
}
}
}
// Handle
type rpcErrFunc func(w func(func(io.Writer)), req *request, code ErrorCode, err error)
type chanOut func(reflect.Value, interface{}) error
func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
wf := func(cb func(io.Writer)) {
cb(w)
}
// We read the entire request upfront in a buffer to be able to tell if the
// client sent more than maxRequestSize and report it back as an explicit error,
// instead of just silently truncating it and reporting a more vague parsing
// error.
bufferedRequest := new(bytes.Buffer)
// We use LimitReader to enforce maxRequestSize. Since it won't return an
// EOF we can't actually know if the client sent more than the maximum or
// not, so we read one byte more over the limit to explicitly query that.
// FIXME: Maybe there's a cleaner way to do this.
reqSize, err := bufferedRequest.ReadFrom(io.LimitReader(r, s.maxRequestSize+1))
if err != nil {
// ReadFrom will discard EOF so any error here is unexpected and should
// be reported.
rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err))
return
}
if reqSize > s.maxRequestSize {
rpcError(wf, nil, rpcParseError,
// rpcParseError is the closest we have from the standard errors defined
// in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object)
// to report the maximum limit.
xerrors.Errorf("request bigger than maximum %d allowed",
s.maxRequestSize))
return
}
// Trim spaces to avoid issues with batch request detection.
bufferedRequest = bytes.NewBuffer(bytes.TrimSpace(bufferedRequest.Bytes()))
reqSize = int64(bufferedRequest.Len())
if reqSize == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}
if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' {
var reqs []request
if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil {
rpcError(wf, nil, rpcParseError, xerrors.New("Parse error"))
return
}
if len(reqs) == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}
_, _ = w.Write([]byte("[")) // todo consider handling this error
for idx, req := range reqs {
if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
if idx != len(reqs)-1 {
_, _ = w.Write([]byte(",")) // todo consider handling this error
}
}
_, _ = w.Write([]byte("]")) // todo consider handling this error
} else {
var req request
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.New("Parse error"))
return
}
if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
}
}
func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) {
defer func() {
if i := recover(); i != nil {
err = xerrors.Errorf("panic in rpc method '%s': %s", methodName, i)
log.Desugar().WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)).Sugar().Error(err)
}
}()
out = f.Call(params)
return out, nil
}
func (s *handler) getSpan(ctx context.Context, req request) (context.Context, *trace.Span) {
if req.Meta == nil {
return ctx, nil
}
var span *trace.Span
if eSC, ok := req.Meta["SpanContext"]; ok {
bSC := make([]byte, base64.StdEncoding.DecodedLen(len(eSC)))
_, err := base64.StdEncoding.Decode(bSC, []byte(eSC))
if err != nil {
log.Errorf("SpanContext: decode", "error", err)
return ctx, nil
}
sc, ok := propagation.FromBinary(bSC)
if !ok {
log.Errorf("SpanContext: could not create span", "data", bSC)
return ctx, nil
}
ctx, span = trace.StartSpanWithRemoteParent(ctx, "api.handle", sc)
} else {
ctx, span = trace.StartSpan(ctx, "api.handle")
}
span.AddAttributes(trace.StringAttribute("method", req.Method))
return ctx, span
}
func (s *handler) createError(err error) *JSONRPCError {
var code ErrorCode = 1
if s.errors != nil {
c, ok := s.errors.byType[reflect.TypeOf(err)]
if ok {
code = c
}
}
out := &JSONRPCError{
Code: code,
Message: err.Error(),
}
switch m := err.(type) {
case RPCErrorCodec:
o, err := m.ToJSONRPCError()
if err != nil {
log.Errorf("Failed to convert error to JSONRPCError: %w", err)
} else {
out = &o
}
case marshalable:
meta, marshalErr := m.MarshalJSON()
if marshalErr == nil {
out.Meta = meta
} else {
log.Errorf("Failed to marshal error metadata: %w", marshalErr)
}
}
return out
}
func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) {
// Not sure if we need to sanitize the incoming req.Method or not.
ctx, span := s.getSpan(ctx, req)
ctx, _ = tag.New(ctx, tag.Insert(metrics.RPCMethod, req.Method))
defer span.End()
handler, ok := s.methods[req.Method]
if !ok {
aliasTo, ok := s.aliasedMethods[req.Method]
if ok {
handler, ok = s.methods[aliasTo]
}
if !ok {
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
stats.Record(ctx, metrics.RPCInvalidMethod.M(1))
done(false)
return
}
}
outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan
defer done(outCh)
if chOut == nil && outCh {
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not supported in this mode (no out channel support)", req.Method))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
callParams[0] = handler.receiver
if handler.hasCtx == 1 {
callParams[1] = reflect.ValueOf(ctx)
}
if handler.hasRawParams {
// When hasRawParams is true, there is only one parameter and it is a
// json.RawMessage.
callParams[1+handler.hasCtx] = reflect.ValueOf(RawParams(req.Params))
} else {
// "normal" param list; no good way to do named params in Golang
var ps []param
if len(req.Params) > 0 {
err := json.Unmarshal(req.Params, &ps)
if err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling param array: %w", err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
}
if len(ps) != handler.nParams {
rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams))
stats.Record(ctx, metrics.RPCRequestError.M(1))
done(false)
return
}
for i := 0; i < handler.nParams; i++ {
var rp reflect.Value
typ := handler.paramReceivers[i]
dec, found := s.paramDecoders[typ]
if !found {
rp = reflect.New(typ)
if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
rp = rp.Elem()
} else {
var err error
rp, err = dec(ctx, ps[i].data)
if err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
}
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface())
}
}
// /////////////////
callResult, err := doCall(req.Method, handler.handlerFunc, callParams)
if err != nil {
rpcError(w, &req, 0, xerrors.Errorf("fatal error calling '%s': %w", req.Method, err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
if s.tracer != nil {
s.tracer(req.Method, callParams, nil, err)
}
return
}
if req.ID == nil {
return // notification
}
if s.tracer != nil {
s.tracer(req.Method, callParams, callResult, nil)
}
// /////////////////
resp := response{
Jsonrpc: "2.0",
ID: req.ID,
}
if handler.errOut != -1 {
err := callResult[handler.errOut].Interface()
if err != nil {
log.Warnf("error in RPC call to '%s': %+v", req.Method, err)
stats.Record(ctx, metrics.RPCResponseError.M(1))
resp.Error = s.createError(err.(error))
}
}
var kind reflect.Kind
var res interface{}
var nonZero bool
if handler.valOut != -1 {
res = callResult[handler.valOut].Interface()
kind = callResult[handler.valOut].Kind()
nonZero = !callResult[handler.valOut].IsZero()
}
// check error as JSON-RPC spec prohibits error and value at the same time
if resp.Error == nil {
if res != nil && kind == reflect.Chan {
// Channel responses are sent from channel control goroutine.
// Sending responses here could cause deadlocks on writeLk, or allow
// sending channel messages before this rpc call returns
//noinspection GoNilness // already checked above
err = chOut(callResult[handler.valOut], req.ID)
if err == nil {
return // channel goroutine handles responding
}
log.Warnf("failed to setup channel in RPC call to '%s': %+v", req.Method, err)
stats.Record(ctx, metrics.RPCResponseError.M(1))
resp.Error = &JSONRPCError{
Code: 1,
Message: err.Error(),
}
} else {
resp.Result = res
}
}
if resp.Error != nil && nonZero {
log.Errorw("error and res returned", "request", req, "r.err", resp.Error, "res", res)
}
withLazyWriter(w, func(w io.Writer) {
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error(err)
stats.Record(ctx, metrics.RPCResponseError.M(1))
return
}
})
}
// withLazyWriter makes it possible to defer acquiring a writer until the first write.
// This is useful because json.Encode needs to marshal the response fully before writing, which may be
// a problem for very large responses.
func withLazyWriter(withWriterFunc func(func(io.Writer)), cb func(io.Writer)) {
lw := &lazyWriter{
withWriterFunc: withWriterFunc,
done: make(chan struct{}),
}
defer close(lw.done)
cb(lw)
}
type lazyWriter struct {
withWriterFunc func(func(io.Writer))
w io.Writer
done chan struct{}
}
func (lw *lazyWriter) Write(p []byte) (n int, err error) {
if lw.w == nil {
acquired := make(chan struct{})
go func() {
lw.withWriterFunc(func(w io.Writer) {
lw.w = w
close(acquired)
<-lw.done
})
}()
<-acquired
}
return lw.w.Write(p)
}

View File

@@ -0,0 +1,2 @@
This package provides param encoders / decoders for `io.Reader` which proxy
data over temporary http endpoints

View File

@@ -0,0 +1,142 @@
package httpio
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"reflect"
"sync"
"github.com/google/uuid"
logging "github.com/ipfs/go-log/v2"
"golang.org/x/xerrors"
"github.com/filecoin-project/go-jsonrpc"
)
var log = logging.Logger("rpc")
func ReaderParamEncoder(addr string) jsonrpc.Option {
return jsonrpc.WithParamEncoder(new(io.Reader), func(value reflect.Value) (reflect.Value, error) {
r := value.Interface().(io.Reader)
reqID := uuid.New()
u, _ := url.Parse(addr)
u.Path = path.Join(u.Path, reqID.String())
go func() {
// TODO: figure out errors here
resp, err := http.Post(u.String(), "application/octet-stream", r)
if err != nil {
log.Errorf("sending reader param: %+v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
log.Errorf("sending reader param: non-200 status: ", resp.Status)
return
}
}()
return reflect.ValueOf(reqID), nil
})
}
type waitReadCloser struct {
io.ReadCloser
wait chan struct{}
}
func (w *waitReadCloser) Read(p []byte) (int, error) {
n, err := w.ReadCloser.Read(p)
if err != nil {
close(w.wait)
}
return n, err
}
func (w *waitReadCloser) Close() error {
close(w.wait)
return w.ReadCloser.Close()
}
func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
var readersLk sync.Mutex
readers := map[uuid.UUID]chan *waitReadCloser{}
hnd := func(resp http.ResponseWriter, req *http.Request) {
strId := path.Base(req.URL.Path)
u, err := uuid.Parse(strId)
if err != nil {
http.Error(resp, fmt.Sprintf("parsing reader uuid: %s", err), 400)
}
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *waitReadCloser)
readers[u] = ch
}
readersLk.Unlock()
wr := &waitReadCloser{
ReadCloser: req.Body,
wait: make(chan struct{}),
}
select {
case ch <- wr:
case <-req.Context().Done():
log.Error("context error in reader stream handler (1): %v", req.Context().Err())
resp.WriteHeader(500)
return
}
select {
case <-wr.wait:
case <-req.Context().Done():
log.Error("context error in reader stream handler (2): %v", req.Context().Err())
resp.WriteHeader(500)
return
}
resp.WriteHeader(200)
}
dec := jsonrpc.WithParamDecoder(new(io.Reader), func(ctx context.Context, b []byte) (reflect.Value, error) {
var strId string
if err := json.Unmarshal(b, &strId); err != nil {
return reflect.Value{}, xerrors.Errorf("unmarshaling reader id: %w", err)
}
u, err := uuid.Parse(strId)
if err != nil {
return reflect.Value{}, xerrors.Errorf("parsing reader UUDD: %w", err)
}
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *waitReadCloser)
readers[u] = ch
}
readersLk.Unlock()
select {
case wr := <-ch:
return reflect.ValueOf(wr), nil
case <-ctx.Done():
return reflect.Value{}, ctx.Err()
}
})
return hnd, dec
}

View File

@@ -0,0 +1,54 @@
package httpio
import (
"context"
"io"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/filecoin-project/go-jsonrpc"
)
type ReaderHandler struct {
}
func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) {
return io.ReadAll(r)
}
func (h *ReaderHandler) ReadUrl(ctx context.Context, u string) (string, error) {
return u, nil
}
func TestReaderProxy(t *testing.T) {
var client struct {
ReadAll func(ctx context.Context, r io.Reader) ([]byte, error)
}
serverHandler := &ReaderHandler{}
readerHandler, readerServerOpt := ReaderParamDecoder()
rpcServer := jsonrpc.NewServer(readerServerOpt)
rpcServer.Register("ReaderHandler", serverHandler)
mux := mux.NewRouter()
mux.Handle("/rpc/v0", rpcServer)
mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler)
testServ := httptest.NewServer(mux)
defer testServ.Close()
re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push")
closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&client}, nil, re)
require.NoError(t, err)
defer closer()
read, err := client.ReadAll(context.TODO(), strings.NewReader("pooooootato"))
require.NoError(t, err)
require.Equal(t, "pooooootato", string(read), "potatos weren't equal")
}

View File

@@ -0,0 +1,38 @@
package jsonrpc
import "strings"
// MethodNameFormatter is a function that takes a namespace and a method name and returns the full method name, sent via JSON-RPC.
// This is useful if you want to customize the default behaviour, e.g. send without the namespace or make it lowercase.
type MethodNameFormatter func(namespace, method string) string
// CaseStyle represents the case style for method names.
type CaseStyle int
const (
OriginalCase CaseStyle = iota
LowerFirstCharCase
AllFirstCharCase
)
// NewMethodNameFormatter creates a new method name formatter based on the provided options.
func NewMethodNameFormatter(includeNamespace bool, nameCase CaseStyle) MethodNameFormatter {
return func(namespace, method string) string {
formattedMethod := method
if nameCase == LowerFirstCharCase && len(method) > 0 {
formattedMethod = strings.ToLower(method[:1]) + method[1:]
}
if nameCase == AllFirstCharCase {
return strings.ToLower(namespace + "." + formattedMethod)
}
if includeNamespace {
return namespace + "." + formattedMethod
}
return formattedMethod
}
}
// DefaultMethodNameFormatter is a pass-through formatter with default options.
var DefaultMethodNameFormatter = NewMethodNameFormatter(false, OriginalCase)

View File

@@ -0,0 +1,125 @@
package jsonrpc
import (
"context"
"fmt"
"github.com/stretchr/testify/require"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestDifferentMethodNamers(t *testing.T) {
tests := map[string]struct {
namer MethodNameFormatter
requestedMethod string
}{
"default namer": {
namer: DefaultMethodNameFormatter,
requestedMethod: "SimpleServerHandler.Inc",
},
"lower fist char": {
namer: NewMethodNameFormatter(true, LowerFirstCharCase),
requestedMethod: "SimpleServerHandler.inc",
},
"no namespace namer": {
namer: NewMethodNameFormatter(false, OriginalCase),
requestedMethod: "Inc",
},
"no namespace & lower fist char": {
namer: NewMethodNameFormatter(false, LowerFirstCharCase),
requestedMethod: "inc",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
rpcServer := NewServer(WithServerMethodNameFormatter(test.namer))
serverHandler := &SimpleServerHandler{}
rpcServer.Register("SimpleServerHandler", serverHandler)
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
req := fmt.Sprintf(`{"jsonrpc": "2.0", "method": "%s", "params": [], "id": 1}`, test.requestedMethod)
res, err := http.Post(testServ.URL, "application/json", strings.NewReader(req))
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, int32(1), serverHandler.n)
})
}
}
func TestDifferentMethodNamersWithClient(t *testing.T) {
tests := map[string]struct {
namer MethodNameFormatter
urlPrefix string
}{
"default namer & http": {
namer: DefaultMethodNameFormatter,
urlPrefix: "http://",
},
"default namer & ws": {
namer: DefaultMethodNameFormatter,
urlPrefix: "ws://",
},
"lower first char namer & http": {
namer: NewMethodNameFormatter(true, LowerFirstCharCase),
urlPrefix: "http://",
},
"lower first char namer & ws": {
namer: NewMethodNameFormatter(true, LowerFirstCharCase),
urlPrefix: "ws://",
},
"no namespace namer & http": {
namer: NewMethodNameFormatter(false, OriginalCase),
urlPrefix: "http://",
},
"no namespace namer & ws": {
namer: NewMethodNameFormatter(false, OriginalCase),
urlPrefix: "ws://",
},
"no namespace & lower first char & http": {
namer: NewMethodNameFormatter(false, LowerFirstCharCase),
urlPrefix: "http://",
},
"no namespace & lower first char & ws": {
namer: NewMethodNameFormatter(false, LowerFirstCharCase),
urlPrefix: "ws://",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
rpcServer := NewServer(WithServerMethodNameFormatter(test.namer))
serverHandler := &SimpleServerHandler{}
rpcServer.Register("SimpleServerHandler", serverHandler)
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
var client struct {
AddGet func(int) int
}
closer, err := NewMergeClient(
context.Background(),
test.urlPrefix+testServ.Listener.Addr().String(),
"SimpleServerHandler",
[]any{&client},
nil,
WithHTTPClient(testServ.Client()),
WithMethodNameFormatter(test.namer),
)
require.NoError(t, err)
defer closer()
n := client.AddGet(123)
require.Equal(t, 123, n)
})
}
}

View File

@@ -0,0 +1,45 @@
package metrics
import (
"go.opencensus.io/stats"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
)
// Global Tags
var (
RPCMethod, _ = tag.NewKey("method")
)
// Measures
var (
RPCInvalidMethod = stats.Int64("rpc/invalid_method", "Total number of invalid RPC methods called", stats.UnitDimensionless)
RPCRequestError = stats.Int64("rpc/request_error", "Total number of request errors handled", stats.UnitDimensionless)
RPCResponseError = stats.Int64("rpc/response_error", "Total number of responses errors handled", stats.UnitDimensionless)
)
var (
// All RPC related metrics should at the very least tag the RPCMethod
RPCInvalidMethodView = &view.View{
Measure: RPCInvalidMethod,
Aggregation: view.Count(),
TagKeys: []tag.Key{RPCMethod},
}
RPCRequestErrorView = &view.View{
Measure: RPCRequestError,
Aggregation: view.Count(),
TagKeys: []tag.Key{RPCMethod},
}
RPCResponseErrorView = &view.View{
Measure: RPCResponseError,
Aggregation: view.Count(),
TagKeys: []tag.Key{RPCMethod},
}
)
// DefaultViews is an array of OpenCensus views for metric gathering purposes
var DefaultViews = []*view.View{
RPCInvalidMethodView,
RPCRequestErrorView,
RPCResponseErrorView,
}

View File

@@ -0,0 +1,128 @@
package jsonrpc
import (
"net/http"
"reflect"
"time"
"github.com/gorilla/websocket"
)
type ParamEncoder func(reflect.Value) (reflect.Value, error)
type clientHandler struct {
ns string
hnd interface{}
}
type Config struct {
reconnectBackoff backoff
pingInterval time.Duration
timeout time.Duration
paramEncoders map[reflect.Type]ParamEncoder
errors *Errors
reverseHandlers []clientHandler
aliasedHandlerMethods map[string]string
httpClient *http.Client
noReconnect bool
proxyConnFactory func(func() (*websocket.Conn, error)) func() (*websocket.Conn, error) // for testing
methodNamer MethodNameFormatter
reconnfun func()
}
func defaultConfig() Config {
return Config{
reconnectBackoff: backoff{
minDelay: 100 * time.Millisecond,
maxDelay: 5 * time.Second,
},
pingInterval: 5 * time.Second,
timeout: 30 * time.Second,
aliasedHandlerMethods: map[string]string{},
paramEncoders: map[reflect.Type]ParamEncoder{},
httpClient: _defaultHTTPClient,
methodNamer: DefaultMethodNameFormatter,
}
}
type Option func(c *Config)
func WithReconnectBackoff(minDelay, maxDelay time.Duration) func(c *Config) {
return func(c *Config) {
c.reconnectBackoff = backoff{
minDelay: minDelay,
maxDelay: maxDelay,
}
}
}
// Must be < Timeout/2
func WithPingInterval(d time.Duration) func(c *Config) {
return func(c *Config) {
c.pingInterval = d
}
}
func WithTimeout(d time.Duration) func(c *Config) {
return func(c *Config) {
c.timeout = d
}
}
func WithNoReconnect() func(c *Config) {
return func(c *Config) {
c.noReconnect = true
}
}
func WithParamEncoder(t interface{}, encoder ParamEncoder) func(c *Config) {
return func(c *Config) {
c.paramEncoders[reflect.TypeOf(t).Elem()] = encoder
}
}
func WithErrors(es Errors) func(c *Config) {
return func(c *Config) {
c.errors = &es
}
}
func WithClientHandler(ns string, hnd interface{}) func(c *Config) {
return func(c *Config) {
c.reverseHandlers = append(c.reverseHandlers, clientHandler{ns, hnd})
}
}
func WithReconnFun(s func()) func(c *Config) {
return func(c *Config) {
c.reconnfun = s
}
}
// WithClientHandlerAlias creates an alias for a client HANDLER method - for handlers created
// with WithClientHandler
func WithClientHandlerAlias(alias, original string) func(c *Config) {
return func(c *Config) {
c.aliasedHandlerMethods[alias] = original
}
}
func WithHTTPClient(h *http.Client) func(c *Config) {
return func(c *Config) {
c.httpClient = h
}
}
func WithMethodNameFormatter(namer MethodNameFormatter) func(c *Config) {
return func(c *Config) {
c.methodNamer = namer
}
}

View File

@@ -0,0 +1,125 @@
package jsonrpc
import (
"context"
"reflect"
"time"
"golang.org/x/xerrors"
)
// note: we embed reflect.Type because proxy-structs are not comparable
type jsonrpcReverseClient struct{ reflect.Type }
type ParamDecoder func(ctx context.Context, json []byte) (reflect.Value, error)
type ServerConfig struct {
maxRequestSize int64
pingInterval time.Duration
paramDecoders map[reflect.Type]ParamDecoder
errors *Errors
reverseClientBuilder func(context.Context, *wsConn) (context.Context, error)
tracer Tracer
methodNameFormatter MethodNameFormatter
}
type ServerOption func(c *ServerConfig)
func defaultServerConfig() ServerConfig {
return ServerConfig{
paramDecoders: map[reflect.Type]ParamDecoder{},
maxRequestSize: DEFAULT_MAX_REQUEST_SIZE,
pingInterval: 5 * time.Second,
methodNameFormatter: DefaultMethodNameFormatter,
}
}
func WithParamDecoder(t interface{}, decoder ParamDecoder) ServerOption {
return func(c *ServerConfig) {
c.paramDecoders[reflect.TypeOf(t).Elem()] = decoder
}
}
func WithMaxRequestSize(max int64) ServerOption {
return func(c *ServerConfig) {
c.maxRequestSize = max
}
}
func WithServerErrors(es Errors) ServerOption {
return func(c *ServerConfig) {
c.errors = &es
}
}
func WithServerPingInterval(d time.Duration) ServerOption {
return func(c *ServerConfig) {
c.pingInterval = d
}
}
func WithServerMethodNameFormatter(formatter MethodNameFormatter) ServerOption {
return func(c *ServerConfig) {
c.methodNameFormatter = formatter
}
}
// WithTracer allows the instantiator to trace the method calls and results.
// This is useful for debugging a client-server interaction.
func WithTracer(l Tracer) ServerOption {
return func(c *ServerConfig) {
c.tracer = l
}
}
// WithReverseClient will allow extracting reverse client on **WEBSOCKET** calls.
// RP is a proxy-struct type, much like the one passed to NewClient.
func WithReverseClient[RP any](namespace string) ServerOption {
return func(c *ServerConfig) {
c.reverseClientBuilder = func(ctx context.Context, conn *wsConn) (context.Context, error) {
cl := client{
namespace: namespace,
paramEncoders: map[reflect.Type]ParamEncoder{},
methodNameFormatter: c.methodNameFormatter,
}
// todo test that everything is closing correctly
cl.exiting = conn.exiting
requests := cl.setupRequestChan()
conn.requests = requests
calls := new(RP)
err := cl.provide([]interface{}{
calls,
})
if err != nil {
return nil, xerrors.Errorf("provide reverse client calls: %w", err)
}
return context.WithValue(ctx, jsonrpcReverseClient{reflect.TypeOf(calls).Elem()}, calls), nil
}
}
}
// ExtractReverseClient will extract reverse client from context. Reverse client for the type
// will only be present if the server was constructed with a matching WithReverseClient option
// and the connection was a websocket connection.
// If there is no reverse client, the call will return a zero value and `false`. Otherwise a reverse
// client and `true` will be returned.
func ExtractReverseClient[C any](ctx context.Context) (C, bool) {
c, ok := ctx.Value(jsonrpcReverseClient{reflect.TypeOf(new(C)).Elem()}).(*C)
if !ok {
return *new(C), false
}
if c == nil {
// something is very wrong, but don't panic
return *new(C), false
}
return *c, ok
}

View File

@@ -0,0 +1,306 @@
package jsonrpc
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
type ComplexData struct {
Foo string `json:"foo"`
Bar int `json:"bar"`
}
type StaticError struct{}
func (e *StaticError) Error() string { return "static error" }
// Define the error types
type SimpleError struct {
Message string
}
func (e *SimpleError) Error() string {
return e.Message
}
func (e *SimpleError) FromJSONRPCError(jerr JSONRPCError) error {
e.Message = jerr.Message
return nil
}
func (e *SimpleError) ToJSONRPCError() (JSONRPCError, error) {
return JSONRPCError{Message: e.Message}, nil
}
var _ RPCErrorCodec = (*SimpleError)(nil)
type DataStringError struct {
Message string `json:"message"`
Data string `json:"data"`
}
func (e *DataStringError) Error() string {
return e.Message
}
func (e *DataStringError) FromJSONRPCError(jerr JSONRPCError) error {
e.Message = jerr.Message
data, ok := jerr.Data.(string)
if !ok {
return fmt.Errorf("expected string data, got %T", jerr.Data)
}
e.Data = data
return nil
}
func (e *DataStringError) ToJSONRPCError() (JSONRPCError, error) {
return JSONRPCError{Message: e.Message, Data: e.Data}, nil
}
var _ RPCErrorCodec = (*DataStringError)(nil)
type DataComplexError struct {
Message string
internalData ComplexData
}
func (e *DataComplexError) Error() string {
return e.Message
}
func (e *DataComplexError) FromJSONRPCError(jerr JSONRPCError) error {
e.Message = jerr.Message
data, ok := jerr.Data.(json.RawMessage)
if !ok {
return fmt.Errorf("expected string data, got %T", jerr.Data)
}
if err := json.Unmarshal(data, &e.internalData); err != nil {
return err
}
return nil
}
func (e *DataComplexError) ToJSONRPCError() (JSONRPCError, error) {
data, err := json.Marshal(e.internalData)
if err != nil {
return JSONRPCError{}, err
}
return JSONRPCError{Message: e.Message, Data: data}, nil
}
var _ RPCErrorCodec = (*DataComplexError)(nil)
type MetaError struct {
Message string
Details string
}
func (e *MetaError) Error() string {
return e.Message
}
func (e *MetaError) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Message string `json:"message"`
Details string `json:"details"`
}{
Message: e.Message,
Details: e.Details,
})
}
func (e *MetaError) UnmarshalJSON(data []byte) error {
var temp struct {
Message string `json:"message"`
Details string `json:"details"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
e.Message = temp.Message
e.Details = temp.Details
return nil
}
type ComplexError struct {
Message string
Data ComplexData
Details string
}
func (e *ComplexError) Error() string {
return e.Message
}
func (e *ComplexError) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Message string `json:"message"`
Details string `json:"details"`
Data any `json:"data"`
}{
Details: e.Details,
Message: e.Message,
Data: e.Data,
})
}
func (e *ComplexError) UnmarshalJSON(data []byte) error {
var temp struct {
Message string `json:"message"`
Details string `json:"details"`
Data ComplexData `json:"data"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
e.Details = temp.Details
e.Message = temp.Message
e.Data = temp.Data
return nil
}
func TestRespErrorVal(t *testing.T) {
// Initialize the Errors struct and register error types
errorsMap := NewErrors()
errorsMap.Register(1000, new(*StaticError))
errorsMap.Register(1001, new(*SimpleError))
errorsMap.Register(1002, new(*DataStringError))
errorsMap.Register(1003, new(*DataComplexError))
errorsMap.Register(1004, new(*MetaError))
errorsMap.Register(1005, new(*ComplexError))
// Define test cases
testCases := []struct {
name string
respError *JSONRPCError
expectedType interface{}
expectedMessage string
verify func(t *testing.T, err error)
}{
{
name: "StaticError",
respError: &JSONRPCError{
Code: 1000,
Message: "this is ignored",
},
expectedType: &StaticError{},
expectedMessage: "static error",
},
{
name: "SimpleError",
respError: &JSONRPCError{
Code: 1001,
Message: "simple error occurred",
},
expectedType: &SimpleError{},
expectedMessage: "simple error occurred",
},
{
name: "DataStringError",
respError: &JSONRPCError{
Code: 1002,
Message: "data error occurred",
Data: "additional data",
},
expectedType: &DataStringError{},
expectedMessage: "data error occurred",
verify: func(t *testing.T, err error) {
require.IsType(t, &DataStringError{}, err)
require.Equal(t, "data error occurred", err.Error())
require.Equal(t, "additional data", err.(*DataStringError).Data)
},
},
{
name: "DataComplexError",
respError: &JSONRPCError{
Code: 1003,
Message: "data error occurred",
Data: json.RawMessage(`{"foo":"boop","bar":101}`),
},
expectedType: &DataComplexError{},
expectedMessage: "data error occurred",
verify: func(t *testing.T, err error) {
require.Equal(t, ComplexData{Foo: "boop", Bar: 101}, err.(*DataComplexError).internalData)
},
},
{
name: "MetaError",
respError: &JSONRPCError{
Code: 1004,
Message: "meta error occurred",
Meta: func() json.RawMessage {
me := &MetaError{
Message: "meta error occurred",
Details: "meta details",
}
metaData, _ := me.MarshalJSON()
return metaData
}(),
},
expectedType: &MetaError{},
expectedMessage: "meta error occurred",
verify: func(t *testing.T, err error) {
// details will also be included in the error message since it implements the marshable interface
require.Equal(t, "meta details", err.(*MetaError).Details)
},
},
{
name: "ComplexError",
respError: &JSONRPCError{
Code: 1005,
Message: "complex error occurred",
Data: json.RawMessage(`"complex data"`),
Meta: func() json.RawMessage {
ce := &ComplexError{
Message: "complex error occurred",
Details: "complex details",
Data: ComplexData{Foo: "foo", Bar: 42},
}
metaData, _ := ce.MarshalJSON()
return metaData
}(),
},
expectedType: &ComplexError{},
expectedMessage: "complex error occurred",
verify: func(t *testing.T, err error) {
require.Equal(t, ComplexData{Foo: "foo", Bar: 42}, err.(*ComplexError).Data)
require.Equal(t, "complex details", err.(*ComplexError).Details)
},
},
{
name: "UnregisteredError",
respError: &JSONRPCError{
Code: 9999,
Message: "unregistered error occurred",
Data: json.RawMessage(`"some data"`),
},
expectedType: &JSONRPCError{},
expectedMessage: "unregistered error occurred",
verify: func(t *testing.T, err error) {
require.Equal(t, json.RawMessage(`"some data"`), err.(*JSONRPCError).Data)
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
errValue := tc.respError.val(&errorsMap)
errInterface := errValue.Interface()
err, ok := errInterface.(error)
require.True(t, ok, "returned value does not implement error interface")
require.IsType(t, tc.expectedType, err)
require.Equal(t, tc.expectedMessage, err.Error())
if tc.verify != nil {
tc.verify(t, err)
}
})
}
}

View File

@@ -0,0 +1,89 @@
package jsonrpc
import (
"encoding/json"
"fmt"
"reflect"
)
type response struct {
Jsonrpc string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
ID interface{} `json:"id"`
Error *JSONRPCError `json:"error,omitempty"`
}
func (r response) MarshalJSON() ([]byte, error) {
// Custom marshal logic as per JSON-RPC 2.0 spec:
// > `result`:
// > This member is REQUIRED on success.
// > This member MUST NOT exist if there was an error invoking the method.
//
// > `error`:
// > This member is REQUIRED on error.
// > This member MUST NOT exist if there was no error triggered during invocation.
data := map[string]interface{}{
"jsonrpc": r.Jsonrpc,
"id": r.ID,
}
if r.Error != nil {
data["error"] = r.Error
} else {
data["result"] = r.Result
}
return json.Marshal(data)
}
type JSONRPCError struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
Meta json.RawMessage `json:"meta,omitempty"`
Data interface{} `json:"data,omitempty"`
}
func (e *JSONRPCError) Error() string {
if e.Code >= -32768 && e.Code <= -32000 {
return fmt.Sprintf("RPC error (%d): %s", e.Code, e.Message)
}
return e.Message
}
var (
_ error = (*JSONRPCError)(nil)
marshalableRT = reflect.TypeOf(new(marshalable)).Elem()
errorCodecRT = reflect.TypeOf(new(RPCErrorCodec)).Elem()
)
func (e *JSONRPCError) val(errors *Errors) reflect.Value {
if errors != nil {
t, ok := errors.byCode[e.Code]
if ok {
var v reflect.Value
if t.Kind() == reflect.Ptr {
v = reflect.New(t.Elem())
} else {
v = reflect.New(t)
}
if v.Type().Implements(errorCodecRT) {
if err := v.Interface().(RPCErrorCodec).FromJSONRPCError(*e); err != nil {
log.Errorf("Error converting JSONRPCError to custom error type '%s' (code %d): %w", t.String(), e.Code, err)
return reflect.ValueOf(e)
}
} else if len(e.Meta) > 0 && v.Type().Implements(marshalableRT) {
if err := v.Interface().(marshalable).UnmarshalJSON(e.Meta); err != nil {
log.Errorf("Error unmarshalling error metadata to custom error type '%s' (code %d): %w", t.String(), e.Code, err)
return reflect.ValueOf(e)
}
}
if t.Kind() != reflect.Ptr {
v = v.Elem()
}
return v
}
}
return reflect.ValueOf(e)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,183 @@
package jsonrpc
import (
"context"
"encoding/json"
"io"
"net/http"
"runtime/pprof"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
const (
rpcParseError = -32700
rpcInvalidRequest = -32600
rpcMethodNotFound = -32601
rpcInvalidParams = -32602
)
// ConnectionType indicates the type of connection, this is set in the context and can be retrieved
// with GetConnectionType.
type ConnectionType string
const (
// ConnectionTypeUnknown indicates that the connection type cannot be determined, likely because
// it hasn't passed through an RPCServer.
ConnectionTypeUnknown ConnectionType = "unknown"
// ConnectionTypeHTTP indicates that the connection is an HTTP connection.
ConnectionTypeHTTP ConnectionType = "http"
// ConnectionTypeWS indicates that the connection is a WebSockets connection.
ConnectionTypeWS ConnectionType = "websockets"
)
var connectionTypeCtxKey = &struct{ name string }{"jsonrpc-connection-type"}
// GetConnectionType returns the connection type of the request if it was set by an RPCServer.
// A connection type of ConnectionTypeUnknown means that the connection type was not set.
func GetConnectionType(ctx context.Context) ConnectionType {
if v := ctx.Value(connectionTypeCtxKey); v != nil {
return v.(ConnectionType)
}
return ConnectionTypeUnknown
}
// RPCServer provides a jsonrpc 2.0 http server handler
type RPCServer struct {
*handler
reverseClientBuilder func(context.Context, *wsConn) (context.Context, error)
pingInterval time.Duration
}
// NewServer creates new RPCServer instance
func NewServer(opts ...ServerOption) *RPCServer {
config := defaultServerConfig()
for _, o := range opts {
o(&config)
}
return &RPCServer{
handler: makeHandler(config),
reverseClientBuilder: config.reverseClientBuilder,
pingInterval: config.pingInterval,
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func (s *RPCServer) handleWS(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// TODO: allow setting
// (note that we still are mostly covered by jwt tokens)
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Header.Get("Sec-WebSocket-Protocol") != "" {
w.Header().Set("Sec-WebSocket-Protocol", r.Header.Get("Sec-WebSocket-Protocol"))
}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorw("upgrading connection", "error", err)
// note that upgrader.Upgrade will set http error if there is an error
return
}
wc := &wsConn{
conn: c,
handler: s,
pingInterval: s.pingInterval,
exiting: make(chan struct{}),
}
if s.reverseClientBuilder != nil {
ctx, err = s.reverseClientBuilder(ctx, wc)
if err != nil {
log.Errorf("failed to build reverse client: %s", err)
w.WriteHeader(500)
return
}
}
lbl := pprof.Labels("jrpc-mode", "wsserver", "jrpc-remote", r.RemoteAddr, "jrpc-uuid", uuid.New().String())
pprof.Do(ctx, lbl, func(ctx context.Context) {
wc.handleWsConn(ctx)
})
if err := c.Close(); err != nil {
log.Errorw("closing websocket connection", "error", err)
return
}
}
// TODO: return errors to clients per spec
func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h := strings.ToLower(r.Header.Get("Connection"))
if strings.Contains(h, "upgrade") {
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeWS)
s.handleWS(ctx, w, r)
return
}
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeHTTP)
s.handleReader(ctx, r.Body, w, rpcError)
}
func (s *RPCServer) HandleRequest(ctx context.Context, r io.Reader, w io.Writer) {
s.handleReader(ctx, r, w, rpcError)
}
func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) {
log.Errorf("RPC Error: %s", err)
wf(func(w io.Writer) {
if hw, ok := w.(http.ResponseWriter); ok {
if code == rpcInvalidRequest {
hw.WriteHeader(http.StatusBadRequest)
} else {
hw.WriteHeader(http.StatusInternalServerError)
}
}
log.Warnf("rpc error: %s", err)
if req == nil {
req = &request{}
}
resp := response{
Jsonrpc: "2.0",
ID: req.ID,
Error: &JSONRPCError{
Code: code,
Message: err.Error(),
},
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
log.Warnf("failed to write rpc error: %s", err)
return
}
})
}
// Register registers new RPC handler
//
// Handler is any value with methods defined
func (s *RPCServer) Register(namespace string, handler interface{}) {
s.register(namespace, handler)
}
func (s *RPCServer) AliasMethod(alias, original string) {
s.aliasedMethods[alias] = original
}
var _ error = &JSONRPCError{}

View File

@@ -0,0 +1,81 @@
package jsonrpc
import (
"encoding/json"
"fmt"
"math"
"math/rand"
"reflect"
"time"
)
type param struct {
data []byte // from unmarshal
v reflect.Value // to marshal
}
func (p *param) UnmarshalJSON(raw []byte) error {
p.data = make([]byte, len(raw))
copy(p.data, raw)
return nil
}
func (p *param) MarshalJSON() ([]byte, error) {
if p.v.Kind() == reflect.Invalid {
return p.data, nil
}
return json.Marshal(p.v.Interface())
}
// processFuncOut finds value and error Outs in function
func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
errOut = -1 // -1 if not found
valOut = -1
n = funcType.NumOut()
switch n {
case 0:
case 1:
if funcType.Out(0) == errorType {
errOut = 0
} else {
valOut = 0
}
case 2:
valOut = 0
errOut = 1
if funcType.Out(1) != errorType {
panic("expected error as second return value")
}
default:
errstr := fmt.Sprintf("too many return values: %s", funcType)
panic(errstr)
}
return
}
type backoff struct {
minDelay time.Duration
maxDelay time.Duration
}
func (b *backoff) next(attempt int) time.Duration {
if attempt < 0 {
return b.minDelay
}
minf := float64(b.minDelay)
durf := minf * math.Pow(1.5, float64(attempt))
durf = durf + rand.Float64()*minf
delay := time.Duration(durf)
if delay > b.maxDelay {
return b.maxDelay
}
return delay
}

View File

@@ -0,0 +1,3 @@
{
"version": "v0.8.0"
}

View File

@@ -0,0 +1,974 @@
package jsonrpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"golang.org/x/xerrors"
)
const wsCancel = "xrpc.cancel"
const chValue = "xrpc.ch.val"
const chClose = "xrpc.ch.close"
var debugTrace = os.Getenv("JSONRPC_ENABLE_DEBUG_TRACE") == "1"
type frame struct {
// common
Jsonrpc string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Meta map[string]string `json:"meta,omitempty"`
// request
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
// response
Result json.RawMessage `json:"result,omitempty"`
Error *JSONRPCError `json:"error,omitempty"`
}
type outChanReg struct {
reqID interface{}
chID uint64
ch reflect.Value
}
type reqestHandler interface {
handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut)
}
type wsConn struct {
// outside params
conn *websocket.Conn
connFactory func() (*websocket.Conn, error)
reconnectBackoff backoff
pingInterval time.Duration
timeout time.Duration
handler reqestHandler
requests <-chan clientRequest
pongs chan struct{}
stopPings func()
stop <-chan struct{}
exiting chan struct{}
reconfun func()
// incoming messages
incoming chan io.Reader
incomingErr error
errLk sync.Mutex
readError chan error
frameExecQueue chan []byte
// outgoing messages
writeLk sync.Mutex
// ////
// Client related
// inflight are requests we've sent to the remote
inflight map[interface{}]clientRequest
inflightLk sync.Mutex
// chanHandlers is a map of client-side channel handlers
chanHandlersLk sync.Mutex
chanHandlers map[uint64]*chanHandler
// ////
// Server related
// handling are the calls we handle
handling map[interface{}]context.CancelFunc
handlingLk sync.Mutex
spawnOutChanHandlerOnce sync.Once
// chanCtr is a counter used for identifying output channels on the server side
chanCtr uint64
registerCh chan outChanReg
}
type chanHandler struct {
// take inside chanHandlersLk
lk sync.Mutex
cb func(m []byte, ok bool)
}
// //
// WebSocket Message utils //
// //
// nextMessage wait for one message and puts it to the incoming channel
func (c *wsConn) nextMessage() {
c.resetReadDeadline()
msgType, r, err := c.conn.NextReader()
if err != nil {
c.errLk.Lock()
c.incomingErr = err
c.errLk.Unlock()
close(c.incoming)
return
}
if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage {
c.errLk.Lock()
c.incomingErr = errors.New("unsupported message type")
c.errLk.Unlock()
close(c.incoming)
return
}
c.incoming <- r
}
// nextWriter waits for writeLk and invokes the cb callback with WS message
// writer when the lock is acquired
func (c *wsConn) nextWriter(cb func(io.Writer)) {
c.writeLk.Lock()
defer c.writeLk.Unlock()
wcl, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
log.Error("handle me:", err)
return
}
cb(wcl)
if err := wcl.Close(); err != nil {
log.Error("handle me:", err)
return
}
}
func (c *wsConn) sendRequest(req request) error {
c.writeLk.Lock()
defer c.writeLk.Unlock()
if debugTrace {
log.Debugw("sendRequest", "req", req.Method, "id", req.ID)
}
if err := c.conn.WriteJSON(req); err != nil {
return err
}
return nil
}
// //
// Output channels //
// //
// handleOutChans handles channel communication on the server side
// (forwards channel messages to client)
func (c *wsConn) handleOutChans() {
regV := reflect.ValueOf(c.registerCh)
exitV := reflect.ValueOf(c.exiting)
cases := []reflect.SelectCase{
{ // registration chan always 0
Dir: reflect.SelectRecv,
Chan: regV,
},
{ // exit chan always 1
Dir: reflect.SelectRecv,
Chan: exitV,
},
}
internal := len(cases)
var caseToID []uint64
for {
chosen, val, ok := reflect.Select(cases)
switch chosen {
case 0: // registration channel
if !ok {
// control channel closed - signals closed connection
// This shouldn't happen, instead the exiting channel should get closed
log.Warn("control channel closed")
return
}
registration := val.Interface().(outChanReg)
caseToID = append(caseToID, registration.chID)
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: registration.ch,
})
c.nextWriter(func(w io.Writer) {
resp := &response{
Jsonrpc: "2.0",
ID: registration.reqID,
Result: registration.chID,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error(err)
return
}
})
continue
case 1: // exiting channel
if !ok {
// exiting channel closed - signals closed connection
//
// We're not closing any channels as we're on receiving end.
// Also, context cancellation below should take care of any running
// requests
return
}
log.Warn("exiting channel received a message")
continue
}
if !ok {
// Output channel closed, cleanup, and tell remote that this happened
id := caseToID[chosen-internal]
n := len(cases) - 1
if n > 0 {
cases[chosen] = cases[n]
caseToID[chosen-internal] = caseToID[n-internal]
}
cases = cases[:n]
caseToID = caseToID[:n-internal]
rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}})
if err != nil {
log.Error(err)
continue
}
if err := c.sendRequest(request{
Jsonrpc: "2.0",
ID: nil, // notification
Method: chClose,
Params: rp,
}); err != nil {
log.Warnf("closed out channel sendRequest failed: %s", err)
}
continue
}
// forward message
rp, err := json.Marshal([]param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}})
if err != nil {
log.Errorw("marshaling params for sendRequest failed", "err", err)
continue
}
if err := c.sendRequest(request{
Jsonrpc: "2.0",
ID: nil, // notification
Method: chValue,
Params: rp,
}); err != nil {
log.Warnf("sendRequest failed: %s", err)
return
}
}
}
// handleChanOut registers output channel for forwarding to client
func (c *wsConn) handleChanOut(ch reflect.Value, req interface{}) error {
c.spawnOutChanHandlerOnce.Do(func() {
go c.handleOutChans()
})
id := atomic.AddUint64(&c.chanCtr, 1)
select {
case c.registerCh <- outChanReg{
reqID: req,
chID: id,
ch: ch,
}:
return nil
case <-c.exiting:
return xerrors.New("connection closing")
}
}
// //
// Context.Done propagation //
// //
// handleCtxAsync handles context lifetimes for client
// TODO: this should be aware of events going through chanHandlers, and quit
//
// when the related channel is closed.
// This should also probably be a single goroutine,
// Note that not doing this should be fine for now as long as we are using
// contexts correctly (cancelling when async functions are no longer is use)
func (c *wsConn) handleCtxAsync(actx context.Context, id interface{}) {
<-actx.Done()
rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}})
if err != nil {
log.Errorw("marshaling params for sendRequest failed", "err", err)
return
}
if err := c.sendRequest(request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: rp,
}); err != nil {
log.Warnw("failed to send request", "method", wsCancel, "id", id, "error", err.Error())
}
}
// cancelCtx is a built-in rpc which handles context cancellation over rpc
func (c *wsConn) cancelCtx(req frame) {
if req.ID != nil {
log.Warnf("%s call with ID set, won't respond", wsCancel)
}
var params []param
if err := json.Unmarshal(req.Params, &params); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
var id interface{}
if err := json.Unmarshal(params[0].data, &id); err != nil {
log.Error("handle me:", err)
return
}
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
cf, ok := c.handling[id]
if ok {
cf()
}
}
// //
// Main Handling logic //
// //
func (c *wsConn) handleChanMessage(frame frame) {
var params []param
if err := json.Unmarshal(frame.Params, &params); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
var chid uint64
if err := json.Unmarshal(params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
c.chanHandlersLk.Lock()
hnd, ok := c.chanHandlers[chid]
if !ok {
c.chanHandlersLk.Unlock()
log.Errorf("xrpc.ch.val: handler %d not found", chid)
return
}
hnd.lk.Lock()
defer hnd.lk.Unlock()
c.chanHandlersLk.Unlock()
hnd.cb(params[1].data, true)
}
func (c *wsConn) handleChanClose(frame frame) {
var params []param
if err := json.Unmarshal(frame.Params, &params); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
var chid uint64
if err := json.Unmarshal(params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
c.chanHandlersLk.Lock()
hnd, ok := c.chanHandlers[chid]
if !ok {
c.chanHandlersLk.Unlock()
log.Errorf("xrpc.ch.val: handler %d not found", chid)
return
}
hnd.lk.Lock()
defer hnd.lk.Unlock()
delete(c.chanHandlers, chid)
c.chanHandlersLk.Unlock()
hnd.cb(nil, false)
}
func (c *wsConn) handleResponse(frame frame) {
c.inflightLk.Lock()
req, ok := c.inflight[frame.ID]
c.inflightLk.Unlock()
if !ok {
log.Error("client got unknown ID in response")
return
}
if req.retCh != nil && frame.Result != nil {
// output is channel
var chid uint64
if err := json.Unmarshal(frame.Result, &chid); err != nil {
log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result))
return
}
chanCtx, chHnd := req.retCh()
c.chanHandlersLk.Lock()
c.chanHandlers[chid] = &chanHandler{cb: chHnd}
c.chanHandlersLk.Unlock()
go c.handleCtxAsync(chanCtx, frame.ID)
}
req.ready <- clientResponse{
Jsonrpc: frame.Jsonrpc,
Result: frame.Result,
ID: frame.ID,
Error: frame.Error,
}
c.inflightLk.Lock()
delete(c.inflight, frame.ID)
c.inflightLk.Unlock()
}
func (c *wsConn) handleCall(ctx context.Context, frame frame) {
if c.handler == nil {
log.Error("handleCall on client with no reverse handler")
return
}
req := request{
Jsonrpc: frame.Jsonrpc,
ID: frame.ID,
Meta: frame.Meta,
Method: frame.Method,
Params: frame.Params,
}
ctx, cancel := context.WithCancel(ctx)
nextWriter := func(cb func(io.Writer)) {
cb(io.Discard)
}
done := func(keepCtx bool) {
if !keepCtx {
cancel()
}
}
if frame.ID != nil {
nextWriter = c.nextWriter
c.handlingLk.Lock()
c.handling[frame.ID] = cancel
c.handlingLk.Unlock()
done = func(keepctx bool) {
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
if !keepctx {
cancel()
delete(c.handling, frame.ID)
}
}
}
go c.handler.handle(ctx, req, nextWriter, rpcError, done, c.handleChanOut)
}
// handleFrame handles all incoming messages (calls and responses)
func (c *wsConn) handleFrame(ctx context.Context, frame frame) {
// Get message type by method name:
// "" - response
// "xrpc.*" - builtin
// anything else - incoming remote call
switch frame.Method {
case "": // Response to our call
c.handleResponse(frame)
case wsCancel:
c.cancelCtx(frame)
case chValue:
c.handleChanMessage(frame)
case chClose:
c.handleChanClose(frame)
default: // Remote call
c.handleCall(ctx, frame)
}
}
func (c *wsConn) closeInFlight() {
c.inflightLk.Lock()
for id, req := range c.inflight {
req.ready <- clientResponse{
Jsonrpc: "2.0",
ID: id,
Error: &JSONRPCError{
Message: "handler: websocket connection closed",
Code: eTempWSError,
},
}
}
c.inflight = map[interface{}]clientRequest{}
c.inflightLk.Unlock()
c.handlingLk.Lock()
for _, cancel := range c.handling {
cancel()
}
c.handling = map[interface{}]context.CancelFunc{}
c.handlingLk.Unlock()
}
func (c *wsConn) closeChans() {
c.chanHandlersLk.Lock()
defer c.chanHandlersLk.Unlock()
for chid := range c.chanHandlers {
hnd := c.chanHandlers[chid]
hnd.lk.Lock()
delete(c.chanHandlers, chid)
c.chanHandlersLk.Unlock()
hnd.cb(nil, false)
hnd.lk.Unlock()
c.chanHandlersLk.Lock()
}
}
func (c *wsConn) setupPings() func() {
if c.pingInterval == 0 {
return func() {}
}
c.conn.SetPongHandler(func(appData string) error {
select {
case c.pongs <- struct{}{}:
default:
}
return nil
})
c.conn.SetPingHandler(func(appData string) error {
// treat pings as pongs - this lets us register server activity even if it's too busy to respond to our pings
select {
case c.pongs <- struct{}{}:
default:
}
return nil
})
stop := make(chan struct{})
go func() {
for {
select {
case <-time.After(c.pingInterval):
c.writeLk.Lock()
if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
log.Errorf("sending ping message: %+v", err)
}
c.writeLk.Unlock()
case <-stop:
return
}
}
}()
var o sync.Once
return func() {
o.Do(func() {
close(stop)
})
}
}
// returns true if reconnected
func (c *wsConn) tryReconnect(ctx context.Context) bool {
if c.connFactory == nil { // server side
return false
}
// connection dropped unexpectedly, do our best to recover it
c.closeInFlight()
c.closeChans()
c.incoming = make(chan io.Reader) // listen again for responses
go func() {
c.stopPings()
attempts := 0
var conn *websocket.Conn
for conn == nil {
time.Sleep(c.reconnectBackoff.next(attempts))
if ctx.Err() != nil {
return
}
var err error
if conn, err = c.connFactory(); err != nil {
log.Debugw("websocket connection retry failed", "error", err)
}
select {
case <-ctx.Done():
return
default:
}
attempts++
}
c.writeLk.Lock()
c.conn = conn
c.errLk.Lock()
c.incomingErr = nil
c.errLk.Unlock()
c.stopPings = c.setupPings()
c.writeLk.Unlock()
go c.nextMessage()
c.reconfun()
}()
return true
}
func (c *wsConn) readFrame(ctx context.Context, r io.Reader) {
// debug util - dump all messages to stderr
// r = io.TeeReader(r, os.Stderr)
// json.NewDecoder(r).Decode would read the whole frame as well, so might as well do it
// with ReadAll which should be much faster
// use a autoResetReader in case the read takes a long time
buf, err := io.ReadAll(c.autoResetReader(r)) // todo buffer pool
if err != nil {
c.readError <- xerrors.Errorf("reading frame into a buffer: %w", err)
return
}
c.frameExecQueue <- buf
if len(c.frameExecQueue) > 2*cap(c.frameExecQueue)/3 { // warn at 2/3 capacity
log.Warnw("frame executor queue is backlogged", "queued", len(c.frameExecQueue), "cap", cap(c.frameExecQueue))
}
// got the whole frame, can start reading the next one in background
go c.nextMessage()
}
func (c *wsConn) frameExecutor(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case buf := <-c.frameExecQueue:
var frame frame
if err := json.Unmarshal(buf, &frame); err != nil {
log.Warnw("failed to unmarshal frame", "error", err)
// todo send invalid request response
continue
}
var err error
frame.ID, err = normalizeID(frame.ID)
if err != nil {
log.Warnw("failed to normalize frame id", "error", err)
// todo send invalid request response
continue
}
c.handleFrame(ctx, frame)
}
}
}
var maxQueuedFrames = 256
func (c *wsConn) handleWsConn(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
c.incoming = make(chan io.Reader)
c.readError = make(chan error, 1)
c.frameExecQueue = make(chan []byte, maxQueuedFrames)
c.inflight = map[interface{}]clientRequest{}
c.handling = map[interface{}]context.CancelFunc{}
c.chanHandlers = map[uint64]*chanHandler{}
c.pongs = make(chan struct{}, 1)
c.registerCh = make(chan outChanReg)
defer close(c.exiting)
// ////
// on close, make sure to return from all pending calls, and cancel context
// on all calls we handle
defer c.closeInFlight()
defer c.closeChans()
// setup pings
c.stopPings = c.setupPings()
defer c.stopPings()
var timeoutTimer *time.Timer
if c.timeout != 0 {
timeoutTimer = time.NewTimer(c.timeout)
defer timeoutTimer.Stop()
}
// start frame executor
go c.frameExecutor(ctx)
// wait for the first message
go c.nextMessage()
for {
var timeoutCh <-chan time.Time
if timeoutTimer != nil {
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
timeoutTimer.Reset(c.timeout)
timeoutCh = timeoutTimer.C
}
start := time.Now()
action := ""
select {
case r, ok := <-c.incoming:
action = "incoming"
c.errLk.Lock()
err := c.incomingErr
c.errLk.Unlock()
if ok {
go c.readFrame(ctx, r)
break
}
if err == nil {
return // remote closed
}
log.Debugw("websocket error", "error", err, "lastAction", action, "time", time.Since(start))
// only client needs to reconnect
if !c.tryReconnect(ctx) {
return // failed to reconnect
}
case rerr := <-c.readError:
action = "read-error"
log.Debugw("websocket error", "error", rerr, "lastAction", action, "time", time.Since(start))
if !c.tryReconnect(ctx) {
return // failed to reconnect
}
case <-ctx.Done():
log.Debugw("context cancelled", "error", ctx.Err(), "lastAction", action, "time", time.Since(start))
return
case req := <-c.requests:
action = fmt.Sprintf("send-request(%s,%v)", req.req.Method, req.req.ID)
c.writeLk.Lock()
if req.req.ID != nil { // non-notification
c.errLk.Lock()
hasErr := c.incomingErr != nil
c.errLk.Unlock()
if hasErr { // No conn?, immediate fail
req.ready <- clientResponse{
Jsonrpc: "2.0",
ID: req.req.ID,
Error: &JSONRPCError{
Message: "handler: websocket connection closed",
Code: eTempWSError,
},
}
c.writeLk.Unlock()
break
}
c.inflightLk.Lock()
c.inflight[req.req.ID] = req
c.inflightLk.Unlock()
}
c.writeLk.Unlock()
serr := c.sendRequest(req.req)
if serr != nil {
log.Errorf("sendReqest failed (Handle me): %s", serr)
}
if req.req.ID == nil { // notification, return immediately
resp := clientResponse{
Jsonrpc: "2.0",
}
if serr != nil {
resp.Error = &JSONRPCError{
Code: eTempWSError,
Message: fmt.Sprintf("sendRequest: %s", serr),
}
}
req.ready <- resp
}
case <-c.pongs:
action = "pong"
c.resetReadDeadline()
case <-timeoutCh:
if c.pingInterval == 0 {
// pings not running, this is perfectly normal
continue
}
c.writeLk.Lock()
if err := c.conn.Close(); err != nil {
log.Warnw("timed-out websocket close error", "error", err)
}
c.writeLk.Unlock()
log.Errorw("Connection timeout", "remote", c.conn.RemoteAddr(), "lastAction", action)
// The server side does not perform the reconnect operation, so need to exit
if c.connFactory == nil {
return
}
// The client performs the reconnect operation, and if it exits it cannot start a handleWsConn again, so it does not need to exit
continue
case <-c.stop:
c.writeLk.Lock()
cmsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
if err := c.conn.WriteMessage(websocket.CloseMessage, cmsg); err != nil {
log.Warn("failed to write close message: ", err)
}
if err := c.conn.Close(); err != nil {
log.Warnw("websocket close error", "error", err)
}
c.writeLk.Unlock()
return
}
if c.pingInterval > 0 && time.Since(start) > c.pingInterval*2 {
log.Warnw("websocket long time no response", "lastAction", action, "time", time.Since(start))
}
if debugTrace {
log.Debugw("websocket action", "lastAction", action, "time", time.Since(start))
}
}
}
var onReadDeadlineResetInterval = 5 * time.Second
// autoResetReader wraps a reader and resets the read deadline on if needed when doing large reads.
func (c *wsConn) autoResetReader(reader io.Reader) io.Reader {
return &deadlineResetReader{
r: reader,
reset: c.resetReadDeadline,
lastReset: time.Now(),
}
}
type deadlineResetReader struct {
r io.Reader
reset func()
lastReset time.Time
}
func (r *deadlineResetReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if time.Since(r.lastReset) > onReadDeadlineResetInterval {
log.Warnw("slow/large read, resetting deadline while reading the frame", "since", time.Since(r.lastReset), "n", n, "err", err, "p", len(p))
r.reset()
r.lastReset = time.Now()
}
return
}
func (c *wsConn) resetReadDeadline() {
if c.timeout > 0 {
if err := c.conn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil {
log.Error("setting read deadline", err)
}
}
}
// Takes an ID as received on the wire, validates it, and translates it to a
// normalized ID appropriate for keying.
func normalizeID(id interface{}) (interface{}, error) {
switch v := id.(type) {
case string, float64, nil:
return v, nil
case int64: // clients sending int64 need to normalize to float64
return float64(v), nil
default:
return nil, xerrors.Errorf("invalid id type: %T", id)
}
}
// Retry 执行带指数退避的重试
// 参数:
// - baseDelay: 基础延迟时间
// - maxRetries: 最大重试次数
// - fn: 需要执行的函数返回bool表示是否成功error为具体错误
//
// 返回:
// - 最后一次错误(如果所有重试都失败)
func Retry(baseDelay time.Duration, maxRetries int, fn func() (bool, error)) error {
var lastErr error
for i := 0; i <= maxRetries; i++ {
if i > 0 {
// 计算当前重试的延迟时间(指数增长)
delay := baseDelay * time.Duration(1<<i)
fmt.Printf("重试 %d/%d: 将在 %v 后重试,上次错误: %v", i, maxRetries, delay, lastErr)
time.Sleep(delay)
}
// 执行函数
success, err := fn()
if success {
return nil // 成功返回nil
}
lastErr = err // 记录错误,用于后续重试失败时返回
}
return fmt.Errorf("达到最大重试次数 (%d),最后错误: %w", maxRetries, lastErr)
}

19
common/utils/log/.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
# Development binary, built with makefile
*.dev
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# IDE
.idea/*
.vscode/*
.history

View File

@@ -0,0 +1,9 @@
language: go
go:
- 1.6
- 1.7
- tip
script:
- go vet $(go list ./...|grep -v "/vendor/")
- go test -v -race ./...

202
common/utils/log/LICENSE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

103
common/utils/log/README.md Normal file
View File

@@ -0,0 +1,103 @@
# Termtables
[![Build Status](https://travis-ci.org/scylladb/termtables.svg?branch=master)](https://travis-ci.org/scylladb/termtables)
A [Go](http://golang.org) port of the Ruby library [terminal-tables](https://github.com/visionmedia/terminal-table) for
fast and simple ASCII table generation.
## Installation
```bash
go get github.com/scylladb/termtables
```
## Go Style Documentation
[http://godoc.org/github.com/scylladb/termtables](http://godoc.org/github.com/scylladb/termtables)
## APC Command Line usage
`--markdown` output a markdown table, e.g. `apc app list --markdown`
`--html` output an html table, e.g. `apc app list --html`
`--ascii` output an ascii table, e.g. `apc app list --ascii`
## Basic Usage
```go
package main
import (
"fmt"
"github.com/apcera/termtables"
)
func main() {
table := termtables.CreateTable()
table.AddHeaders("Name", "Age")
table.AddRow("John", "30")
table.AddRow("Sam", 18)
table.AddRow("Julie", 20.14)
fmt.Println(table.Render())
}
```
Result:
```
+-------+-------+
| Name | Age |
+-------+-------+
| John | 30 |
| Sam | 18 |
| Julie | 20.14 |
+-------+-------+
```
## Advanced Usage
The package function-call `EnableUTF8()` will cause any tables created after
that point to use Unicode box-drawing characters for the table lines.
Calling `EnableUTF8PerLocale()` uses the C library's locale functionality to
determine if the current locale environment variables say that the current
character map is UTF-8. If, and only if, so, then `EnableUTF8()` will be
called.
Calling `SetModeHTML(true)` will cause any tables created after that point
to be emitted in HTML, while `SetModeMarkdown(true)` will trigger Markdown.
Neither should result in changes to later API to get the different results;
the primary intended use-case is extracting the same table, but for
documentation.
The table method `.AddSeparator()` inserts a rule line in the output. This
only applies in normal terminal output mode.
The table method `.AddTitle()` adds a title to the table; in terminal output,
this is an initial row; in HTML, it's a caption. In Markdown, it's a line of
text before the table, prefixed by `Table: `.
The table method `.SetAlign()` takes an alignment and a column number
(indexing starts at 1) and changes all _current_ cells in that column to have
the given alignment. It does not change the alignment of cells added to the
table after this call. Alignment is only stored on a per-cell basis.
## Known Issues
Normal output:
* `.SetAlign()` does not affect headers.
Markdown output mode:
* When emitting Markdown, the column markers are not re-flowed if a vertical
bar is an element of a cell, causing an escape to take place; since Markdown
is often converted to HTML, this only affects text viewing.
* A title in Markdown is not escaped against all possible forms of Markdown
markup (to avoid adding a dependency upon a Markdown library, as supported
syntax can vary).
* Markdown requires headers, so a dummy header will be inserted if needed.
* Table alignment is not reflected in Markdown output.

168
common/utils/log/cell.go Normal file
View File

@@ -0,0 +1,168 @@
// Copyright 2012 Apcera Inc. All rights reserved.
package termtables
import (
"fmt"
"math"
"regexp"
"strconv"
"strings"
"unicode/utf8"
runewidth "github.com/mattn/go-runewidth"
)
var (
// Must match SGR escape sequence, which is "CSI Pm m", where the Control
// Sequence Introducer (CSI) is "ESC ["; where Pm is "A multiple numeric
// parameter composed of any number of single numeric parameters, separated
// by ; character(s). Individual values for the parameters are listed with
// Ps" and where Ps is A single (usually optional) numeric parameter,
// composed of one of [sic] more digits."
//
// In practice, the end sequence is usually given as \e[0m but reading that
// definition, it's clear that the 0 is optional and some testing confirms
// that it is certainly optional with MacOS Terminal 2.3, so we need to
// support the string \e[m as a terminator too.
colorFilter = regexp.MustCompile(`\033\[(?:\d+(?:;\d+)*)?m`)
)
// A Cell denotes one cell of a table; it spans one row and a variable number
// of columns. A given Cell can only be used at one place in a table; the act
// of adding the Cell to the table mutates it with position information, so
// do not create one "const" Cell to add it multiple times.
type Cell struct {
column int
formattedValue string
alignment *TableAlignment
colSpan int
}
// CreateCell returns a Cell where the content is the supplied value, with the
// optional supplied style (which may be given as nil). The style can include
// a non-zero ColSpan to cause the cell to become column-spanning. Changing
// the style afterwards will not adjust the column-spanning state of the cell
// itself.
func CreateCell(v interface{}, style *CellStyle) *Cell {
return createCell(0, v, style)
}
func createCell(column int, v interface{}, style *CellStyle) *Cell {
cell := &Cell{column: column, formattedValue: renderValue(v), colSpan: 1}
if style != nil {
cell.alignment = &style.Alignment
if style.ColSpan != 0 {
cell.colSpan = style.ColSpan
}
}
return cell
}
// Width returns the width of the content of the cell, measured in runes as best
// as possible considering sophisticated Unicode.
func (c *Cell) Width() int {
return runewidth.StringWidth(filterColorCodes(c.formattedValue))
}
// Filter out terminal bold/color sequences in a string.
// This supports only basic bold/color escape sequences.
func filterColorCodes(s string) string {
return colorFilter.ReplaceAllString(s, "")
}
// Render returns a string representing the content of the cell, together with
// padding (to the widths specified) and handling any alignment.
func (c *Cell) Render(style *renderStyle) (buffer string) {
// if no alignment is set, import the table's default
if c.alignment == nil {
c.alignment = &style.Alignment
}
// left padding
buffer += strings.Repeat(" ", style.PaddingLeft)
// append the main value and handle alignment
buffer += c.alignCell(style)
// right padding
buffer += strings.Repeat(" ", style.PaddingRight)
// this handles escaping for, eg, Markdown, where we don't care about the
// alignment quite as much
if style.replaceContent != nil {
buffer = style.replaceContent(buffer)
}
return buffer
}
func (c *Cell) alignCell(style *renderStyle) string {
buffer := ""
width := style.CellWidth(c.column)
if c.colSpan > 1 {
for i := 1; i < c.colSpan; i++ {
w := style.CellWidth(c.column + i)
if w == 0 {
break
}
width += style.PaddingLeft + w + style.PaddingRight + utf8.RuneCountInString(style.BorderY)
}
}
switch *c.alignment {
default:
buffer += c.formattedValue
if l := width - c.Width(); l > 0 {
buffer += strings.Repeat(" ", l)
}
case AlignLeft:
buffer += c.formattedValue
if l := width - c.Width(); l > 0 {
buffer += strings.Repeat(" ", l)
}
case AlignRight:
if l := width - c.Width(); l > 0 {
buffer += strings.Repeat(" ", l)
}
buffer += c.formattedValue
case AlignCenter:
left, right := 0, 0
if l := width - c.Width(); l > 0 {
lf := float64(l)
left = int(math.Floor(lf / 2))
right = int(math.Ceil(lf / 2))
}
buffer += strings.Repeat(" ", left)
buffer += c.formattedValue
buffer += strings.Repeat(" ", right)
}
return buffer
}
// Format the raw value as a string depending on the type
func renderValue(v interface{}) string {
switch vv := v.(type) {
case string:
return vv
case bool:
return strconv.FormatBool(vv)
case int:
return strconv.Itoa(vv)
case int64:
return strconv.FormatInt(vv, 10)
case uint64:
return strconv.FormatUint(vv, 10)
case float64:
return strconv.FormatFloat(vv, 'f', 2, 64)
case fmt.Stringer:
return vv.String()
}
return fmt.Sprintf("%v", v)
}

View File

@@ -0,0 +1,113 @@
// Copyright 2012-2015 Apcera Inc. All rights reserved.
package termtables
import (
"testing"
)
func TestCellRenderString(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, "foobar", nil)
output := cell.Render(style)
if output != "foobar" {
t.Fatal("Unexpected output:", output)
}
}
func TestCellRenderBool(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, true, nil)
output := cell.Render(style)
if output != "true" {
t.Fatal("Unexpected output:", output)
}
}
func TestCellRenderInteger(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, 12345, nil)
output := cell.Render(style)
if output != "12345" {
t.Fatal("Unexpected output:", output)
}
}
func TestCellRenderFloat(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, 12.345, nil)
output := cell.Render(style)
if output != "12.35" {
t.Fatal("Unexpected output:", output)
}
}
func TestCellRenderPadding(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{PaddingLeft: 3, PaddingRight: 4}, cellWidths: map[int]int{}}
cell := createCell(0, "foobar", nil)
output := cell.Render(style)
if output != " foobar " {
t.Fatal("Unexpected output:", output)
}
}
type foo struct {
v string
}
func (f *foo) String() string {
return f.v
}
func TestCellRenderStringerStruct(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, &foo{v: "bar"}, nil)
output := cell.Render(style)
if output != "bar" {
t.Fatal("Unexpected output:", output)
}
}
type fooString string
func TestCellRenderGeneric(t *testing.T) {
style := &renderStyle{TableStyle: TableStyle{}, cellWidths: map[int]int{}}
cell := createCell(0, fooString("baz"), nil)
output := cell.Render(style)
if output != "baz" {
t.Fatal("Unexpected output:", output)
}
}
func TestFilterColorCodes(t *testing.T) {
tests := []struct {
in string
out string
}{
{"abc", "abc"},
{"", ""},
{"\033[31m\033[0m", ""},
{"a\033[31mb\033[0mc", "abc"},
{"\033[31mabc\033[0m", "abc"},
{"\033[31mfoo\033[0mbar", "foobar"},
{"\033[31mfoo\033[mbar", "foobar"},
{"\033[31mfoo\033[0;0mbar", "foobar"},
{"\033[31;4mfoo\033[0mbar", "foobar"},
{"\033[31;4;43mfoo\033[0mbar", "foobar"},
}
for _, test := range tests {
got := filterColorCodes(test.in)
if got != test.out {
t.Errorf("Invalid color-code filter result; expected %q but got %q from input %q",
test.out, got, test.in)
}
}
}

8
common/utils/log/go.mod Normal file
View File

@@ -0,0 +1,8 @@
module github.com/apcera/termtables
go 1.20
require (
github.com/mattn/go-runewidth v0.0.3-0.20170201023540-14207d285c6c
//github.com/scylladb/termtables v1.0.0
)

4
common/utils/log/go.sum Normal file
View File

@@ -0,0 +1,4 @@
github.com/mattn/go-runewidth v0.0.3-0.20170201023540-14207d285c6c h1:jQ6tSGsM/2TGhmbzHl9wXDtm2YjZDAfMsHyxaBDwywA=
github.com/mattn/go-runewidth v0.0.3-0.20170201023540-14207d285c6c/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/scylladb/termtables v1.0.0 h1:uUnesUY4V1VPCotpOQLb1LjTXVvzwy7Ramx8K8+w+8U=
github.com/scylladb/termtables v1.0.0/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=

107
common/utils/log/html.go Normal file
View File

@@ -0,0 +1,107 @@
// Copyright 2013 Apcera Inc. All rights reserved.
package termtables
import (
"bytes"
"fmt"
"html"
"strings"
)
type titleStyle int
const (
TitleAsCaption titleStyle = iota
TitleAsThSpan
)
// htmlStyleRules defines attributes which we can use, and might be set on a
// table by accessors, to influence the type of HTML which is output.
type htmlStyleRules struct {
title titleStyle
}
// HTML returns an HTML representations of the contents of one row of a table.
func (r *Row) HTML(tag string, style *renderStyle) string {
attrs := make([]string, len(r.cells))
elems := make([]string, len(r.cells))
for i := range r.cells {
if r.cells[i].alignment != nil {
switch *r.cells[i].alignment {
case AlignLeft:
attrs[i] = " align='left'"
case AlignCenter:
attrs[i] = " align='center'"
case AlignRight:
attrs[i] = " align='right'"
}
}
elems[i] = html.EscapeString(strings.TrimSpace(r.cells[i].Render(style)))
}
// WAG as to max capacity, plus a bit
buf := bytes.NewBuffer(make([]byte, 0, 8192))
buf.WriteString("<tr>")
for i := range elems {
fmt.Fprintf(buf, "<%s%s>%s</%s>", tag, attrs[i], elems[i], tag)
}
buf.WriteString("</tr>\n")
return buf.String()
}
func generateHtmlTitleRow(title interface{}, t *Table, style *renderStyle) string {
elContent := html.EscapeString(
strings.TrimSpace(CreateCell(t.title, &CellStyle{}).Render(style)),
)
switch style.htmlRules.title {
case TitleAsCaption:
return "<caption>" + elContent + "</caption>\n"
case TitleAsThSpan:
return fmt.Sprintf("<tr><th style=\"text-align: center\" colspan=\"%d\">%s</th></tr>\n",
style.columns, elContent)
default:
return "<!-- " + elContent + " -->"
}
}
// RenderHTML returns a string representation of a the table, suitable for
// inclusion as HTML elsewhere. Primary use-case controlling layout style
// is for inclusion into Markdown documents, documenting normal table use.
// Thus we leave the padding in place to have columns align when viewed as
// plain text and rely upon HTML ignoring extra whitespace.
func (t *Table) RenderHTML() (buffer string) {
// elements is already populated with row data
// generate the runtime style
style := createRenderStyle(t)
style.PaddingLeft = 0
style.PaddingRight = 0
// TODO: control CSS styles to suppress border based upon t.Style.SkipBorder
rowsText := make([]string, 0, len(t.elements)+6)
if t.title != nil || t.headers != nil {
rowsText = append(rowsText, "<thead>\n")
if t.title != nil {
rowsText = append(rowsText, generateHtmlTitleRow(t.title, t, style))
}
if t.headers != nil {
rowsText = append(rowsText, CreateRow(t.headers).HTML("th", style))
}
rowsText = append(rowsText, "</thead>\n")
}
rowsText = append(rowsText, "<tbody>\n")
// loop over the elements and render them
for i := range t.elements {
if row, ok := t.elements[i].(*Row); ok {
rowsText = append(rowsText, row.HTML("td", style))
} else {
rowsText = append(rowsText, fmt.Sprintf("<!-- unable to render line %d, unhandled type -->\n", i))
}
}
rowsText = append(rowsText, "</tbody>\n")
return "<table class=\"termtable\">\n" + strings.Join(rowsText, "") + "</table>\n"
}

View File

@@ -0,0 +1,222 @@
// Copyright 2013 Apcera Inc. All rights reserved.
package termtables
import (
"testing"
)
func TestCreateTableHTML(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<tr><th>Name</th><th>Value</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>hey</td><td>you</td></tr>\n" +
"<tr><td>ken</td><td>1234</td></tr>\n" +
"<tr><td>derek</td><td>3.14</td></tr>\n" +
"<tr><td>derek too</td><td>3.15</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableWithHeaderHTML(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<caption>Example</caption>\n" +
"<tr><th>Name</th><th>Value</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>hey</td><td>you</td></tr>\n" +
"<tr><td>ken</td><td>1234</td></tr>\n" +
"<tr><td>derek</td><td>3.14</td></tr>\n" +
"<tr><td>derek too</td><td>3.15</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddTitle("Example")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableTitleWidthAdjustsHTML(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<caption>Example My Foo Bar&#39;d Test</caption>\n" +
"<tr><th>Name</th><th>Value</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>hey</td><td>you</td></tr>\n" +
"<tr><td>ken</td><td>1234</td></tr>\n" +
"<tr><td>derek</td><td>3.14</td></tr>\n" +
"<tr><td>derek too</td><td>3.15</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddTitle("Example My Foo Bar'd Test")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableWithNoHeadersHTML(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<tbody>\n" +
"<tr><td>hey</td><td>you</td></tr>\n" +
"<tr><td>ken</td><td>1234</td></tr>\n" +
"<tr><td>derek</td><td>3.14</td></tr>\n" +
"<tr><td>derek too</td><td>3.15</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableUnicodeWidthsHTML(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<tr><th>Name</th><th>Cost</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>Currency</td><td>¤10</td></tr>\n" +
"<tr><td>US Dollar</td><td>$30</td></tr>\n" +
"<tr><td>Euro</td><td>€27</td></tr>\n" +
"<tr><td>Thai</td><td>฿70</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddHeaders("Name", "Cost")
table.AddRow("Currency", "¤10")
table.AddRow("US Dollar", "$30")
table.AddRow("Euro", "€27")
table.AddRow("Thai", "฿70")
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableWithAlignment(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<tr><th>Foo</th><th>Bar</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>humpty</td><td>dumpty</td></tr>\n" +
"<tr><td align='right'>r</td><td>&lt;- on right</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddHeaders("Foo", "Bar")
table.AddRow("humpty", "dumpty")
table.AddRow(CreateCell("r", &CellStyle{Alignment: AlignRight}), "<- on right")
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableAfterSetAlign(t *testing.T) {
expected := "<table class=\"termtable\">\n" +
"<thead>\n" +
"<tr><th>Alphabetical</th><th>Num</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td align='right'>alfa</td><td>1</td></tr>\n" +
"<tr><td align='right'>bravo</td><td>2</td></tr>\n" +
"<tr><td align='right'>charlie</td><td>3</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.AddHeaders("Alphabetical", "Num")
table.AddRow("alfa", 1)
table.AddRow("bravo", 2)
table.AddRow("charlie", 3)
table.SetAlign(AlignRight, 1)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableWithAltTitleStyle(t *testing.T) {
expected := "" +
"<table class=\"termtable\">\n" +
"<thead>\n" +
"<tr><th style=\"text-align: center\" colspan=\"3\">Metasyntactic</th></tr>\n" +
"<tr><th>Foo</th><th>Bar</th><th>Baz</th></tr>\n" +
"</thead>\n" +
"<tbody>\n" +
"<tr><td>a</td><td>b</td><td>c</td></tr>\n" +
"<tr><td>α</td><td>β</td><td>γ</td></tr>\n" +
"</tbody>\n" +
"</table>\n"
table := CreateTable()
table.SetModeHTML()
table.SetHTMLStyleTitle(TitleAsThSpan)
table.AddTitle("Metasyntactic")
table.AddHeaders("Foo", "Bar", "Baz")
table.AddRow("a", "b", "c")
table.AddRow("α", "β", "γ")
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}

47
common/utils/log/row.go Normal file
View File

@@ -0,0 +1,47 @@
// Copyright 2012 Apcera Inc. All rights reserved.
package termtables
import "strings"
// A Row represents one row of a Table, consisting of some number of Cell
// items.
type Row struct {
cells []*Cell
}
// CreateRow returns a Row where the cells are created as needed to hold each
// item given; each item can be a Cell or content to go into a Cell created
// to hold it.
func CreateRow(items []interface{}) *Row {
row := &Row{cells: []*Cell{}}
for _, item := range items {
row.AddCell(item)
}
return row
}
// AddCell adds one item to a row as a new cell, where the item is either a
// Cell or content to be put into a cell.
func (r *Row) AddCell(item interface{}) {
if c, ok := item.(*Cell); ok {
c.column = len(r.cells)
r.cells = append(r.cells, c)
} else {
r.cells = append(r.cells, createCell(len(r.cells), item, nil))
}
}
// Render returns a string representing the content of one row of a table, where
// the Row contains Cells (not Separators) and the representation includes any
// vertical borders needed.
func (r *Row) Render(style *renderStyle) string {
// pre-render and shove into an array... helps with cleanly adding borders
renderedCells := []string{}
for _, c := range r.cells {
renderedCells = append(renderedCells, c.Render(style))
}
// format final output
return style.BorderY + strings.Join(renderedCells, style.BorderY) + style.BorderY
}

View File

@@ -0,0 +1,29 @@
// Copyright 2012-2015 Apcera Inc. All rights reserved.
package termtables
import (
"testing"
)
func TestBasicRowRender(t *testing.T) {
row := CreateRow([]interface{}{"foo", "bar"})
style := &renderStyle{TableStyle: TableStyle{BorderX: "-", BorderY: "|", BorderI: "+",
PaddingLeft: 1, PaddingRight: 1}, cellWidths: map[int]int{0: 3, 1: 3}}
output := row.Render(style)
if output != "| foo | bar |" {
t.Fatal("Unexpected output:", output)
}
}
func TestRowRenderWidthBasedPadding(t *testing.T) {
row := CreateRow([]interface{}{"foo", "bar"})
style := &renderStyle{TableStyle: TableStyle{BorderX: "-", BorderY: "|", BorderI: "+",
PaddingLeft: 1, PaddingRight: 1}, cellWidths: map[int]int{0: 3, 1: 5}}
output := row.Render(style)
if output != "| foo | bar |" {
t.Fatal("Unexpected output:", output)
}
}

View File

@@ -0,0 +1,60 @@
// Copyright 2012 Apcera Inc. All rights reserved.
package termtables
import "strings"
type lineType int
// These lines are for horizontal rules; these indicate desired styling,
// but simplistic (pure ASCII) markup characters may end up leaving the
// variant lines indistinguishable from LINE_INNER.
const (
// LINE_INNER *must* be the default; where there are vertical lines drawn
// across an inner line, the character at that position should indicate
// that the vertical line goes both up and down from this horizontal line.
LINE_INNER lineType = iota
// LINE_TOP has only descenders
LINE_TOP
// LINE_SUBTOP has only descenders in the middle, but goes both up and
// down at the far left & right edges.
LINE_SUBTOP
// LINE_BOTTOM has only ascenders.
LINE_BOTTOM
)
// A Separator is a horizontal rule line, with associated information which
// indicates where in a table it is, sufficient for simple cases to let
// clean tables be drawn. If a row-spanning cell is created, then this will
// be insufficient: we can get away with hand-waving of "well, it's showing
// where the border would be" but a more capable handling will require
// structure reworking. Patches welcome.
type Separator struct {
where lineType
}
// Render returns the string representation of a horizontal rule line in the
// table.
func (s *Separator) Render(style *renderStyle) string {
// loop over getting dashes
parts := []string{}
for i := 0; i < style.columns; i++ {
w := style.PaddingLeft + style.CellWidth(i) + style.PaddingRight
parts = append(parts, strings.Repeat(style.BorderX, w))
}
switch s.where {
case LINE_TOP:
return style.BorderTopLeft + strings.Join(parts, style.BorderTop) + style.BorderTopRight
case LINE_SUBTOP:
return style.BorderLeft + strings.Join(parts, style.BorderTop) + style.BorderRight
case LINE_BOTTOM:
return style.BorderBottomLeft + strings.Join(parts, style.BorderBottom) + style.BorderBottomRight
case LINE_INNER:
return style.BorderLeft + strings.Join(parts, style.BorderI) + style.BorderRight
}
panic("not reached")
}

View File

@@ -0,0 +1,36 @@
// Copyright 2012 Apcera Inc. All rights reserved.
package termtables
import (
"strings"
"unicode/utf8"
)
// A StraightSeparator is a horizontal line with associated information about
// what sort of position it takes in the table, so as to control which shapes
// will be used where vertical lines are expected to touch this horizontal
// line.
type StraightSeparator struct {
where lineType
}
// Render returns a string representing this separator, with all border
// crossings appropriately chosen.
func (s *StraightSeparator) Render(style *renderStyle) string {
// loop over getting dashes
width := 0
for i := 0; i < style.columns; i++ {
width += style.PaddingLeft + style.CellWidth(i) + style.PaddingRight + utf8.RuneCountInString(style.BorderI)
}
switch s.where {
case LINE_TOP:
return style.BorderTopLeft + strings.Repeat(style.BorderX, width-1) + style.BorderTopRight
case LINE_INNER, LINE_SUBTOP:
return style.BorderLeft + strings.Repeat(style.BorderX, width-1) + style.BorderRight
case LINE_BOTTOM:
return style.BorderBottomLeft + strings.Repeat(style.BorderX, width-1) + style.BorderBottomRight
}
panic("not reached")
}

214
common/utils/log/style.go Normal file
View File

@@ -0,0 +1,214 @@
// Copyright 2012-2013 Apcera Inc. All rights reserved.
package termtables
import (
"fmt"
"strings"
"unicode/utf8"
)
type TableAlignment int
// These constants control the alignment which should be used when rendering
// the content of a cell.
const (
AlignLeft = TableAlignment(1)
AlignCenter = TableAlignment(2)
AlignRight = TableAlignment(3)
)
// TableStyle controls styling information for a Table as a whole.
//
// For the Border rules, only X, Y and I are needed, and all have defaults.
// The others will all default to the same as BorderI.
type TableStyle struct {
SkipBorder bool
BorderX string
BorderY string
BorderI string
BorderTop string
BorderBottom string
BorderRight string
BorderLeft string
BorderTopLeft string
BorderTopRight string
BorderBottomLeft string
BorderBottomRight string
PaddingLeft int
PaddingRight int
Width int
Alignment TableAlignment
htmlRules htmlStyleRules
}
// A CellStyle controls all style applicable to one Cell.
type CellStyle struct {
// Alignment indicates the alignment to be used in rendering the content
Alignment TableAlignment
// ColSpan indicates how many columns this Cell is expected to consume.
ColSpan int
}
// DefaultStyle is a TableStyle which can be used to get some simple
// default styling for a table, using ASCII characters for drawing borders.
var DefaultStyle = &TableStyle{
SkipBorder: false,
BorderX: "-", BorderY: "|", BorderI: "+",
PaddingLeft: 1, PaddingRight: 1,
Width: 80,
Alignment: AlignLeft,
// FIXME: the use of a Width here may interact poorly with a changing
// MaxColumns value; we don't set MaxColumns here because the evaluation
// order of a var and an init value adds undesired subtlety.
}
type renderStyle struct {
cellWidths map[int]int
columns int
// used for markdown rendering
replaceContent func(string) string
TableStyle
}
// setUtfBoxStyle changes the border characters to be suitable for use when
// the output stream can render UTF-8 characters.
func (s *TableStyle) setUtfBoxStyle() {
s.BorderX = "─"
s.BorderY = "│"
s.BorderI = "┼"
s.BorderTop = "┬"
s.BorderBottom = "┴"
s.BorderLeft = "├"
s.BorderRight = "┤"
s.BorderTopLeft = "╭"
s.BorderTopRight = "╮"
s.BorderBottomLeft = "╰"
s.BorderBottomRight = "╯"
}
// setAsciiBoxStyle changes the border characters back to their defaults
func (s *TableStyle) setAsciiBoxStyle() {
s.BorderX = "-"
s.BorderY = "|"
s.BorderI = "+"
s.BorderTop, s.BorderBottom, s.BorderLeft, s.BorderRight = "", "", "", ""
s.BorderTopLeft, s.BorderTopRight, s.BorderBottomLeft, s.BorderBottomRight = "", "", "", ""
s.fillStyleRules()
}
// fillStyleRules populates members of the TableStyle box-drawing specification
// with BorderI as the default.
func (s *TableStyle) fillStyleRules() {
if s.BorderTop == "" {
s.BorderTop = s.BorderI
}
if s.BorderBottom == "" {
s.BorderBottom = s.BorderI
}
if s.BorderLeft == "" {
s.BorderLeft = s.BorderI
}
if s.BorderRight == "" {
s.BorderRight = s.BorderI
}
if s.BorderTopLeft == "" {
s.BorderTopLeft = s.BorderI
}
if s.BorderTopRight == "" {
s.BorderTopRight = s.BorderI
}
if s.BorderBottomLeft == "" {
s.BorderBottomLeft = s.BorderI
}
if s.BorderBottomRight == "" {
s.BorderBottomRight = s.BorderI
}
}
func createRenderStyle(table *Table) *renderStyle {
style := &renderStyle{TableStyle: *table.Style, cellWidths: map[int]int{}}
style.TableStyle.fillStyleRules()
if table.outputMode == outputMarkdown {
style.buildReplaceContent(table.Style.BorderY)
}
// FIXME: handle actually defined width condition
// loop over the rows and cells to calculate widths
for _, element := range table.elements {
// skip separators
if _, ok := element.(*Separator); ok {
continue
}
// iterate over cells
if row, ok := element.(*Row); ok {
for i, cell := range row.cells {
// FIXME: need to support sizing with colspan handling
if cell.colSpan > 1 {
continue
}
if style.cellWidths[i] < cell.Width() {
style.cellWidths[i] = cell.Width()
}
}
}
}
style.columns = len(style.cellWidths)
// calculate actual width
width := utf8.RuneCountInString(style.BorderLeft) // start at '1' for left border
internalBorderWidth := utf8.RuneCountInString(style.BorderI)
lastIndex := 0
for i, v := range style.cellWidths {
width += v + style.PaddingLeft + style.PaddingRight + internalBorderWidth
if i > lastIndex {
lastIndex = i
}
}
if internalBorderWidth != utf8.RuneCountInString(style.BorderRight) {
width += utf8.RuneCountInString(style.BorderRight) - internalBorderWidth
}
if table.titleCell != nil {
titleMinWidth := 0 +
table.titleCell.Width() +
utf8.RuneCountInString(style.BorderLeft) +
utf8.RuneCountInString(style.BorderRight) +
style.PaddingLeft +
style.PaddingRight
if width < titleMinWidth {
// minWidth must be set to include padding of the title, as required
style.cellWidths[lastIndex] += (titleMinWidth - width)
width = titleMinWidth
}
}
// right border is covered in loop
style.Width = width
return style
}
// CellWidth returns the width of the cell at the supplied index, where the
// width is the number of tty character-cells required to draw the glyphs.
func (s *renderStyle) CellWidth(i int) int {
return s.cellWidths[i]
}
// buildReplaceContent creates a function closure, with minimal bound lexical
// state, which replaces content
func (s *renderStyle) buildReplaceContent(bad string) {
replacement := fmt.Sprintf("&#x%02x;", bad)
s.replaceContent = func(old string) string {
return strings.Replace(old, bad, replacement, -1)
}
}

373
common/utils/log/table.go Normal file
View File

@@ -0,0 +1,373 @@
// Copyright 2012-2013 Apcera Inc. All rights reserved.
package termtables
import (
"bytes"
"os"
"regexp"
"runtime"
"strings"
"github.com/apcera/termtables/term"
)
// MaxColumns represents the maximum number of columns that are available for
// display without wrapping around the right-hand side of the terminal window.
// At program initialization, the value will be automatically set according
// to available sources of information, including the $COLUMNS environment
// variable and, on Unix, tty information.
var MaxColumns = 80
// Element the interface that can draw a representation of the contents of a
// table cell.
type Element interface {
Render(*renderStyle) string
}
type outputMode int
const (
outputTerminal outputMode = iota
outputMarkdown
outputHTML
)
// Open question: should UTF-8 become an output mode? It does require more
// tracking when resetting, if the locale-enabling had been used.
var outputsEnabled struct {
UTF8 bool
HTML bool
Markdown bool
titleStyle titleStyle
}
var defaultOutputMode outputMode = outputTerminal
// Table represents a terminal table. The Style can be directly accessed
// and manipulated; all other access is via methods.
type Table struct {
Style *TableStyle
elements []Element
headers []interface{}
title interface{}
titleCell *Cell
outputMode outputMode
}
// EnableUTF8 will unconditionally enable using UTF-8 box-drawing characters
// for any tables created after this call, as the default style.
func EnableUTF8() {
outputsEnabled.UTF8 = true
}
// SetModeHTML will control whether or not new tables generated will be in HTML
// mode by default; HTML-or-not takes precedence over options which control how
// a terminal output will be rendered, such as whether or not to use UTF8.
// This affects any tables created after this call.
func SetModeHTML(onoff bool) {
outputsEnabled.HTML = onoff
chooseDefaultOutput()
}
// SetModeMarkdown will control whether or not new tables generated will be
// in Markdown mode by default. HTML-mode takes precedence.
func SetModeMarkdown(onoff bool) {
outputsEnabled.Markdown = onoff
chooseDefaultOutput()
}
var utfRe = regexp.MustCompile(`utf\-8|utf8|UTF\-8|UTF8`)
// EnableUTF8PerLocale will use current locale character map information to
// determine if UTF-8 is expected and, if so, is equivalent to EnableUTF8.
func EnableUTF8PerLocale() {
locale := getLocale()
if utfRe.MatchString(locale) {
EnableUTF8()
}
}
// getLocale returns the current locale name.
func getLocale() string {
if runtime.GOOS == "windows" {
// TODO: detect windows locale
return "US-ASCII"
}
return unixLocale()
}
// unixLocale returns the locale by checking the $LC_ALL, $LC_CTYPE, and $LANG
// environment variables. If none of those are set, it returns "US-ASCII".
func unixLocale() string {
for _, env := range []string{"LC_ALL", "LC_CTYPE", "LANG"} {
if locale := os.Getenv(env); locale != "" {
return locale
}
}
return "US-ASCII"
}
// SetHTMLStyleTitle lets an HTML title output mode be chosen.
func SetHTMLStyleTitle(want titleStyle) {
outputsEnabled.titleStyle = want
}
// chooseDefaultOutput sets defaultOutputMode based on priority
// choosing amongst the options which are enabled. Pros: simpler
// encapsulation; cons: setting markdown doesn't disable HTML if
// HTML was previously enabled and was later disabled.
// This seems fairly reasonable.
func chooseDefaultOutput() {
if outputsEnabled.HTML {
defaultOutputMode = outputHTML
} else if outputsEnabled.Markdown {
defaultOutputMode = outputMarkdown
} else {
defaultOutputMode = outputTerminal
}
}
func init() {
// Do not enable UTF-8 per locale by default, breaks tests.
sz, err := term.GetSize()
if err == nil && sz.Columns != 0 {
MaxColumns = sz.Columns
}
}
// CreateTable creates an empty Table using defaults for style.
func CreateTable() *Table {
t := &Table{elements: []Element{}, Style: DefaultStyle}
if outputsEnabled.UTF8 {
t.Style.setUtfBoxStyle()
}
if outputsEnabled.titleStyle != titleStyle(0) {
t.Style.htmlRules.title = outputsEnabled.titleStyle
}
t.outputMode = defaultOutputMode
return t
}
// AddSeparator adds a line to the table content, where the line
// consists of separator characters.
func (t *Table) AddSeparator() {
t.elements = append(t.elements, &Separator{})
}
// AddRow adds the supplied items as cells in one row of the table.
func (t *Table) AddRow(items ...interface{}) *Row {
row := CreateRow(items)
t.elements = append(t.elements, row)
return row
}
// AddTitle supplies a table title, which if present will be rendered as
// one cell across the width of the table, as the first row.
func (t *Table) AddTitle(title interface{}) {
t.title = title
}
// AddHeaders supplies column headers for the table.
func (t *Table) AddHeaders(headers ...interface{}) {
t.headers = append(t.headers, headers...)
}
// SetAlign changes the alignment for elements in a column of the table;
// alignments are stored with each cell, so cells added after a call to
// SetAlign will not pick up the change. Columns are numbered from 1.
func (t *Table) SetAlign(align TableAlignment, columns ...int) {
for i := range t.elements {
row, ok := t.elements[i].(*Row)
if !ok {
continue
}
for _, column := range columns {
if column < 0 || column > len(row.cells) {
continue
}
row.cells[column-1].alignment = &align
}
}
}
// UTF8Box sets the table style to use UTF-8 box-drawing characters,
// overriding all relevant style elements at the time of the call.
func (t *Table) UTF8Box() {
t.Style.setUtfBoxStyle()
}
// SetModeHTML switches this table to be in HTML when rendered; the
// default depends upon whether the package function SetModeHTML() has been
// called, and with what value. This method forces the feature on for this
// table. Turning off involves choosing a different mode, per-table.
func (t *Table) SetModeHTML() {
t.outputMode = outputHTML
}
// SetModeMarkdown switches this table to be in Markdown mode
func (t *Table) SetModeMarkdown() {
t.outputMode = outputMarkdown
}
// SetModeTerminal switches this table to be in terminal mode.
func (t *Table) SetModeTerminal() {
t.outputMode = outputTerminal
}
// SetHTMLStyleTitle lets an HTML output mode be chosen; we should rework this
// into a more generic and extensible API as we clean up termtables.
func (t *Table) SetHTMLStyleTitle(want titleStyle) {
t.Style.htmlRules.title = want
}
// Render returns a string representation of a fully rendered table, drawn
// out for display, with embedded newlines. If this table is in HTML mode,
// then this is equivalent to RenderHTML().
func (t *Table) Render() string {
// Elements is already populated with row data.
switch t.outputMode {
case outputTerminal:
return t.renderTerminal()
case outputMarkdown:
return t.renderMarkdown()
case outputHTML:
return t.RenderHTML()
default:
panic("unknown output mode set")
}
}
// renderTerminal returns a string representation of a fully rendered table,
// drawn out for display, with embedded newlines.
func (t *Table) renderTerminal() string {
// Use a placeholder rather than adding titles/headers to the tables
// elements or else successive calls will compound them.
tt := t.clone()
// Initial top line.
if !tt.Style.SkipBorder {
if tt.title != nil && tt.headers == nil {
tt.elements = append([]Element{&Separator{where: LINE_SUBTOP}}, tt.elements...)
} else if tt.title == nil && tt.headers == nil {
tt.elements = append([]Element{&Separator{where: LINE_TOP}}, tt.elements...)
} else {
tt.elements = append([]Element{&Separator{where: LINE_INNER}}, tt.elements...)
}
}
// If we have headers, include them.
if tt.headers != nil {
ne := make([]Element, 2)
ne[1] = CreateRow(tt.headers)
if tt.title != nil {
ne[0] = &Separator{where: LINE_SUBTOP}
} else {
ne[0] = &Separator{where: LINE_TOP}
}
tt.elements = append(ne, tt.elements...)
}
// If we have a title, write it.
if tt.title != nil {
// Match changes to this into renderMarkdown too.
tt.titleCell = CreateCell(tt.title, &CellStyle{Alignment: AlignCenter, ColSpan: 999})
ne := []Element{
&StraightSeparator{where: LINE_TOP},
CreateRow([]interface{}{tt.titleCell}),
}
tt.elements = append(ne, tt.elements...)
}
// Create a new table from the
// generate the runtime style. Must include all cells being printed.
style := createRenderStyle(tt)
// Loop over the elements and render them.
b := bytes.NewBuffer(nil)
for _, e := range tt.elements {
b.WriteString(e.Render(style))
b.WriteString("\n")
}
// Add bottom line.
if !style.SkipBorder {
b.WriteString((&Separator{where: LINE_BOTTOM}).Render(style) + "\n")
}
return b.String()
}
// renderMarkdown returns a string representation of a table in Markdown
// markup format using GitHub Flavored Markdown's notation (since tables
// are not in the core Markdown spec).
func (t *Table) renderMarkdown() string {
// We need ASCII drawing characters; we need a line after the header;
// *do* need a header! Do not need to markdown-escape contents of
// tables as markdown is ignored in there. Do need to do _something_
// with a '|' character shown as a member of a table.
t.Style.setAsciiBoxStyle()
firstLines := make([]Element, 0, 2)
if t.headers == nil {
initial := createRenderStyle(t)
if initial.columns > 1 {
row := CreateRow([]interface{}{})
for i := 0; i < initial.columns; i++ {
row.AddCell(CreateCell(i+1, &CellStyle{}))
}
}
}
firstLines = append(firstLines, CreateRow(t.headers))
// This is a dummy line, swapped out below.
firstLines = append(firstLines, firstLines[0])
t.elements = append(firstLines, t.elements...)
// Generate the runtime style.
style := createRenderStyle(t)
// We know that the second line is a dummy, we can replace it.
mdRow := CreateRow([]interface{}{})
for i := 0; i < style.columns; i++ {
mdRow.AddCell(CreateCell(strings.Repeat("-", style.cellWidths[i]), &CellStyle{}))
}
t.elements[1] = mdRow
b := bytes.NewBuffer(nil)
// Comes after style is generated, which must come after all width-affecting
// changes are in.
if t.title != nil {
// Markdown doesn't support titles or column spanning; we _should_
// escape the title, but doing that to handle all possible forms of
// markup would require a heavy dependency, so we punt.
b.WriteString("Table: ")
b.WriteString(strings.TrimSpace(CreateCell(t.title, &CellStyle{}).Render(style)))
b.WriteString("\n\n")
}
// Loop over the elements and render them.
for _, e := range t.elements {
b.WriteString(e.Render(style))
b.WriteString("\n")
}
return b.String()
}
// clone returns a copy of the table with the underlying slices being copied;
// the references to the Elements/cells are left as shallow copies.
func (t *Table) clone() *Table {
tt := &Table{outputMode: t.outputMode, Style: t.Style, title: t.title}
if t.headers != nil {
tt.headers = make([]interface{}, len(t.headers))
copy(tt.headers, t.headers)
}
if t.elements != nil {
tt.elements = make([]Element, len(t.elements))
copy(tt.elements, t.elements)
}
return tt
}

View File

@@ -0,0 +1,562 @@
// Copyright 2012-2013 Apcera Inc. All rights reserved.
package termtables
import "testing"
func DisplayFailedOutput(actual, expected string) string {
return "Output didn't match expected\n\n" +
"Actual:\n\n" +
actual + "\n" +
"Expected:\n\n" +
expected
}
func checkRendersTo(t *testing.T, table *Table, expected string) {
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestCreateTable(t *testing.T) {
expected := "" +
"+-----------+-------+\n" +
"| Name | Value |\n" +
"+-----------+-------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"| escaping | rox%% |\n" +
"+-----------+-------+\n"
table := CreateTable()
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
table.AddRow("escaping", "rox%%")
checkRendersTo(t, table, expected)
}
func TestStyleResets(t *testing.T) {
expected := "" +
"+-----------+-------+\n" +
"| Name | Value |\n" +
"+-----------+-------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------+-------+\n"
table := CreateTable()
table.UTF8Box()
table.Style.setAsciiBoxStyle()
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
checkRendersTo(t, table, expected)
}
func TestTableWithHeader(t *testing.T) {
expected := "" +
"+-------------------+\n" +
"| Example |\n" +
"+-----------+-------+\n" +
"| Name | Value |\n" +
"+-----------+-------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------+-------+\n"
table := CreateTable()
table.AddTitle("Example")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
checkRendersTo(t, table, expected)
}
// TestTableWithHeaderMultipleTimes ensures that printing a table with headers
// multiple times continues to render correctly.
func TestTableWithHeaderMultipleTimes(t *testing.T) {
expected := "" +
"+-------------------+\n" +
"| Example |\n" +
"+-----------+-------+\n" +
"| Name | Value |\n" +
"+-----------+-------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------+-------+\n"
table := CreateTable()
table.AddTitle("Example")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
checkRendersTo(t, table, expected)
checkRendersTo(t, table, expected)
}
func TestTableTitleWidthAdjusts(t *testing.T) {
expected := "" +
"+---------------------------+\n" +
"| Example My Foo Bar'd Test |\n" +
"+-----------+---------------+\n" +
"| Name | Value |\n" +
"+-----------+---------------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------+---------------+\n"
table := CreateTable()
table.AddTitle("Example My Foo Bar'd Test")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
checkRendersTo(t, table, expected)
}
func TestTableHeaderWidthAdjusts(t *testing.T) {
expected := "" +
"+---------------+---------------------+\n" +
"| Slightly Long | More than 2 columns |\n" +
"+---------------+---------------------+\n" +
"| a | b |\n" +
"+---------------+---------------------+\n"
table := CreateTable()
table.AddHeaders("Slightly Long", "More than 2 columns")
table.AddRow("a", "b")
checkRendersTo(t, table, expected)
}
func TestTableWithNoHeaders(t *testing.T) {
expected := "" +
"+-----------+------+\n" +
"| hey | you |\n" +
"| ken | 1234 |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------+------+\n"
table := CreateTable()
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
checkRendersTo(t, table, expected)
}
func TestTableUnicodeWidths(t *testing.T) {
expected := "" +
"+-----------+------+\n" +
"| Name | Cost |\n" +
"+-----------+------+\n" +
"| Currency | ¤10 |\n" +
"| US Dollar | $30 |\n" +
"| Euro | €27 |\n" +
"| Thai | ฿70 |\n" +
"+-----------+------+\n"
table := CreateTable()
table.AddHeaders("Name", "Cost")
table.AddRow("Currency", "¤10")
table.AddRow("US Dollar", "$30")
table.AddRow("Euro", "€27")
table.AddRow("Thai", "฿70")
checkRendersTo(t, table, expected)
}
func TestTableInUTF8(t *testing.T) {
expected := "" +
"╭───────────────────╮\n" +
"│ Example │\n" +
"├───────────┬───────┤\n" +
"│ Name │ Value │\n" +
"├───────────┼───────┤\n" +
"│ hey │ you │\n" +
"│ ken │ 1234 │\n" +
"│ derek │ 3.14 │\n" +
"│ derek too │ 3.15 │\n" +
"│ escaping │ rox%% │\n" +
"╰───────────┴───────╯\n"
table := CreateTable()
table.UTF8Box()
table.AddTitle("Example")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("ken", 1234)
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
table.AddRow("escaping", "rox%%")
checkRendersTo(t, table, expected)
}
func TestTableUnicodeUTF8AndSGR(t *testing.T) {
// at present, this mostly just tests that alignment still works
expected := "" +
"╭───────────────────────╮\n" +
"│ \033[1mFanciness\033[0m │\n" +
"├──────────┬────────────┤\n" +
"│ \033[31mred\033[0m │ \033[32mgreen\033[0m │\n" +
"├──────────┼────────────┤\n" +
"│ plain │ text │\n" +
"│ Καλημέρα │ κόσμε │\n" +
"│ \033[1mvery\033[0m │ \033[4munderlined\033[0m │\n" +
"│ a\033[1mb\033[0mc │ \033[45mmagenta\033[0m │\n" +
"│ \033[31m→\033[0m │ \033[32m←\033[0m │\n" +
"╰──────────┴────────────╯\n"
sgred := func(in string, sgrPm string) string {
return "\033[" + sgrPm + "m" + in + "\033[0m"
}
bold := func(in string) string { return sgred(in, "1") }
table := CreateTable()
table.UTF8Box()
table.AddTitle(bold("Fanciness"))
table.AddHeaders(sgred("red", "31"), sgred("green", "32"))
table.AddRow("plain", "text")
table.AddRow("Καλημέρα", "κόσμε") // from http://plan9.bell-labs.com/sys/doc/utf.html
table.AddRow(bold("very"), sgred("underlined", "4"))
table.AddRow("a"+bold("b")+"c", sgred("magenta", "45"))
table.AddRow(sgred("→", "31"), sgred("←", "32"))
// TODO: in future, if we start detecting presence of SGR sequences, we
// should ensure that the SGR reset is done at the end of the cell content,
// so that SGR doesn't "bleed across" (cells or rows). We would then add
// tests for that here.
//
// Of course, at that point, we'd also want to support automatic HTML
// styling conversion too, so would need a test for that also.
checkRendersTo(t, table, expected)
}
func TestTableInMarkdown(t *testing.T) {
expected := "" +
"Table: Example\n\n" +
"| Name | Value |\n" +
"| ----- | ----- |\n" +
"| hey | you |\n" +
"| a &#x7c; b | esc |\n" +
"| esc | rox%% |\n"
table := CreateTable()
table.SetModeMarkdown()
table.AddTitle("Example")
table.AddHeaders("Name", "Value")
table.AddRow("hey", "you")
table.AddRow("a | b", "esc")
table.AddRow("esc", "rox%%")
checkRendersTo(t, table, expected)
}
func TestTitleUnicodeWidths(t *testing.T) {
expected := "" +
"+-------+\n" +
"| ← 5 → |\n" +
"+---+---+\n" +
"| a | b |\n" +
"| c | d |\n" +
"| e | 3 |\n" +
"+---+---+\n"
// minimum width for a table of two columns is 9 characters, given
// one space of padding, and non-empty tables.
table := CreateTable()
// We have 4 characters down for left and right columns and padding, so
// a width of 5 for us should match the minimum per the columns
// 5 characters; each arrow is three octets in UTF-8, giving 9 bytes
// so, same in character-count-width, longer in bytes
table.AddTitle("← 5 →")
// a single character per cell, here; use ASCII characters
table.AddRow("a", "b")
table.AddRow("c", "d")
table.AddRow("e", 3)
checkRendersTo(t, table, expected)
}
// We identified two error conditions wherein length wrapping would not correctly
// wrap width when, for instance, in a two-column table, the longest row in the
// right-hand column was not the same as the longest row in the left-hand column.
// This tests that we correctly accumulate the maximum width across all rows of
// the termtable and adjust width accordingly.
func TestTableWidthHandling(t *testing.T) {
expected := "" +
"+-----------------------------------------+\n" +
"| Example... to Fix My Test |\n" +
"+-----------------+-----------------------+\n" +
"| hey foo bar baz | you |\n" +
"| ken | you should write code |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------------+-----------------------+\n"
table := CreateTable()
table.AddTitle("Example... to Fix My Test")
table.AddRow("hey foo bar baz", "you")
table.AddRow("ken", "you should write code")
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableWidthHandling_SecondErrorCondition(t *testing.T) {
expected := "" +
"+----------------------------------------+\n" +
"| Example... to Fix My Test |\n" +
"+-----------------+----------------------+\n" +
"| hey foo bar baz | you |\n" +
"| ken | you should sell cod! |\n" +
"| derek | 3.14 |\n" +
"| derek too | 3.15 |\n" +
"+-----------------+----------------------+\n"
table := CreateTable()
table.AddTitle("Example... to Fix My Test")
table.AddRow("hey foo bar baz", "you")
table.AddRow("ken", "you should sell cod!")
table.AddRow("derek", 3.14)
table.AddRow("derek too", 3.1456788)
output := table.Render()
if output != expected {
t.Fatal(DisplayFailedOutput(output, expected))
}
}
func TestTableAlignPostsetting(t *testing.T) {
expected := "" +
"+-----------+-------+----------+\n"+
"| Name | Value | Value 2 |\n"+
"+-----------+-------+----------+\n"+
"| hey | you | man |\n"+
"| ken | 1234 | 4321 |\n"+
"| derek | 3.14 | bob |\n"+
"| derek too | 3.15 | long bob |\n"+
"| escaping | rox%% | :) |\n"+
"+-----------+-------+----------+\n"
table := CreateTable()
table.AddHeaders("Name", "Value", "Value 2")
table.AddRow("hey", "you", "man")
table.AddRow("ken", 1234, 4321)
table.AddRow("derek", 3.14, "bob")
table.AddRow("derek too", 3.1456788, "long bob")
table.AddRow("escaping", "rox%%", ":)")
table.SetAlign(AlignRight, 2, 3)
checkRendersTo(t, table, expected)
}
func TestTableMissingCells(t *testing.T) {
expected := "" +
"+----------+---------+---------+\n" +
"| Name | Value 1 | Value 2 |\n" +
"+----------+---------+---------+\n" +
"| hey | you | person |\n" +
"| ken | 1234 |\n" +
"| escaping | rox%s%% |\n" +
"+----------+---------+---------+\n"
// FIXME: missing extra cells there
table := CreateTable()
table.AddHeaders("Name", "Value 1", "Value 2")
table.AddRow("hey", "you", "person")
table.AddRow("ken", 1234)
table.AddRow("escaping", "rox%s%%")
checkRendersTo(t, table, expected)
}
// We don't yet support combining characters, double-width characters or
// anything to do with estimating a tty-style "character width" for what in
// Unicode is a grapheme cluster. This disabled test shows what we want
// to support, but don't yet.
func TestTableWithCombiningChars(t *testing.T) {
expected := "" +
"+------+---+\n" +
"| noel | 1 |\n" +
"| noël | 2 |\n" +
"| noël | 3 |\n" +
"+------+---+\n"
table := CreateTable()
table.AddRow("noel", "1")
table.AddRow("noe\u0308l", "2") // LATIN SMALL LETTER E + COMBINING DIAERESIS
table.AddRow("noël", "3") // Hex EB; LATIN SMALL LETTER E WITH DIAERESIS
checkRendersTo(t, table, expected)
}
// another unicode length issue
func TestTableWithFullwidthChars(t *testing.T) {
expected := "" +
"+----------+------------+\n" +
"| wide | not really |\n" +
"| | fullwidth |\n" +
"+----------+------------+\n"
table := CreateTable()
table.AddRow("wide", "not really")
table.AddRow("", "fullwidth") // FULLWIDTH LATIN SMALL LETTER <X>
checkRendersTo(t, table, expected)
}
// Tests CJK characters using examples given in issue #33. The examples may not
// look like they line up but you can visually confirm its accuracy with a
// fmt.Print.
func TestCJKChars(t *testing.T) {
expected := "" +
"+-------+---------+----------+\n" +
"| KeyID | ValueID | ValueCN |\n" +
"+-------+---------+----------+\n" +
"| 8 | 51 | 精钢 |\n" +
"| 8 | 52 | 鳄鱼皮 |\n" +
"| 8 | 53 | 镀金皮带 |\n" +
"| 8 | 54 | 精钢 |\n" +
"+-------+---------+----------+\n"
table := CreateTable()
table.AddHeaders("KeyID", "ValueID", "ValueCN")
table.AddRow("8", 51, "精钢")
table.AddRow("8", 52, "鳄鱼皮")
table.AddRow("8", 53, "镀金皮带")
table.AddRow("8", 54, "精钢")
checkRendersTo(t, table, expected)
expected2 := "" +
"+--------------------+----------------------+\n" +
"| field | value |\n" +
"+--------------------+----------------------+\n" +
"| GoodsPropertyKeyID | 9 |\n" +
"| MerchantAccountID | 0 |\n" +
"| GoodsCategoryCode | 100001 |\n" +
"| NameCN | 机芯类型 |\n" +
"| NameJP | ムーブメントのタイプ |\n" +
"+--------------------+----------------------+\n"
table = CreateTable()
table.AddHeaders("field", "value")
table.AddRow("GoodsPropertyKeyID", 9)
table.AddRow("MerchantAccountID", 0)
table.AddRow("GoodsCategoryCode", 100001)
table.AddRow("NameCN", "机芯类型")
table.AddRow("NameJP", "ムーブメントのタイプ")
checkRendersTo(t, table, expected2)
}
func TestTableMultipleAddHeader(t *testing.T) {
expected := "" +
"+--------------+--------+-------+\n" +
"| First column | Second | Third |\n" +
"+--------------+--------+-------+\n" +
"| 2 | 3 | 5 |\n" +
"+--------------+--------+-------+\n"
table := CreateTable()
table.AddHeaders("First column", "Second")
table.AddHeaders("Third")
table.AddRow(2, 3, 5)
checkRendersTo(t, table, expected)
}
func createTestTable() *Table {
table := CreateTable()
header := []interface{}{}
for i := 0; i < 50; i++ {
header = append(header, "First Column")
}
table.AddHeaders(header...)
for i := 0; i < 3000; i++ {
row := []interface{}{}
for i := 0; i < 50; i++ {
row = append(row, "First row value")
}
table.AddRow(row...)
}
return table
}
func BenchmarkTableRenderTerminal(b *testing.B) {
table := createTestTable()
table.SetModeTerminal()
b.ResetTimer()
for i := 0; i < b.N; i++ {
table.Render()
}
}
func BenchmarkTableRenderMarkdown(b *testing.B) {
table := createTestTable()
table.SetModeMarkdown()
b.ResetTimer()
for i := 0; i < b.N; i++ {
table.Render()
}
}
func BenchmarkTableRenderHTML(b *testing.B) {
table := createTestTable()
table.SetModeHTML()
b.ResetTimer()
for i := 0; i < b.N; i++ {
table.Render()
}
}

View File

@@ -0,0 +1,43 @@
// Copyright 2013 Apcera Inc. All rights reserved.
package term
import (
"os"
"strconv"
)
// GetEnvWindowSize returns the window Size, as determined by process
// environment; if either LINES or COLUMNS is present, and whichever is
// present is also numeric, the Size will be non-nil. If Size is nil,
// there's insufficient data in environ. If one entry is 0, that means
// that the environment does not include that data. If a value is
// negative, we treat that as an error.
func GetEnvWindowSize() *Size {
lines := os.Getenv("LINES")
columns := os.Getenv("COLUMNS")
if lines == "" && columns == "" {
return nil
}
nLines := 0
nColumns := 0
var err error
if lines != "" {
nLines, err = strconv.Atoi(lines)
if err != nil || nLines < 0 {
return nil
}
}
if columns != "" {
nColumns, err = strconv.Atoi(columns)
if err != nil || nColumns < 0 {
return nil
}
}
return &Size{
Lines: nLines,
Columns: nColumns,
}
}

View File

@@ -0,0 +1,54 @@
// Copyright 2013 Apcera Inc. All rights reserved.
package term
import (
"os"
)
// Size is the size of a terminal, expressed in character cells, as Lines and
// Columns. This might come from environment variables or OS-dependent
// resources.
type Size struct {
Lines int
Columns int
}
// GetSize will return the terminal window size.
//
// We prefer environ $LINES/$COLUMNS, then fall back to tty-held information.
// We do not support use of termcap/terminfo to derive default size information.
func GetSize() (*Size, error) {
envSize := GetEnvWindowSize()
if envSize != nil && envSize.Lines != 0 && envSize.Columns != 0 {
return envSize, nil
}
fh, err := os.Open("/dev/tty")
if err != nil {
// no tty, no point continuing; we only let the environ
// avoid an error in this case because if someone has faked
// up an environ with LINES/COLUMNS _both_ set, we should let
// them
return nil, err
}
size, err := GetTerminalWindowSize(fh)
if err != nil {
if envSize != nil {
return envSize, nil
}
return nil, err
}
if envSize == nil {
return size, err
}
if envSize.Lines == 0 {
envSize.Lines = size.Lines
}
if envSize.Columns == 0 {
envSize.Columns = size.Columns
}
return envSize, nil
}

View File

@@ -0,0 +1,35 @@
// Copyright 2013 Apcera Inc. All rights reserved.
// +build !windows
package term
import (
"errors"
"os"
"syscall"
"unsafe"
)
// ErrGetWinsizeFailed indicates that the system call to extract the size of
// a Unix tty from the kernel failed.
var ErrGetWinsizeFailed = errors.New("term: syscall.TIOCGWINSZ failed")
// GetTerminalWindowSize returns the terminal size maintained by the kernel
// for a Unix TTY, passed in as an *os.File. This information can be seen
// with the stty(1) command, and changes in size (eg, terminal emulator
// resized) should trigger a SIGWINCH signal delivery to the foreground process
// group at the time of the change, so a long-running process might reasonably
// watch for SIGWINCH and arrange to re-fetch the size when that happens.
func GetTerminalWindowSize(file *os.File) (*Size, error) {
// Based on source from from golang.org/x/crypto/ssh/terminal/util.go
var dimensions [4]uint16
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
return nil, err
}
return &Size{
Lines: int(dimensions[0]),
Columns: int(dimensions[1]),
}, nil
}

View File

@@ -0,0 +1,57 @@
// Copyright 2013 Apcera Inc. All rights reserved.
// +build windows
package term
// Used when we have no other source for getting platform-specific information
// about the terminal sizes available.
import (
"os"
"syscall"
"unsafe"
)
// Based on source from from golang.org/x/crypto/ssh/terminal/util_windows.go
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
type (
short int16
word uint16
coord struct {
x short
y short
}
smallRect struct {
left short
top short
right short
bottom short
}
consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes word
window smallRect
maximumWindowSize coord
}
)
// GetTerminalWindowSize returns the width and height of a terminal in Windows.
func GetTerminalWindowSize(file *os.File) (*Size, error) {
var info consoleScreenBufferInfo
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, file.Fd(), uintptr(unsafe.Pointer(&info)), 0)
if e != 0 {
return nil, error(e)
}
return &Size{
Lines: int(info.size.y),
Columns: int(info.size.x),
}, nil
}

View File

@@ -0,0 +1,23 @@
// Copyright 2013 Apcera Inc. All rights reserved.
//go:build ignore
// +build ignore
package main
import (
"fmt"
"os"
"github.com/apcera/termtables/term"
)
func main() {
size, err := term.GetSize()
if err != nil {
fmt.Fprintf(os.Stderr, "failed: %s\n", err)
os.Exit(1)
}
fmt.Printf("Lines %d Columns %d\n", size.Lines, size.Columns)
}

View File

@@ -0,0 +1,10 @@
language: go
sudo: false
script: go test -v
go:
- 1.3
- 1.12
- 1.13
- tip

View File

@@ -0,0 +1,19 @@
Copyright (c) 2015 Ryan Hileman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -0,0 +1,104 @@
[![Build Status](https://travis-ci.org/lunixbochs/struc.svg?branch=master)](https://travis-ci.org/lunixbochs/struc) [![GoDoc](https://godoc.org/github.com/lunixbochs/struc?status.svg)](https://godoc.org/github.com/lunixbochs/struc)
struc
====
Struc exists to pack and unpack C-style structures from bytes, which is useful for binary files and network protocols. It could be considered an alternative to `encoding/binary`, which requires massive boilerplate for some similar operations.
Take a look at an [example comparing `struc` and `encoding/binary`](https://bochs.info/p/cxvm9)
Struc considers usability first. That said, it does cache reflection data and aims to be competitive with `encoding/binary` struct packing in every way, including performance.
Example struct
----
```Go
type Example struct {
Var int `struc:"int32,sizeof=Str"`
Str string
Weird []byte `struc:"[8]int64"`
Weird []byte `struc:"[32]byte"`
Var []int `struc:"[]int32,little"`
}
```
Struct tag format
----
- ```Var []int `struc:"[]int32,little,sizeof=StringField"` ``` will pack Var as a slice of little-endian int32, and link it as the size of `StringField`.
- `sizeof=`: Indicates this field is a number used to track the length of a another field. `sizeof` fields are automatically updated on `Pack()` based on the current length of the tracked field, and are used to size the target field during `Unpack()`.
- Bare values will be parsed as type and endianness.
Endian formats
----
- `big` (default)
- `little`
Recognized types
----
- `pad` - this type ignores field contents and is backed by a `[length]byte` containing nulls
- `bool`
- `byte`
- `int8`, `uint8`
- `int16`, `uint16`
- `int32`, `uint32`
- `int64`, `uint64`
- `float32`
- `float64`
Types can be indicated as arrays/slices using `[]` syntax. Example: `[]int64`, `[8]int32`.
Bare slice types (those with no `[size]`) must have a linked `Sizeof` field.
Private fields are ignored when packing and unpacking.
Example code
----
```Go
package main
import (
"bytes"
"github.com/lunixbochs/struc"
)
type Example struct {
A int `struc:"big"`
// B will be encoded/decoded as a 16-bit int (a "short")
// but is stored as a native int in the struct
B int `struc:"int16"`
// the sizeof key links a buffer's size to any int field
Size int `struc:"int8,little,sizeof=Str"`
Str string
// you can get freaky if you want
Str2 string `struc:"[5]int64"`
}
func main() {
var buf bytes.Buffer
t := &Example{1, 2, 0, "test", "test2"}
err := struc.Pack(&buf, t)
o := &Example{}
err = struc.Unpack(&buf, o)
}
```
Benchmark
----
`BenchmarkEncode` uses struc. `Stdlib` benchmarks use equivalent `encoding/binary` code. `Manual` encodes without any reflection, and should be considered an upper bound on performance (which generated code based on struc definitions should be able to achieve).
```
BenchmarkEncode 1000000 1265 ns/op
BenchmarkStdlibEncode 1000000 1855 ns/op
BenchmarkManualEncode 5000000 284 ns/op
BenchmarkDecode 1000000 1259 ns/op
BenchmarkStdlibDecode 1000000 1656 ns/op
BenchmarkManualDecode 20000000 89.0 ns/op
```

View File

@@ -0,0 +1,203 @@
package struc
import (
"bytes"
"encoding/binary"
"testing"
)
type BenchExample struct {
Test [5]byte
A int32
B, C, D int16
Test2 [4]byte
Length int32
}
func BenchmarkArrayEncode(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
if err := Pack(&buf, arrayReference); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSliceEncode(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
if err := Pack(&buf, sliceReference); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkArrayDecode(b *testing.B) {
var out ExampleArray
for i := 0; i < b.N; i++ {
buf := bytes.NewBuffer(arraySliceReferenceBytes)
if err := Unpack(buf, &out); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSliceDecode(b *testing.B) {
var out ExampleSlice
for i := 0; i < b.N; i++ {
buf := bytes.NewBuffer(arraySliceReferenceBytes)
if err := Unpack(buf, &out); err != nil {
b.Fatal(err)
}
}
}
type BenchStrucExample struct {
Test [5]byte `struc:"[5]byte"`
A int `struc:"int32"`
B, C, D int `struc:"int16"`
Test2 [4]byte `struc:"[4]byte"`
Length int `struc:"int32,sizeof=Data"`
Data []byte
}
var benchRef = &BenchExample{
[5]byte{1, 2, 3, 4, 5},
1, 2, 3, 4,
[4]byte{1, 2, 3, 4},
8,
}
var eightBytes = []byte("8bytestr")
var benchStrucRef = &BenchStrucExample{
[5]byte{1, 2, 3, 4, 5},
1, 2, 3, 4,
[4]byte{1, 2, 3, 4},
8, eightBytes,
}
func BenchmarkEncode(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
err := Pack(&buf, benchStrucRef)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkStdlibEncode(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
err := binary.Write(&buf, binary.BigEndian, benchRef)
if err != nil {
b.Fatal(err)
}
_, err = buf.Write(eightBytes)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkManualEncode(b *testing.B) {
order := binary.BigEndian
s := benchStrucRef
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
tmp := make([]byte, 29)
copy(tmp[0:5], s.Test[:])
order.PutUint32(tmp[5:9], uint32(s.A))
order.PutUint16(tmp[9:11], uint16(s.B))
order.PutUint16(tmp[11:13], uint16(s.C))
order.PutUint16(tmp[13:15], uint16(s.D))
copy(tmp[15:19], s.Test2[:])
order.PutUint32(tmp[19:23], uint32(s.Length))
copy(tmp[23:], s.Data)
_, err := buf.Write(tmp)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDecode(b *testing.B) {
var out BenchStrucExample
var buf bytes.Buffer
if err := Pack(&buf, benchStrucRef); err != nil {
b.Fatal(err)
}
bufBytes := buf.Bytes()
for i := 0; i < b.N; i++ {
buf := bytes.NewReader(bufBytes)
err := Unpack(buf, &out)
if err != nil {
b.Fatal(err)
}
out.Data = nil
}
}
func BenchmarkStdlibDecode(b *testing.B) {
var out BenchExample
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, *benchRef)
_, err := buf.Write(eightBytes)
if err != nil {
b.Fatal(err)
}
bufBytes := buf.Bytes()
for i := 0; i < b.N; i++ {
buf := bytes.NewReader(bufBytes)
err := binary.Read(buf, binary.BigEndian, &out)
if err != nil {
b.Fatal(err)
}
tmp := make([]byte, out.Length)
_, err = buf.Read(tmp)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkManualDecode(b *testing.B) {
var o BenchStrucExample
var buf bytes.Buffer
if err := Pack(&buf, benchStrucRef); err != nil {
b.Fatal(err)
}
tmp := buf.Bytes()
order := binary.BigEndian
for i := 0; i < b.N; i++ {
copy(o.Test[:], tmp[0:5])
o.A = int(order.Uint32(tmp[5:9]))
o.B = int(order.Uint16(tmp[9:11]))
o.C = int(order.Uint16(tmp[11:13]))
o.D = int(order.Uint16(tmp[13:15]))
copy(o.Test2[:], tmp[15:19])
o.Length = int(order.Uint32(tmp[19:23]))
o.Data = make([]byte, o.Length)
copy(o.Data, tmp[23:])
}
}
func BenchmarkFullEncode(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf bytes.Buffer
if err := Pack(&buf, reference); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkFullDecode(b *testing.B) {
var out Example
for i := 0; i < b.N; i++ {
buf := bytes.NewBuffer(referenceBytes)
if err := Unpack(buf, &out); err != nil {
b.Fatal(err)
}
}
}

View File

@@ -0,0 +1,52 @@
package struc
import (
"encoding/binary"
"io"
"reflect"
)
type byteWriter struct {
buf []byte
pos int
}
func (b byteWriter) Write(p []byte) (int, error) {
capacity := len(b.buf) - b.pos
if capacity < len(p) {
p = p[:capacity]
}
if len(p) > 0 {
copy(b.buf[b.pos:], p)
b.pos += len(p)
}
return len(p), nil
}
type binaryFallback reflect.Value
func (b binaryFallback) String() string {
return b.String()
}
func (b binaryFallback) Sizeof(val reflect.Value, options *Options) int {
return binary.Size(val.Interface())
}
func (b binaryFallback) Pack(buf []byte, val reflect.Value, options *Options) (int, error) {
tmp := byteWriter{buf: buf}
var order binary.ByteOrder = binary.BigEndian
if options.Order != nil {
order = options.Order
}
err := binary.Write(tmp, order, val.Interface())
return tmp.pos, err
}
func (b binaryFallback) Unpack(r io.Reader, val reflect.Value, options *Options) error {
var order binary.ByteOrder = binary.BigEndian
if options.Order != nil {
order = options.Order
}
return binary.Read(r, order, val.Interface())
}

View File

@@ -0,0 +1,33 @@
package struc
import (
"io"
"reflect"
)
type Custom interface {
Pack(p []byte, opt *Options) (int, error)
Unpack(r io.Reader, length int, opt *Options) error
Size(opt *Options) int
String() string
}
type customFallback struct {
custom Custom
}
func (c customFallback) Pack(p []byte, val reflect.Value, opt *Options) (int, error) {
return c.custom.Pack(p, opt)
}
func (c customFallback) Unpack(r io.Reader, val reflect.Value, opt *Options) error {
return c.custom.Unpack(r, 1, opt)
}
func (c customFallback) Sizeof(val reflect.Value, opt *Options) int {
return c.custom.Size(opt)
}
func (c customFallback) String() string {
return c.custom.String()
}

View File

@@ -0,0 +1,78 @@
package struc
import (
"encoding/binary"
"io"
"math"
"strconv"
)
type Float16 float64
func (f *Float16) Pack(p []byte, opt *Options) (int, error) {
order := opt.Order
if order == nil {
order = binary.BigEndian
}
sign := uint16(0)
if *f < 0 {
sign = 1
}
var frac, exp uint16
if math.IsInf(float64(*f), 0) {
exp = 0x1f
frac = 0
} else if math.IsNaN(float64(*f)) {
exp = 0x1f
frac = 1
} else {
bits := math.Float64bits(float64(*f))
exp64 := (bits >> 52) & 0x7ff
if exp64 != 0 {
exp = uint16((exp64 - 1023 + 15) & 0x1f)
}
frac = uint16((bits >> 42) & 0x3ff)
}
var out uint16
out |= sign << 15
out |= exp << 10
out |= frac & 0x3ff
order.PutUint16(p, out)
return 2, nil
}
func (f *Float16) Unpack(r io.Reader, length int, opt *Options) error {
order := opt.Order
if order == nil {
order = binary.BigEndian
}
var tmp [2]byte
if _, err := r.Read(tmp[:]); err != nil {
return err
}
val := order.Uint16(tmp[:2])
sign := (val >> 15) & 1
exp := int16((val >> 10) & 0x1f)
frac := val & 0x3ff
if exp == 0x1f {
if frac != 0 {
*f = Float16(math.NaN())
} else {
*f = Float16(math.Inf(int(sign)*-2 + 1))
}
} else {
var bits uint64
bits |= uint64(sign) << 63
bits |= uint64(frac) << 42
if exp > 0 {
bits |= uint64(exp-15+1023) << 52
}
*f = Float16(math.Float64frombits(bits))
}
return nil
}
func (f *Float16) Size(opt *Options) int {
return 2
}
func (f *Float16) String() string {
return strconv.FormatFloat(float64(*f), 'g', -1, 32)
}

View File

@@ -0,0 +1,56 @@
package struc
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"testing"
)
func TestFloat16(t *testing.T) {
// test cases from https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Half_precision_examples
tests := []struct {
B string
F float64
}{
//s expnt significand
{"0 01111 0000000000", 1},
{"0 01111 0000000001", 1.0009765625},
{"1 10000 0000000000", -2},
{"0 11110 1111111111", 65504},
// {"0 00001 0000000000", 0.0000610352},
// {"0 00000 1111111111", 0.0000609756},
// {"0 00000 0000000001", 0.0000000596046},
{"0 00000 0000000000", 0},
// {"1 00000 0000000000", -0},
{"0 11111 0000000000", math.Inf(1)},
{"1 11111 0000000000", math.Inf(-1)},
{"0 01101 0101010101", 0.333251953125},
}
for _, test := range tests {
var buf bytes.Buffer
f := Float16(test.F)
if err := Pack(&buf, &f); err != nil {
t.Error("pack failed:", err)
continue
}
bitval, _ := strconv.ParseUint(strings.Replace(test.B, " ", "", -1), 2, 16)
tmp := binary.BigEndian.Uint16(buf.Bytes())
if tmp != uint16(bitval) {
t.Errorf("incorrect pack: %s != %016b (%f)", test.B, tmp, test.F)
continue
}
var f2 Float16
if err := Unpack(&buf, &f2); err != nil {
t.Error("unpack failed:", err)
continue
}
// let sprintf deal with (im)precision for me here
if fmt.Sprintf("%f", f) != fmt.Sprintf("%f", f2) {
t.Errorf("incorrect unpack: %016b %f != %f", bitval, f, f2)
}
}
}

View File

@@ -0,0 +1,360 @@
package struc
import (
"bytes"
"encoding/binary"
"io"
"reflect"
"strconv"
"testing"
)
// Custom Type
type Int3 uint32
// newInt3 returns a pointer to an Int3
func newInt3(in int) *Int3 {
i := Int3(in)
return &i
}
type Int3Struct struct {
I Int3
}
func (i *Int3) Pack(p []byte, opt *Options) (int, error) {
var tmp [4]byte
binary.BigEndian.PutUint32(tmp[:], uint32(*i))
copy(p, tmp[1:])
return 3, nil
}
func (i *Int3) Unpack(r io.Reader, length int, opt *Options) error {
var tmp [4]byte
if _, err := r.Read(tmp[1:]); err != nil {
return err
}
*i = Int3(binary.BigEndian.Uint32(tmp[:]))
return nil
}
func (i *Int3) Size(opt *Options) int {
return 3
}
func (i *Int3) String() string {
return strconv.FormatUint(uint64(*i), 10)
}
// Array of custom type
type ArrayInt3Struct struct {
I [2]Int3
}
// Custom type of array of standard type
type DoubleUInt8 [2]uint8
type DoubleUInt8Struct struct {
I DoubleUInt8
}
func (di *DoubleUInt8) Pack(p []byte, opt *Options) (int, error) {
for i, value := range *di {
p[i] = value
}
return 2, nil
}
func (di *DoubleUInt8) Unpack(r io.Reader, length int, opt *Options) error {
for i := 0; i < 2; i++ {
var value uint8
if err := binary.Read(r, binary.LittleEndian, &value); err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
di[i] = value
}
return nil
}
func (di *DoubleUInt8) Size(opt *Options) int {
return 2
}
func (di *DoubleUInt8) String() string {
panic("not implemented")
}
// Custom type of array of custom type
type DoubleInt3 [2]Int3
type DoubleInt3Struct struct {
D DoubleInt3
}
func (di *DoubleInt3) Pack(p []byte, opt *Options) (int, error) {
var out []byte
for _, value := range *di {
tmp := make([]byte, 3)
if _, err := value.Pack(tmp, opt); err != nil {
return 0, err
}
out = append(out, tmp...)
}
copy(p, out)
return 6, nil
}
func (di *DoubleInt3) Unpack(r io.Reader, length int, opt *Options) error {
for i := 0; i < 2; i++ {
di[i].Unpack(r, 0, opt)
}
return nil
}
func (di *DoubleInt3) Size(opt *Options) int {
return 6
}
func (di *DoubleInt3) String() string {
panic("not implemented")
}
// Custom type of slice of standard type
// Slice of uint8, stored in a zero terminated list.
type SliceUInt8 []uint8
type SliceUInt8Struct struct {
I SliceUInt8
N uint8 // A field after to ensure the length is correct.
}
func (ia *SliceUInt8) Pack(p []byte, opt *Options) (int, error) {
for i, value := range *ia {
p[i] = value
}
return len(*ia) + 1, nil
}
func (ia *SliceUInt8) Unpack(r io.Reader, length int, opt *Options) error {
for {
var value uint8
if err := binary.Read(r, binary.LittleEndian, &value); err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
if value == 0 {
break
}
*ia = append(*ia, value)
}
return nil
}
func (ia *SliceUInt8) Size(opt *Options) int {
return len(*ia) + 1
}
func (ia *SliceUInt8) String() string {
panic("not implemented")
}
type ArrayOfUInt8Struct struct {
I [2]uint8
}
// Custom 4-character fixed string, similar to CHAR(4) in SQL.
type Char4 string
func (*Char4) Size(opt *Options) int {
return 4
}
func (c *Char4) Pack(p []byte, opt *Options) (int, error) {
buf := []byte(*c)
buf = append(buf, make([]byte, c.Size(nil)-len(buf))...)
copy(p, buf)
return len(buf), nil
}
func (c *Char4) Unpack(r io.Reader, length int, opt *Options) error {
buf := bytes.Buffer{}
if _, err := buf.ReadFrom(r); err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
*c = Char4(buf.String())
return nil
}
func (c *Char4) String() string {
return string(*c)
}
type Char4Struct struct {
C Char4
}
func TestCustomTypes(t *testing.T) {
testCases := []struct {
name string
packObj interface{}
emptyObj interface{}
expectBytes []byte
skip bool // Skip the test, because it fails.
// Switch to expectFail when possible:
// https://github.com/golang/go/issues/25951
}{
// Start tests with unimplemented non-custom types.
{
name: "ArrayOfUInt8",
packObj: [2]uint8{32, 64},
emptyObj: [2]uint8{0, 0},
expectBytes: []byte{32, 64},
skip: true, // Not implemented.
},
{
name: "PointerToArrayOfUInt8",
packObj: &[2]uint8{32, 64},
emptyObj: &[2]uint8{0, 0},
expectBytes: []byte{32, 64},
skip: true, // Not implemented.
},
{
name: "ArrayOfUInt8Struct",
packObj: &ArrayOfUInt8Struct{I: [2]uint8{32, 64}},
emptyObj: &ArrayOfUInt8Struct{},
expectBytes: []byte{32, 64},
},
{
name: "CustomType",
packObj: newInt3(3),
emptyObj: newInt3(0),
expectBytes: []byte{0, 0, 3},
},
{
name: "CustomType-Big",
packObj: newInt3(4000),
emptyObj: newInt3(0),
expectBytes: []byte{0, 15, 160},
},
{
name: "CustomTypeStruct",
packObj: &Int3Struct{3},
emptyObj: &Int3Struct{},
expectBytes: []byte{0, 0, 3},
},
{
name: "ArrayOfCustomType",
packObj: [2]Int3{3, 4},
emptyObj: [2]Int3{},
expectBytes: []byte{0, 0, 3, 0, 0, 4},
skip: true, // Not implemented.
},
{
name: "PointerToArrayOfCustomType",
packObj: &[2]Int3{3, 4},
emptyObj: &[2]Int3{},
expectBytes: []byte{0, 0, 3, 0, 0, 4},
skip: true, // Not implemented.
},
{
name: "ArrayOfCustomTypeStruct",
packObj: &ArrayInt3Struct{[2]Int3{3, 4}},
emptyObj: &ArrayInt3Struct{},
expectBytes: []byte{0, 0, 3, 0, 0, 4},
skip: true, // Not implemented.
},
{
name: "CustomTypeOfArrayOfUInt8",
packObj: &DoubleUInt8{32, 64},
emptyObj: &DoubleUInt8{},
expectBytes: []byte{32, 64},
},
{
name: "CustomTypeOfArrayOfUInt8Struct",
packObj: &DoubleUInt8Struct{I: DoubleUInt8{32, 64}},
emptyObj: &DoubleUInt8Struct{},
expectBytes: []byte{32, 64},
skip: true, // Not implemented.
},
{
name: "CustomTypeOfArrayOfCustomType",
packObj: &DoubleInt3{Int3(128), Int3(256)},
emptyObj: &DoubleInt3{},
expectBytes: []byte{0, 0, 128, 0, 1, 0},
},
{
name: "CustomTypeOfArrayOfCustomTypeStruct",
packObj: &DoubleInt3Struct{D: DoubleInt3{Int3(128), Int3(256)}},
emptyObj: &DoubleInt3Struct{},
expectBytes: []byte{0, 0, 128, 0, 1, 0},
skip: true, // Not implemented.
},
{
name: "CustomTypeOfSliceOfUInt8",
packObj: &SliceUInt8{128, 64, 32},
emptyObj: &SliceUInt8{},
expectBytes: []byte{128, 64, 32, 0},
},
{
name: "CustomTypeOfSliceOfUInt8-Empty",
packObj: &SliceUInt8{},
emptyObj: &SliceUInt8{},
expectBytes: []byte{0},
},
{
name: "CustomTypeOfSliceOfUInt8Struct",
packObj: &SliceUInt8Struct{I: SliceUInt8{128, 64, 32}, N: 192},
emptyObj: &SliceUInt8Struct{},
expectBytes: []byte{128, 64, 32, 0, 192},
skip: true, // Not implemented.
},
{
name: "CustomTypeOfChar4Struct",
packObj: &Char4Struct{C: Char4("foo\x00")},
emptyObj: &Char4Struct{},
expectBytes: []byte{102, 111, 111, 0},
},
}
for _, test := range testCases {
// TODO: Switch to t.Run() when Go 1.7 is the minimum supported version.
t.Log("RUN ", test.name)
runner := func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Log("unexpected panic:", r)
t.Error(r)
}
}()
if test.skip {
// TODO: Switch to t.Skip() when Go 1.7 is supported
t.Log("skipped unimplemented")
return
}
var buf bytes.Buffer
if err := Pack(&buf, test.packObj); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), test.expectBytes) {
t.Fatal("error packing, expect:", test.expectBytes, "found:", buf.Bytes())
}
if err := Unpack(&buf, test.emptyObj); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(test.packObj, test.emptyObj) {
t.Fatal("error unpacking, expect:", test.packObj, "found:", test.emptyObj)
}
}
runner(t)
}
}

288
common/utils/sturc/field.go Normal file
View File

@@ -0,0 +1,288 @@
package struc
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"reflect"
)
type Field struct {
Name string
Ptr bool
Index int
Type Type
defType Type
Array bool
Slice bool
Len int
Order binary.ByteOrder
Sizeof []int
Sizefrom []int
Fields Fields
kind reflect.Kind
}
func (f *Field) String() string {
var out string
if f.Type == Pad {
return fmt.Sprintf("{type: Pad, len: %d}", f.Len)
} else {
out = fmt.Sprintf("type: %s, order: %v", f.Type.String(), f.Order)
}
if f.Sizefrom != nil {
out += fmt.Sprintf(", sizefrom: %v", f.Sizefrom)
} else if f.Len > 0 {
out += fmt.Sprintf(", len: %d", f.Len)
}
if f.Sizeof != nil {
out += fmt.Sprintf(", sizeof: %v", f.Sizeof)
}
return "{" + out + "}"
}
func (f *Field) Size(val reflect.Value, options *Options) int {
typ := f.Type.Resolve(options)
size := 0
if typ == Struct {
vals := []reflect.Value{val}
if f.Slice {
vals = make([]reflect.Value, val.Len())
for i := 0; i < val.Len(); i++ {
vals[i] = val.Index(i)
}
}
for _, val := range vals {
size += f.Fields.Sizeof(val, options)
}
} else if typ == Pad {
size = f.Len
} else if typ == CustomType {
return val.Addr().Interface().(Custom).Size(options)
} else if f.Slice || f.kind == reflect.String {
length := val.Len()
if f.Len > 1 {
length = f.Len
}
size = length * typ.Size()
} else {
size = typ.Size()
}
align := options.ByteAlign
if align > 0 && size < align {
size = align
}
return size
}
func (f *Field) packVal(buf []byte, val reflect.Value, length int, options *Options) (size int, err error) {
order := f.Order
if options.Order != nil {
order = options.Order
}
if f.Ptr {
val = val.Elem()
}
typ := f.Type.Resolve(options)
switch typ {
case Struct:
return f.Fields.Pack(buf, val, options)
case Bool, Int8, Int16, Int32, Int64, Uint8, Uint16, Uint32, Uint64:
size = typ.Size()
var n uint64
switch f.kind {
case reflect.Bool:
if val.Bool() {
n = 1
} else {
n = 0
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n = uint64(val.Int())
default:
n = val.Uint()
}
switch typ {
case Bool:
if n != 0 {
buf[0] = 1
} else {
buf[0] = 0
}
case Int8, Uint8:
buf[0] = byte(n)
case Int16, Uint16:
order.PutUint16(buf, uint16(n))
case Int32, Uint32:
order.PutUint32(buf, uint32(n))
case Int64, Uint64:
order.PutUint64(buf, uint64(n))
}
case Float32, Float64:
size = typ.Size()
n := val.Float()
switch typ {
case Float32:
order.PutUint32(buf, math.Float32bits(float32(n)))
case Float64:
order.PutUint64(buf, math.Float64bits(n))
}
case String:
switch f.kind {
case reflect.String:
size = val.Len()
copy(buf, []byte(val.String()))
default:
// TODO: handle kind != bytes here
size = val.Len()
copy(buf, val.Bytes())
}
case CustomType:
return val.Addr().Interface().(Custom).Pack(buf, options)
default:
panic(fmt.Sprintf("no pack handler for type: %s", typ))
}
return
}
func (f *Field) Pack(buf []byte, val reflect.Value, length int, options *Options) (int, error) {
typ := f.Type.Resolve(options)
if typ == Pad {
for i := 0; i < length; i++ {
buf[i] = 0
}
return length, nil
}
if f.Slice {
// special case strings and byte slices for performance
end := val.Len()
if !f.Array && typ == Uint8 && (f.defType == Uint8 || f.kind == reflect.String) {
var tmp []byte
if f.kind == reflect.String {
tmp = []byte(val.String())
} else {
tmp = val.Bytes()
}
copy(buf, tmp)
if end < length {
// TODO: allow configuring pad byte?
rep := bytes.Repeat([]byte{0}, length-end)
copy(buf[end:], rep)
return length, nil
}
return val.Len(), nil
}
pos := 0
var zero reflect.Value
if end < length {
zero = reflect.Zero(val.Type().Elem())
}
for i := 0; i < length; i++ {
cur := zero
if i < end {
cur = val.Index(i)
}
if n, err := f.packVal(buf[pos:], cur, 1, options); err != nil {
return pos, err
} else {
pos += n
}
}
return pos, nil
} else {
return f.packVal(buf, val, length, options)
}
}
func (f *Field) unpackVal(buf []byte, val reflect.Value, length int, options *Options) error {
order := f.Order
if options.Order != nil {
order = options.Order
}
if f.Ptr {
val = val.Elem()
}
typ := f.Type.Resolve(options)
switch typ {
case Float32, Float64:
var n float64
switch typ {
case Float32:
n = float64(math.Float32frombits(order.Uint32(buf)))
case Float64:
n = math.Float64frombits(order.Uint64(buf))
}
switch f.kind {
case reflect.Float32, reflect.Float64:
val.SetFloat(n)
default:
return fmt.Errorf("struc: refusing to unpack float into field %s of type %s", f.Name, f.kind.String())
}
case Bool, Int8, Int16, Int32, Int64, Uint8, Uint16, Uint32, Uint64:
var n uint64
switch typ {
case Int8:
n = uint64(int64(int8(buf[0])))
case Int16:
n = uint64(int64(int16(order.Uint16(buf))))
case Int32:
n = uint64(int64(int32(order.Uint32(buf))))
case Int64:
n = uint64(int64(order.Uint64(buf)))
case Bool, Uint8:
n = uint64(buf[0])
case Uint16:
n = uint64(order.Uint16(buf))
case Uint32:
n = uint64(order.Uint32(buf))
case Uint64:
n = uint64(order.Uint64(buf))
}
switch f.kind {
case reflect.Bool:
val.SetBool(n != 0)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val.SetInt(int64(n))
default:
val.SetUint(n)
}
default:
panic(fmt.Sprintf("no unpack handler for type: %s", typ))
}
return nil
}
func (f *Field) Unpack(buf []byte, val reflect.Value, length int, options *Options) error {
typ := f.Type.Resolve(options)
if typ == Pad || f.kind == reflect.String {
if typ == Pad {
return nil
} else {
val.SetString(string(buf))
return nil
}
} else if f.Slice {
if val.Cap() < length {
val.Set(reflect.MakeSlice(val.Type(), length, length))
} else if val.Len() < length {
val.Set(val.Slice(0, length))
}
// special case byte slices for performance
if !f.Array && typ == Uint8 && f.defType == Uint8 {
copy(val.Bytes(), buf[:length])
return nil
}
pos := 0
size := typ.Size()
for i := 0; i < length; i++ {
if err := f.unpackVal(buf[pos:pos+size], val.Index(i), 1, options); err != nil {
return err
}
pos += size
}
return nil
} else {
return f.unpackVal(buf, val, length, options)
}
}

View File

@@ -0,0 +1,77 @@
package struc
import (
"bytes"
"testing"
)
type badFloat struct {
BadFloat int `struc:"float64"`
}
func TestBadFloatField(t *testing.T) {
buf := bytes.NewReader([]byte("00000000"))
err := Unpack(buf, &badFloat{})
if err == nil {
t.Fatal("failed to error on bad float unpack")
}
}
type emptyLengthField struct {
Strlen int `struc:"sizeof=Str"`
Str []byte
}
func TestEmptyLengthField(t *testing.T) {
var buf bytes.Buffer
s := &emptyLengthField{0, []byte("test")}
o := &emptyLengthField{}
if err := Pack(&buf, s); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, o); err != nil {
t.Fatal(err)
}
if !bytes.Equal(s.Str, o.Str) {
t.Fatal("empty length field encode failed")
}
}
type fixedSlicePad struct {
Field []byte `struc:"[4]byte"`
}
func TestFixedSlicePad(t *testing.T) {
var buf bytes.Buffer
ref := []byte{0, 0, 0, 0}
s := &fixedSlicePad{}
if err := Pack(&buf, s); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), ref) {
t.Fatal("implicit fixed slice pack failed")
}
if err := Unpack(&buf, s); err != nil {
t.Fatal(err)
}
if !bytes.Equal(s.Field, ref) {
t.Fatal("implicit fixed slice unpack failed")
}
}
type sliceCap struct {
Len int `struc:"sizeof=Field"`
Field []byte
}
func TestSliceCap(t *testing.T) {
var buf bytes.Buffer
tmp := &sliceCap{0, []byte("1234")}
if err := Pack(&buf, tmp); err != nil {
t.Fatal(err)
}
tmp.Field = make([]byte, 0, 4)
if err := Unpack(&buf, tmp); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,178 @@
package struc
import (
"encoding/binary"
"fmt"
"io"
"reflect"
"strings"
)
type Fields []*Field
func (f Fields) SetByteOrder(order binary.ByteOrder) {
for _, field := range f {
if field != nil {
field.Order = order
}
}
}
func (f Fields) String() string {
fields := make([]string, len(f))
for i, field := range f {
if field != nil {
fields[i] = field.String()
}
}
return "{" + strings.Join(fields, ", ") + "}"
}
func (f Fields) Sizeof(val reflect.Value, options *Options) int {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
size := 0
for i, field := range f {
if field != nil {
size += field.Size(val.Field(i), options)
}
}
return size
}
func (f Fields) sizefrom(val reflect.Value, index []int) int {
field := val.FieldByIndex(index)
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return int(field.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := int(field.Uint())
// all the builtin array length types are native int
// so this guards against weird truncation
if n < 0 {
return 0
}
return n
default:
name := val.Type().FieldByIndex(index).Name
panic(fmt.Sprintf("sizeof field %T.%s not an integer type", val.Interface(), name))
}
}
func (f Fields) Pack(buf []byte, val reflect.Value, options *Options) (int, error) {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
pos := 0
for i, field := range f {
if field == nil {
continue
}
v := val.Field(i)
length := field.Len
if field.Sizefrom != nil {
length = f.sizefrom(val, field.Sizefrom)
}
if length <= 0 && field.Slice {
length = v.Len()
}
if field.Sizeof != nil {
length := val.FieldByIndex(field.Sizeof).Len()
switch field.kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// allocating a new int here has fewer side effects (doesn't update the original struct)
// but it's a wasteful allocation
// the old method might work if we just cast the temporary int/uint to the target type
v = reflect.New(v.Type()).Elem()
v.SetInt(int64(length))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = reflect.New(v.Type()).Elem()
v.SetUint(uint64(length))
default:
panic(fmt.Sprintf("sizeof field is not int or uint type: %s, %s", field.Name, v.Type()))
}
}
if n, err := field.Pack(buf[pos:], v, length, options); err != nil {
return n, err
} else {
pos += n
//g.Dump(v)
}
}
return pos, nil
}
func (f Fields) Unpack(r io.Reader, val reflect.Value, options *Options) error {
for val.Kind() == reflect.Ptr {
val = val.Elem()
}
var tmp [8]byte
var buf []byte
for i, field := range f {
if field == nil {
continue
}
v := val.Field(i)
length := field.Len
if field.Sizefrom != nil {
length = f.sizefrom(val, field.Sizefrom)
}
if v.Kind() == reflect.Ptr && !v.Elem().IsValid() {
v.Set(reflect.New(v.Type().Elem()))
}
if field.Type == Struct {
if field.Slice {
vals := v
if !field.Array {
vals = reflect.MakeSlice(v.Type(), length, length)
}
for i := 0; i < length; i++ {
v := vals.Index(i)
fields, err := parseFields(v)
if err != nil {
return err
}
if err := fields.Unpack(r, v, options); err != nil {
return err
}
}
if !field.Array {
v.Set(vals)
}
} else {
// TODO: DRY (we repeat the inner loop above)
fields, err := parseFields(v)
if err != nil {
return err
}
if err := fields.Unpack(r, v, options); err != nil {
return err
}
}
continue
} else {
typ := field.Type.Resolve(options)
if typ == CustomType {
if err := v.Addr().Interface().(Custom).Unpack(r, length, options); err != nil {
return err
}
} else {
size := length * field.Type.Resolve(options).Size()
if size < 8 {
buf = tmp[:size]
} else {
buf = make([]byte, size)
}
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
err := field.Unpack(buf[:size], v, length, options)
if err != nil {
return err
}
}
}
}
return nil
}

View File

@@ -0,0 +1,81 @@
package struc
import (
"bytes"
"reflect"
"testing"
)
var refVal = reflect.ValueOf(reference)
func TestFieldsParse(t *testing.T) {
if _, err := parseFields(refVal); err != nil {
t.Fatal(err)
}
}
func TestFieldsString(t *testing.T) {
fields, _ := parseFields(refVal)
fields.String()
}
type sizefromStruct struct {
Size1 uint `struc:"sizeof=Var1"`
Var1 []byte
Size2 int `struc:"sizeof=Var2"`
Var2 []byte
}
func TestFieldsSizefrom(t *testing.T) {
var test = sizefromStruct{
Var1: []byte{1, 2, 3},
Var2: []byte{4, 5, 6},
}
var buf bytes.Buffer
err := Pack(&buf, &test)
if err != nil {
t.Fatal(err)
}
err = Unpack(&buf, &test)
if err != nil {
t.Fatal(err)
}
}
type sizefromStructBad struct {
Size1 string `struc:"sizeof=Var1"`
Var1 []byte
}
func TestFieldsSizefromBad(t *testing.T) {
var test = &sizefromStructBad{Var1: []byte{1, 2, 3}}
var buf bytes.Buffer
defer func() {
if err := recover(); err == nil {
t.Fatal("failed to panic on bad sizeof type")
}
}()
Pack(&buf, &test)
}
type StructWithinArray struct {
a uint32
}
type StructHavingArray struct {
Props [1]StructWithinArray `struc:"[1]StructWithinArray"`
}
func TestStrucArray(t *testing.T) {
var buf bytes.Buffer
a := &StructHavingArray{[1]StructWithinArray{}}
err := Pack(&buf, a)
if err != nil {
t.Fatal(err)
}
b := &StructHavingArray{}
err = Unpack(&buf, b)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,3 @@
module github.com/lunixbochs/struc
go 1.12

View File

@@ -0,0 +1,16 @@
package struc
import (
"encoding/binary"
"io"
)
// Deprecated. Use PackWithOptions.
func PackWithOrder(w io.Writer, data interface{}, order binary.ByteOrder) error {
return PackWithOptions(w, data, &Options{Order: order})
}
// Deprecated. Use UnpackWithOptions.
func UnpackWithOrder(r io.Reader, data interface{}, order binary.ByteOrder) error {
return UnpackWithOptions(r, data, &Options{Order: order})
}

View File

@@ -0,0 +1,123 @@
package struc
import (
"bytes"
"fmt"
"testing"
)
var packableReference = []byte{
1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 5, 0, 6, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 8,
9, 10, 11, 12, 13, 14, 15, 16,
0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24,
}
func TestPackable(t *testing.T) {
var (
buf bytes.Buffer
i8 int8 = 1
i16 int16 = 2
i32 int32 = 3
i64 int64 = 4
u8 uint8 = 5
u16 uint16 = 6
u32 uint32 = 7
u64 uint64 = 8
u8a = [8]uint8{9, 10, 11, 12, 13, 14, 15, 16}
u16a = [8]uint16{17, 18, 19, 20, 21, 22, 23, 24}
)
// pack tests
if err := Pack(&buf, i8); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, i16); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, i32); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, i64); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u8); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u16); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u32); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u64); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u8a[:]); err != nil {
t.Fatal(err)
}
if err := Pack(&buf, u16a[:]); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), packableReference) {
fmt.Println(buf.Bytes())
fmt.Println(packableReference)
t.Fatal("Packable Pack() did not match reference.")
}
// unpack tests
i8 = 0
i16 = 0
i32 = 0
i64 = 0
u8 = 0
u16 = 0
u32 = 0
u64 = 0
if err := Unpack(&buf, &i8); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &i16); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &i32); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &i64); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &u8); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &u16); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &u32); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, &u64); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, u8a[:]); err != nil {
t.Fatal(err)
}
if err := Unpack(&buf, u16a[:]); err != nil {
t.Fatal(err)
}
// unpack checks
if i8 != 1 || i16 != 2 || i32 != 3 || i64 != 4 {
t.Fatal("Signed integer unpack failed.")
}
if u8 != 5 || u16 != 6 || u32 != 7 || u64 != 8 {
t.Fatal("Unsigned integer unpack failed.")
}
for i := 0; i < 8; i++ {
if u8a[i] != uint8(i+9) {
t.Fatal("uint8 array unpack failed.")
}
}
for i := 0; i < 8; i++ {
if u16a[i] != uint16(i+17) {
t.Fatal("uint16 array unpack failed.")
}
}
}

View File

@@ -0,0 +1,13 @@
package struc
import (
"io"
"reflect"
)
type Packer interface {
Pack(buf []byte, val reflect.Value, options *Options) (int, error)
Unpack(r io.Reader, val reflect.Value, options *Options) error
Sizeof(val reflect.Value, options *Options) int
String() string
}

230
common/utils/sturc/parse.go Normal file
View File

@@ -0,0 +1,230 @@
package struc
import (
"encoding/binary"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
)
// struc:"int32,big,sizeof=Data,skip,sizefrom=Len"
type strucTag struct {
Type string
Order binary.ByteOrder
Sizeof string
Skip bool
Sizefrom string
}
func parseStrucTag(tag reflect.StructTag) *strucTag {
t := &strucTag{
Order: binary.BigEndian,
}
tagStr := tag.Get("struc")
if tagStr == "" {
// someone's going to typo this (I already did once)
// sorry if you made a module actually using this tag
// and you're mad at me now
tagStr = tag.Get("struct")
}
for _, s := range strings.Split(tagStr, ",") {
if strings.HasPrefix(s, "sizeof=") {
tmp := strings.SplitN(s, "=", 2)
t.Sizeof = tmp[1]
} else if strings.HasPrefix(s, "sizefrom=") {
tmp := strings.SplitN(s, "=", 2)
t.Sizefrom = tmp[1]
} else if s == "big" {
t.Order = binary.BigEndian
} else if s == "little" {
t.Order = binary.LittleEndian
} else if s == "skip" {
t.Skip = true
} else {
t.Type = s
}
}
return t
}
var typeLenRe = regexp.MustCompile(`^\[(\d*)\]`)
func parseField(f reflect.StructField) (fd *Field, tag *strucTag, err error) {
tag = parseStrucTag(f.Tag)
var ok bool
fd = &Field{
Name: f.Name,
Len: 1,
Order: tag.Order,
Slice: false,
kind: f.Type.Kind(),
}
switch fd.kind {
case reflect.Array:
fd.Slice = true
fd.Array = true
fd.Len = f.Type.Len()
fd.kind = f.Type.Elem().Kind()
case reflect.Slice:
fd.Slice = true
fd.Len = -1
fd.kind = f.Type.Elem().Kind()
case reflect.Ptr:
fd.Ptr = true
fd.kind = f.Type.Elem().Kind()
}
// check for custom types
tmp := reflect.New(f.Type)
if _, ok := tmp.Interface().(Custom); ok {
fd.Type = CustomType
return
}
var defTypeOk bool
fd.defType, defTypeOk = reflectTypeMap[fd.kind]
// find a type in the struct tag
pureType := typeLenRe.ReplaceAllLiteralString(tag.Type, "")
if fd.Type, ok = typeLookup[pureType]; ok {
fd.Len = 1
match := typeLenRe.FindAllStringSubmatch(tag.Type, -1)
if len(match) > 0 && len(match[0]) > 1 {
fd.Slice = true
first := match[0][1]
// Field.Len = -1 indicates a []slice
if first == "" {
fd.Len = -1
} else {
fd.Len, err = strconv.Atoi(first)
}
}
return
}
// the user didn't specify a type
switch f.Type {
case reflect.TypeOf(Size_t(0)):
fd.Type = SizeType
case reflect.TypeOf(Off_t(0)):
fd.Type = OffType
default:
if defTypeOk {
fd.Type = fd.defType
} else {
err = errors.New(fmt.Sprintf("struc: Could not resolve field '%v' type '%v'.", f.Name, f.Type))
}
}
return
}
func parseFieldsLocked(v reflect.Value) (Fields, error) {
// we need to repeat this logic because parseFields() below can't be recursively called due to locking
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
if v.NumField() < 1 {
return nil, errors.New("struc: Struct has no fields.")
}
sizeofMap := make(map[string][]int)
fields := make(Fields, v.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
f, tag, err := parseField(field)
if tag.Skip {
continue
}
if err != nil {
return nil, err
}
if !v.Field(i).CanSet() {
continue
}
f.Index = i
if tag.Sizeof != "" {
target, ok := t.FieldByName(tag.Sizeof)
if !ok {
return nil, fmt.Errorf("struc: `sizeof=%s` field does not exist", tag.Sizeof)
}
f.Sizeof = target.Index
sizeofMap[tag.Sizeof] = field.Index
}
if sizefrom, ok := sizeofMap[field.Name]; ok {
f.Sizefrom = sizefrom
}
if tag.Sizefrom != "" {
source, ok := t.FieldByName(tag.Sizefrom)
if !ok {
return nil, fmt.Errorf("struc: `sizefrom=%s` field does not exist", tag.Sizefrom)
}
f.Sizefrom = source.Index
}
if f.Len == -1 && f.Sizefrom == nil {
return nil, fmt.Errorf("struc: field `%s` is a slice with no length or sizeof field", field.Name)
}
// recurse into nested structs
// TODO: handle loops (probably by indirecting the []Field and putting pointer in cache)
if f.Type == Struct {
typ := field.Type
if f.Ptr {
typ = typ.Elem()
}
if f.Slice {
typ = typ.Elem()
}
f.Fields, err = parseFieldsLocked(reflect.New(typ))
if err != nil {
return nil, err
}
}
fields[i] = f
}
return fields, nil
}
var fieldCache = make(map[reflect.Type]Fields)
var fieldCacheLock sync.RWMutex
var parseLock sync.Mutex
func fieldCacheLookup(t reflect.Type) Fields {
fieldCacheLock.RLock()
defer fieldCacheLock.RUnlock()
if cached, ok := fieldCache[t]; ok {
return cached
}
return nil
}
func parseFields(v reflect.Value) (Fields, error) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
// fast path: hopefully the field parsing is already cached
if cached := fieldCacheLookup(t); cached != nil {
return cached, nil
}
// hold a global lock so multiple goroutines can't parse (the same) fields at once
parseLock.Lock()
defer parseLock.Unlock()
// check cache a second time, in case parseLock was just released by
// another thread who filled the cache for us
if cached := fieldCacheLookup(t); cached != nil {
return cached, nil
}
// no luck, time to parse and fill the cache ourselves
fields, err := parseFieldsLocked(v)
if err != nil {
return nil, err
}
fieldCacheLock.Lock()
fieldCache[t] = fields
fieldCacheLock.Unlock()
return fields, nil
}

View File

@@ -0,0 +1,62 @@
package struc
import (
"bytes"
"reflect"
"testing"
)
func parseTest(data interface{}) error {
_, err := parseFields(reflect.ValueOf(data))
return err
}
type empty struct{}
func TestEmptyStruc(t *testing.T) {
if err := parseTest(&empty{}); err == nil {
t.Fatal("failed to error on empty struct")
}
}
type chanStruct struct {
Test chan int
}
func TestChanError(t *testing.T) {
if err := parseTest(&chanStruct{}); err == nil {
// TODO: should probably ignore channel fields
t.Fatal("failed to error on struct containing channel")
}
}
type badSizeof struct {
Size int `struc:"sizeof=Bad"`
}
func TestBadSizeof(t *testing.T) {
if err := parseTest(&badSizeof{}); err == nil {
t.Fatal("failed to error on missing Sizeof target")
}
}
type missingSize struct {
Test []byte
}
func TestMissingSize(t *testing.T) {
if err := parseTest(&missingSize{}); err == nil {
t.Fatal("failed to error on missing field size")
}
}
type badNested struct {
Empty empty
}
func TestNestedParseError(t *testing.T) {
var buf bytes.Buffer
if err := Pack(&buf, &badNested{}); err == nil {
t.Fatal("failed to error on bad nested struct")
}
}

122
common/utils/sturc/struc.go Normal file
View File

@@ -0,0 +1,122 @@
package struc
import (
"encoding/binary"
"fmt"
"io"
"reflect"
)
type Options struct {
ByteAlign int
PtrSize int
Order binary.ByteOrder
}
func (o *Options) Validate() error {
if o.PtrSize == 0 {
o.PtrSize = 32
} else {
switch o.PtrSize {
case 8, 16, 32, 64:
default:
return fmt.Errorf("Invalid Options.PtrSize: %d. Must be in (8, 16, 32, 64)", o.PtrSize)
}
}
return nil
}
var emptyOptions = &Options{}
func init() {
// fill default values to avoid data race to be reported by race detector.
emptyOptions.Validate()
}
func prep(data interface{}) (reflect.Value, Packer, error) {
value := reflect.ValueOf(data)
for value.Kind() == reflect.Ptr {
next := value.Elem().Kind()
if next == reflect.Struct || next == reflect.Ptr {
value = value.Elem()
} else {
break
}
}
switch value.Kind() {
case reflect.Struct:
fields, err := parseFields(value)
return value, fields, err
default:
if !value.IsValid() {
return reflect.Value{}, nil, fmt.Errorf("Invalid reflect.Value for %+v", data)
}
if c, ok := data.(Custom); ok {
return value, customFallback{c}, nil
}
return value, binaryFallback(value), nil
}
}
func Pack(w io.Writer, data interface{}) error {
return PackWithOptions(w, data, nil)
}
func PackWithOptions(w io.Writer, data interface{}, options *Options) error {
if options == nil {
options = emptyOptions
}
if err := options.Validate(); err != nil {
return err
}
val, packer, err := prep(data)
if err != nil {
return err
}
if val.Type().Kind() == reflect.String {
val = val.Convert(reflect.TypeOf([]byte{}))
}
size := packer.Sizeof(val, options)
buf := make([]byte, size)
if _, err := packer.Pack(buf, val, options); err != nil {
return err
}
_, err = w.Write(buf)
return err
}
func Unpack(r io.Reader, data interface{}) error {
return UnpackWithOptions(r, data, nil)
}
func UnpackWithOptions(r io.Reader, data interface{}, options *Options) error {
if options == nil {
options = emptyOptions
}
if err := options.Validate(); err != nil {
return err
}
val, packer, err := prep(data)
if err != nil {
return err
}
return packer.Unpack(r, val, options)
}
func Sizeof(data interface{}) (int, error) {
return SizeofWithOptions(data, nil)
}
func SizeofWithOptions(data interface{}, options *Options) (int, error) {
if options == nil {
options = emptyOptions
}
if err := options.Validate(); err != nil {
return 0, err
}
val, packer, err := prep(data)
if err != nil {
return 0, err
}
return packer.Sizeof(val, options), nil
}

View File

@@ -0,0 +1,310 @@
package struc
import (
"bytes"
"encoding/binary"
"fmt"
"reflect"
"testing"
)
type Nested struct {
Test2 int `struc:"int8"`
}
type Example struct {
Pad []byte `struc:"[5]pad"` // 00 00 00 00 00
I8f int `struc:"int8"` // 01
I16f int `struc:"int16"` // 00 02
I32f int `struc:"int32"` // 00 00 00 03
I64f int `struc:"int64"` // 00 00 00 00 00 00 00 04
U8f int `struc:"uint8,little"` // 05
U16f int `struc:"uint16,little"` // 06 00
U32f int `struc:"uint32,little"` // 07 00 00 00
U64f int `struc:"uint64,little"` // 08 00 00 00 00 00 00 00
Boolf int `struc:"bool"` // 01
Byte4f []byte `struc:"[4]byte"` // "abcd"
I8 int8 // 09
I16 int16 // 00 0a
I32 int32 // 00 00 00 0b
I64 int64 // 00 00 00 00 00 00 00 0c
U8 uint8 `struc:"little"` // 0d
U16 uint16 `struc:"little"` // 0e 00
U32 uint32 `struc:"little"` // 0f 00 00 00
U64 uint64 `struc:"little"` // 10 00 00 00 00 00 00 00
BoolT bool // 01
BoolF bool // 00
Byte4 [4]byte // "efgh"
Float1 float32 // 41 a0 00 00
Float2 float64 // 41 35 00 00 00 00 00 00
I32f2 int64 `struc:"int32"` // ff ff ff ff
U32f2 int64 `struc:"uint32"` // ff ff ff ff
I32f3 int32 `struc:"int64"` // ff ff ff ff ff ff ff ff
Size int `struc:"sizeof=Str,little"` // 0a 00 00 00
Str string `struc:"[]byte"` // "ijklmnopqr"
Strb string `struc:"[4]byte"` // "stuv"
Size2 int `struc:"uint8,sizeof=Str2"` // 04
Str2 string // "1234"
Size3 int `struc:"uint8,sizeof=Bstr"` // 04
Bstr []byte // "5678"
Size4 int `struc:"little"` // 07 00 00 00
Str4a string `struc:"[]byte,sizefrom=Size4"` // "ijklmno"
Str4b string `struc:"[]byte,sizefrom=Size4"` // "pqrstuv"
Size5 int `struc:"uint8"` // 04
Bstr2 []byte `struc:"sizefrom=Size5"` // "5678"
Nested Nested // 00 00 00 01
NestedP *Nested // 00 00 00 02
TestP64 *int `struc:"int64"` // 00 00 00 05
NestedSize int `struc:"sizeof=NestedA"` // 00 00 00 02
NestedA []Nested // [00 00 00 03, 00 00 00 04]
Skip int `struc:"skip"`
CustomTypeSize Int3 `struc:"sizeof=CustomTypeSizeArr"` // 00 00 00 04
CustomTypeSizeArr []byte // "ABCD"
}
var five = 5
type ExampleStructWithin struct {
a uint8
}
type ExampleSlice struct {
PropsLen uint8 `struc:"sizeof=Props"`
Props []ExampleStructWithin
}
type ExampleArray struct {
PropsLen uint8
Props [16]ExampleStructWithin `struc:"[16]ExampleStructWithin"`
}
var arraySliceReferenceBytes = []byte{
16,
0, 0, 0, 1,
0, 0, 0, 1,
0, 0, 0, 2,
0, 0, 0, 3,
0, 0, 0, 4,
0, 0, 0, 5,
0, 0, 0, 6,
0, 0, 0, 7,
0, 0, 0, 8,
0, 0, 0, 9,
0, 0, 0, 10,
0, 0, 0, 11,
0, 0, 0, 12,
0, 0, 0, 13,
0, 0, 0, 14,
0, 0, 0, 15,
0, 0, 0, 16,
}
var arrayReference = &ExampleArray{
16,
[16]ExampleStructWithin{
ExampleStructWithin{1},
ExampleStructWithin{2},
ExampleStructWithin{3},
ExampleStructWithin{4},
ExampleStructWithin{5},
ExampleStructWithin{6},
ExampleStructWithin{7},
ExampleStructWithin{8},
ExampleStructWithin{9},
ExampleStructWithin{10},
ExampleStructWithin{11},
ExampleStructWithin{12},
ExampleStructWithin{13},
ExampleStructWithin{14},
ExampleStructWithin{15},
ExampleStructWithin{16},
},
}
var sliceReference = &ExampleSlice{
16,
[]ExampleStructWithin{
ExampleStructWithin{1},
ExampleStructWithin{2},
ExampleStructWithin{3},
ExampleStructWithin{4},
ExampleStructWithin{5},
ExampleStructWithin{6},
ExampleStructWithin{7},
ExampleStructWithin{8},
ExampleStructWithin{9},
ExampleStructWithin{10},
ExampleStructWithin{11},
ExampleStructWithin{12},
ExampleStructWithin{13},
ExampleStructWithin{14},
ExampleStructWithin{15},
ExampleStructWithin{16},
},
}
var reference = &Example{
nil,
1, 2, 3, 4, 5, 6, 7, 8, 0, []byte{'a', 'b', 'c', 'd'},
9, 10, 11, 12, 13, 14, 15, 16, true, false, [4]byte{'e', 'f', 'g', 'h'},
20, 21,
-1,
4294967295,
-1,
10, "ijklmnopqr", "stuv",
4, "1234",
4, []byte("5678"),
7, "ijklmno", "pqrstuv",
4, []byte("5678"),
Nested{1}, &Nested{2}, &five,
6, []Nested{{3}, {4}, {5}, {6}, {7}, {8}},
0,
Int3(4), []byte("ABCD"),
}
var referenceBytes = []byte{
0, 0, 0, 0, 0, // pad(5)
1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, // fake int8-int64(1-4)
5, 6, 0, 7, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, // fake little-endian uint8-uint64(5-8)
0, // fake bool(0)
'a', 'b', 'c', 'd', // fake [4]byte
9, 0, 10, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 12, // real int8-int64(9-12)
13, 14, 0, 15, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, // real little-endian uint8-uint64(13-16)
1, 0, // real bool(1), bool(0)
'e', 'f', 'g', 'h', // real [4]byte
65, 160, 0, 0, // real float32(20)
64, 53, 0, 0, 0, 0, 0, 0, // real float64(21)
255, 255, 255, 255, // fake int32(-1)
255, 255, 255, 255, // fake uint32(4294967295)
255, 255, 255, 255, 255, 255, 255, 255, // fake int64(-1)
10, 0, 0, 0, // little-endian int32(10) sizeof=Str
'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', // Str
's', 't', 'u', 'v', // fake string([4]byte)
04, '1', '2', '3', '4', // real string
04, '5', '6', '7', '8', // fake []byte(string)
7, 0, 0, 0, // little-endian int32(7)
'i', 'j', 'k', 'l', 'm', 'n', 'o', // Str4a sizefrom=Size4
'p', 'q', 'r', 's', 't', 'u', 'v', // Str4b sizefrom=Size4
04, '5', '6', '7', '8', // fake []byte(string)
1, 2, // Nested{1}, Nested{2}
0, 0, 0, 0, 0, 0, 0, 5, // &five
0, 0, 0, 6, // int32(6)
3, 4, 5, 6, 7, 8, // [Nested{3}, ...Nested{8}]
0, 0, 4, 'A', 'B', 'C', 'D', // Int3(4), []byte("ABCD")
}
func TestCodec(t *testing.T) {
var buf bytes.Buffer
if err := Pack(&buf, reference); err != nil {
t.Fatal(err)
}
out := &Example{}
if err := Unpack(&buf, out); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(reference, out) {
fmt.Printf("got: %#v\nwant: %#v\n", out, reference)
t.Fatal("encode/decode failed")
}
}
func TestEncode(t *testing.T) {
var buf bytes.Buffer
if err := Pack(&buf, reference); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), referenceBytes) {
fmt.Printf("got: %#v\nwant: %#v\n", buf.Bytes(), referenceBytes)
t.Fatal("encode failed")
}
}
func TestDecode(t *testing.T) {
buf := bytes.NewReader(referenceBytes)
out := &Example{}
if err := Unpack(buf, out); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(reference, out) {
fmt.Printf("got: %#v\nwant: %#v\n", out, reference)
t.Fatal("decode failed")
}
}
func TestSizeof(t *testing.T) {
size, err := Sizeof(reference)
if err != nil {
t.Fatal(err)
}
if size != len(referenceBytes) {
t.Fatalf("sizeof failed; expected %d, got %d", len(referenceBytes), size)
}
}
type ExampleEndian struct {
T int `struc:"int16,big"`
}
func TestEndianSwap(t *testing.T) {
var buf bytes.Buffer
big := &ExampleEndian{1}
if err := PackWithOrder(&buf, big, binary.BigEndian); err != nil {
t.Fatal(err)
}
little := &ExampleEndian{}
if err := UnpackWithOrder(&buf, little, binary.LittleEndian); err != nil {
t.Fatal(err)
}
if little.T != 256 {
t.Fatal("big -> little conversion failed")
}
}
func TestNilValue(t *testing.T) {
var buf bytes.Buffer
if err := Pack(&buf, nil); err == nil {
t.Fatal("failed throw error for bad struct value")
}
if err := Unpack(&buf, nil); err == nil {
t.Fatal("failed throw error for bad struct value")
}
if _, err := Sizeof(nil); err == nil {
t.Fatal("failed to throw error for bad struct value")
}
}
type sliceUnderrun struct {
Str string `struc:"[10]byte"`
Arr []uint16 `struc:"[10]uint16"`
}
func TestSliceUnderrun(t *testing.T) {
var buf bytes.Buffer
v := sliceUnderrun{
Str: "foo",
Arr: []uint16{1, 2, 3},
}
if err := Pack(&buf, &v); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,4 @@
package test_pack_init
// This is a placeholder package for a test on specific race detector report on
// default Options initialization.

View File

@@ -0,0 +1,29 @@
package test_pack_init
import (
"bytes"
"github.com/lunixbochs/struc"
"sync"
"testing"
)
type Example struct {
I int `struc:int`
}
// TestParallelPack checks whether Pack is goroutine-safe. Run it with -race flag.
// Keep it as a single test in package since it is likely to be triggered on initialization
// of global objects reported as a data race by race detector.
func TestParallelPack(t *testing.T) {
var wg sync.WaitGroup
val := Example{}
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
var buf bytes.Buffer
_ = struc.Pack(&buf, &val)
}()
}
wg.Wait()
}

136
common/utils/sturc/types.go Normal file
View File

@@ -0,0 +1,136 @@
package struc
import (
"fmt"
"reflect"
)
type Type int
const (
Invalid Type = iota
Pad
Bool
Int
Int8
Uint8
Int16
Uint16
Int32
Uint32
Int64
Uint64
Float32
Float64
String
Struct
Ptr
SizeType
OffType
CustomType
)
func (t Type) Resolve(options *Options) Type {
switch t {
case OffType:
switch options.PtrSize {
case 8:
return Int8
case 16:
return Int16
case 32:
return Int32
case 64:
return Int64
default:
panic(fmt.Sprintf("unsupported ptr bits: %d", options.PtrSize))
}
case SizeType:
switch options.PtrSize {
case 8:
return Uint8
case 16:
return Uint16
case 32:
return Uint32
case 64:
return Uint64
default:
panic(fmt.Sprintf("unsupported ptr bits: %d", options.PtrSize))
}
}
return t
}
func (t Type) String() string {
return typeNames[t]
}
func (t Type) Size() int {
switch t {
case SizeType, OffType:
panic("Size_t/Off_t types must be converted to another type using options.PtrSize")
case Pad, String, Int8, Uint8, Bool:
return 1
case Int16, Uint16:
return 2
case Int32, Uint32, Float32:
return 4
case Int64, Uint64, Float64:
return 8
default:
panic("Cannot resolve size of type:" + t.String())
}
}
var typeLookup = map[string]Type{
"pad": Pad,
"bool": Bool,
"byte": Uint8,
"int8": Int8,
"uint8": Uint8,
"int16": Int16,
"uint16": Uint16,
"int32": Int32,
"uint32": Uint32,
"int64": Int64,
"uint64": Uint64,
"float32": Float32,
"float64": Float64,
"size_t": SizeType,
"off_t": OffType,
}
var typeNames = map[Type]string{
CustomType: "Custom",
}
func init() {
for name, enum := range typeLookup {
typeNames[enum] = name
}
}
type Size_t uint64
type Off_t int64
var reflectTypeMap = map[reflect.Kind]Type{
reflect.Bool: Bool,
reflect.Int8: Int8,
reflect.Int16: Int16,
reflect.Int: Int32,
reflect.Int32: Int32,
reflect.Int64: Int64,
reflect.Uint8: Uint8,
reflect.Uint16: Uint16,
reflect.Uint: Uint32,
reflect.Uint32: Uint32,
reflect.Uint64: Uint64,
reflect.Float32: Float32,
reflect.Float64: Float64,
reflect.String: String,
reflect.Struct: Struct,
reflect.Ptr: Ptr,
}

View File

@@ -0,0 +1,53 @@
package struc
import (
"bytes"
"testing"
)
func TestBadType(t *testing.T) {
defer func() { recover() }()
Type(-1).Size()
t.Fatal("failed to panic for invalid Type.Size()")
}
func TestTypeString(t *testing.T) {
if Pad.String() != "pad" {
t.Fatal("type string representation failed")
}
}
type sizeOffTest struct {
Size Size_t
Off Off_t
}
func TestSizeOffTypes(t *testing.T) {
bits := []int{8, 16, 32, 64}
var buf bytes.Buffer
test := &sizeOffTest{1, 2}
for _, b := range bits {
if err := PackWithOptions(&buf, test, &Options{PtrSize: b}); err != nil {
t.Fatal(err)
}
}
reference := []byte{
1, 2,
0, 1, 0, 2,
0, 0, 0, 1, 0, 0, 0, 2,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2,
}
if !bytes.Equal(reference, buf.Bytes()) {
t.Errorf("reference: %v != bytes: %v", reference, buf.Bytes())
}
reader := bytes.NewReader(buf.Bytes())
for _, b := range bits {
out := &sizeOffTest{}
if err := UnpackWithOptions(reader, out, &Options{PtrSize: b}); err != nil {
t.Fatal(err)
}
if out.Size != 1 || out.Off != 2 {
t.Errorf("Size_t/Off_t mismatch: {%d, %d}\n%v", out.Size, out.Off, buf.Bytes())
}
}
}

162
common/utils/test_test.go Normal file
View File

@@ -0,0 +1,162 @@
package utils
import (
"encoding/json"
"encoding/xml"
"fmt"
"strings"
"testing"
"github.com/apcera/termtables"
"github.com/gogf/gf/v2/os/glog"
"github.com/pointernil/bitset32"
)
func TestThree(t *testing.T) {
fmt.Printf("! \n")
var b bitset32.BitSet32
var a bitset32.BitSet32
// play some Go Fish
// for i := 0; i < 100; i++ {
// card1 := uint(rand.Intn(52))
// card2 := uint(rand.Intn(52))
// b.Set(card1)
// if b.Test(card2) {
// fmt.Println("Go Fish!")
// }
// b.Clear(card1)
// }
// Chaining
b.Set(10).Set(11)
a = *bitset32.New(50)
for i, e := b.NextSet(0); e; i, e = b.NextSet(i + 1) {
fmt.Println("The b bit is set:", i)
}
a.Set(10).Set(9)
f := b.Union(&a)
for i, e := f.NextSet(0); e; i, e = f.NextSet(i + 1) {
fmt.Println("The b+ bit is set:", i)
}
fmt.Println(b.Bytes())
}
func TestInit(t *testing.T) {
table := termtables.CreateTable()
table.AddHeaders("Name", "Age")
table.AddRow("John", "30")
table.AddRow("Sam", 18)
table.AddRow("Julie", 20.14)
fmt.Println(glog.GetStack())
fmt.Println(table.Render())
}
// SuperMaps 表示XML根节点
type SuperMaps struct {
XMLName xml.Name `xml:"superMaps" json:"-"`
Maps []Map `xml:"maps" json:"maps"`
}
// Map 表示XML中的每个地图节点
type Map struct {
ID string `xml:"id,attr" json:"id"`
Name string `xml:"name,attr" json:"name"`
X string `xml:"x,attr" json:"x"`
Y string `xml:"y,attr" json:"y"`
Galaxy string `xml:"galaxy,attr" json:"galaxy,omitempty"`
}
func TestXml(t *testing.T) {
// 示例XML数据
xmlData := `<?xml version="1.0" encoding="UTF-8"?>
<superMaps>
<maps id="1" name="传送舱" x="" y=""/>
<maps id="4" name="船长室" x="" y=""/>
<maps id="10 11 12 13" name="克洛斯星" galaxy="1" x="358" y="46"/>
</superMaps>`
// 示例JSON数据
jsonData := `{
"maps": [
{
"id": "1",
"name": "传送舱",
"x": "",
"y": ""
},
{
"id": "4",
"name": "船长室",
"x": "",
"y": ""
},
{
"id": "10 11 12 13",
"name": "克洛斯星",
"galaxy": "1",
"x": "358",
"y": "46"
}
]
}`
// 1. 解析XML并转换为JSON
fmt.Println("=== XML 到 JSON ===")
superMaps := &SuperMaps{}
err := xml.Unmarshal([]byte(xmlData), superMaps)
if err != nil {
fmt.Printf("解析XML失败: %v\n", err)
return
}
jsonOutput, err := json.MarshalIndent(superMaps, "", " ")
if err != nil {
fmt.Printf("转换为JSON失败: %v\n", err)
return
}
fmt.Println(string(jsonOutput))
// 2. 解析JSON并转换为XML
fmt.Println("\n=== JSON 到 XML ===")
newSuperMaps := &SuperMaps{}
err = json.Unmarshal([]byte(jsonData), newSuperMaps)
if err != nil {
fmt.Printf("解析JSON失败: %v\n", err)
return
}
xmlOutput, err := xml.MarshalIndent(newSuperMaps, "", " ")
if err != nil {
fmt.Printf("转换为XML失败: %v\n", err)
return
}
// 添加XML声明
xmlWithHeader := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>%s`, xmlOutput)
fmt.Println(string(xmlWithHeader))
// 3. 解析复杂ID的XML并处理
fmt.Println("\n=== 解析复杂ID的XML ===")
parseComplexID(xmlData)
}
// 解析复杂ID的XML并处理
func parseComplexID(xmlData string) {
superMaps := &SuperMaps{}
err := xml.Unmarshal([]byte(xmlData), superMaps)
if err != nil {
fmt.Printf("解析XML失败: %v\n", err)
return
}
// 处理每个地图项
for _, m := range superMaps.Maps {
// 处理包含多个ID的情况
if strings.Contains(m.ID, " ") {
ids := strings.Fields(m.ID)
fmt.Printf("地图名称: %s, 拆分后的ID: %v\n", m.Name, ids)
} else {
fmt.Printf("地图名称: %s, ID: %s\n", m.Name, m.ID)
}
}
}

View File

@@ -0,0 +1,41 @@
增加了支持短标签(自关闭标签)功能的golang xml官方包.
Added official package of golang xml with shortform (auto close) support.
e.g.
```xml
<book></book> -> <book />
<happiness type="joy"></happiness> -> <happiness type="joy" />
```
# How to get
```
go get github.com/ECUST-XX/xml
```
# Usage
Replace
```go
import "encoding/xml"
xml.MarshalIndent(v, " ", " ")
```
with
```go
import "github.com/ECUST-XX/xml"
xml.MarshalIndentShortForm(v, " ", " ")
```
or
```go
import "github.com/ECUST-XX/xml"
enc := xml.NewEncoder(os.Stdout)
enc.Indent(" ", " ")
enc.ShortForm()
if err := enc.Encode(v); err != nil {
fmt.Printf("error: %v\n", err)
}
```

View File

@@ -0,0 +1,56 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import "time"
var atomValue = &Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Example Feed",
Link: []Link{{Href: "http://example.org/"}},
Updated: ParseTime("2003-12-13T18:30:02Z"),
Author: Person{Name: "John Doe"},
ID: "urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6",
Entry: []Entry{
{
Title: "Atom-Powered Robots Run Amok",
Link: []Link{{Href: "http://example.org/2003/12/13/atom03"}},
ID: "urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a",
Updated: ParseTime("2003-12-13T18:30:02Z"),
Summary: NewText("Some text."),
},
},
}
var atomXML = `` +
`<feed xmlns="http://www.w3.org/2005/Atom" updated="2003-12-13T18:30:02Z">` +
`<title>Example Feed</title>` +
`<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
`<link href="http://example.org/"></link>` +
`<author><name>John Doe</name><uri></uri><email></email></author>` +
`<entry>` +
`<title>Atom-Powered Robots Run Amok</title>` +
`<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` +
`<link href="http://example.org/2003/12/13/atom03"></link>` +
`<updated>2003-12-13T18:30:02Z</updated>` +
`<author><name></name><uri></uri><email></email></author>` +
`<summary>Some text.</summary>` +
`</entry>` +
`</feed>`
func ParseTime(str string) time.Time {
t, err := time.Parse(time.RFC3339, str)
if err != nil {
panic(err)
}
return t
}
func NewText(text string) Text {
return Text{
Body: text,
}
}

View File

@@ -0,0 +1,85 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml_test
import (
"fmt"
"log"
"strings"
"github.com/ECUST-XX/xml"
)
type Animal int
const (
Unknown Animal = iota
Gopher
Zebra
)
func (a *Animal) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
var s string
if err := d.DecodeElement(&s, &start); err != nil {
return err
}
switch strings.ToLower(s) {
default:
*a = Unknown
case "gopher":
*a = Gopher
case "zebra":
*a = Zebra
}
return nil
}
func (a Animal) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
var s string
switch a {
default:
s = "unknown"
case Gopher:
s = "gopher"
case Zebra:
s = "zebra"
}
return e.EncodeElement(s, start)
}
func Example_customMarshalXML() {
blob := `
<animals>
<animal>gopher</animal>
<animal>armadillo</animal>
<animal>zebra</animal>
<animal>unknown</animal>
<animal>gopher</animal>
<animal>bee</animal>
<animal>gopher</animal>
<animal>zebra</animal>
</animals>`
var zoo struct {
Animals []Animal `xml:"animal"`
}
if err := xml.Unmarshal([]byte(blob), &zoo); err != nil {
log.Fatal(err)
}
census := make(map[Animal]int)
for _, animal := range zoo.Animals {
census[animal] += 1
}
fmt.Printf("Zoo Census:\n* Gophers: %d\n* Zebras: %d\n* Unknown: %d\n",
census[Gopher], census[Zebra], census[Unknown])
// Output:
// Zoo Census:
// * Gophers: 3
// * Zebras: 2
// * Unknown: 3
}

View File

@@ -0,0 +1,198 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml_test
import (
"fmt"
"os"
"github.com/ECUST-XX/xml"
)
func ExampleMarshalIndent() {
type Address struct {
City, State string
}
type Person struct {
XMLName xml.Name `xml:"person"`
Id int `xml:"id,attr"`
FirstName string `xml:"name>first"`
LastName string `xml:"name>last"`
Age int `xml:"age"`
Height float32 `xml:"height,omitempty"`
Married bool
Address
Comment string `xml:",comment"`
}
v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42}
v.Comment = " Need more details. "
v.Address = Address{"Hanga Roa", "Easter Island"}
output, err := xml.MarshalIndent(v, " ", " ")
if err != nil {
fmt.Printf("error: %v\n", err)
}
os.Stdout.Write(output)
// Output:
// <person id="13">
// <name>
// <first>John</first>
// <last>Doe</last>
// </name>
// <age>42</age>
// <Married>false</Married>
// <City>Hanga Roa</City>
// <State>Easter Island</State>
// <!-- Need more details. -->
// </person>
}
func ExampleEncoder() {
type Address struct {
City, State string
}
type Person struct {
XMLName xml.Name `xml:"person"`
Id int `xml:"id,attr"`
FirstName string `xml:"name>first"`
LastName string `xml:"name>last"`
Age int `xml:"age"`
Height float32 `xml:"height,omitempty"`
Married bool
Address
Comment string `xml:",comment"`
}
v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42}
v.Comment = " Need more details. "
v.Address = Address{"Hanga Roa", "Easter Island"}
enc := xml.NewEncoder(os.Stdout)
enc.Indent(" ", " ")
if err := enc.Encode(v); err != nil {
fmt.Printf("error: %v\n", err)
}
// Output:
// <person id="13">
// <name>
// <first>John</first>
// <last>Doe</last>
// </name>
// <age>42</age>
// <Married>false</Married>
// <City>Hanga Roa</City>
// <State>Easter Island</State>
// <!-- Need more details. -->
// </person>
}
// This example demonstrates unmarshaling an XML excerpt into a value with
// some preset fields. Note that the Phone field isn't modified and that
// the XML <Company> element is ignored. Also, the Groups field is assigned
// considering the element path provided in its tag.
func ExampleUnmarshal() {
type Email struct {
Where string `xml:"where,attr"`
Addr string
}
type Address struct {
City, State string
}
type Result struct {
XMLName xml.Name `xml:"Person"`
Name string `xml:"FullName"`
Phone string
Email []Email
Groups []string `xml:"Group>Value"`
Address
}
v := Result{Name: "none", Phone: "none"}
data := `
<Person>
<FullName>Grace R. Emlin</FullName>
<Company>Example Inc.</Company>
<Email where="home">
<Addr>gre@example.com</Addr>
</Email>
<Email where='work'>
<Addr>gre@work.com</Addr>
</Email>
<Group>
<Value>Friends</Value>
<Value>Squash</Value>
</Group>
<City>Hanga Roa</City>
<State>Easter Island</State>
</Person>
`
err := xml.Unmarshal([]byte(data), &v)
if err != nil {
fmt.Printf("error: %v", err)
return
}
fmt.Printf("XMLName: %#v\n", v.XMLName)
fmt.Printf("Name: %q\n", v.Name)
fmt.Printf("Phone: %q\n", v.Phone)
fmt.Printf("Email: %v\n", v.Email)
fmt.Printf("Groups: %v\n", v.Groups)
fmt.Printf("Address: %v\n", v.Address)
// Output:
// XMLName: xml.Name{Space:"", Local:"Person"}
// Name: "Grace R. Emlin"
// Phone: "none"
// Email: [{home gre@example.com} {work gre@work.com}]
// Groups: [Friends Squash]
// Address: {Hanga Roa Easter Island}
}
func ExampleMarshalShortForm() {
type Address struct {
City, State string
}
type Person struct {
XMLName xml.Name `xml:"person"`
Id int `xml:"id,attr"`
FirstName string `xml:"name>first"`
LastName string `xml:"name>last"`
Age int `xml:"age"`
Height float32 `xml:"height,omitempty"`
Married bool
Address
Comment string `xml:",comment"`
FavoriteMusic []string `xml:"favorite_music"`
LuckyNumber int `xml:"lucky_number"`
}
v := &Person{Id: 13, FirstName: "John", LastName: ""}
v.Comment = " Need more details. "
v.Address = Address{"Hanga Roa", ""}
v.FavoriteMusic = []string{"Made in Heaven", ""}
output, err := xml.MarshalIndentShortForm(v, " ", " ")
if err != nil {
fmt.Printf("error: %v\n", err)
}
os.Stdout.Write(output)
// Output:
// <person id="13">
// <name>
// <first>John</first>
// <last />
// </name>
// <age>0</age>
// <Married>false</Married>
// <City>Hanga Roa</City>
// <State />
// <!-- Need more details. -->
// <favorite_music>Made in Heaven</favorite_music>
// <favorite_music />
// <lucky_number>0</lucky_number>
// </person>
}

Some files were not shown because too many files have changed in this diff Show More