Skip to content

Commit

Permalink
Add a generic sync.Pool wrapper to go library
Browse files Browse the repository at this point in the history
Since we dropped support of Go 1.18-, use generic to avoid dealing with
type assertions with interface{}/any.

While I'm here, also remove the usages of ioutil, as that's officially
marked as deprecated in Go 1.19.

Client: go
  • Loading branch information
fishy committed Aug 10, 2022
1 parent 7ae180b commit bdfde85
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 100 deletions.
37 changes: 15 additions & 22 deletions lib/go/thrift/deserializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package thrift

import (
"context"
"sync"
)

type TDeserializer struct {
Expand Down Expand Up @@ -81,19 +80,15 @@ func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err er
// It must be initialized with either NewTDeserializerPool or
// NewTDeserializerPoolSizeFactory.
type TDeserializerPool struct {
pool sync.Pool
pool *pool[TDeserializer]
}

// NewTDeserializerPool creates a new TDeserializerPool.
//
// NewTDeserializer can be used as the arg here.
func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
return &TDeserializerPool{
pool: sync.Pool{
New: func() interface{} {
return f()
},
},
pool: newPool(f, nil),
}
}

Expand All @@ -104,28 +99,26 @@ func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
// larger than that. It just dictates the initial size.
func NewTDeserializerPoolSizeFactory(size int, factory TProtocolFactory) *TDeserializerPool {
return &TDeserializerPool{
pool: sync.Pool{
New: func() interface{} {
transport := NewTMemoryBufferLen(size)
protocol := factory.GetProtocol(transport)

return &TDeserializer{
Transport: transport,
Protocol: protocol,
}
},
},
pool: newPool(func() *TDeserializer {
transport := NewTMemoryBufferLen(size)
protocol := factory.GetProtocol(transport)

return &TDeserializer{
Transport: transport,
Protocol: protocol,
}
}, nil),
}
}

func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
d := t.pool.get()
defer t.pool.put(&d)
return d.ReadString(ctx, msg, s)
}

func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
d := t.pool.get()
defer t.pool.put(&d)
return d.Read(ctx, msg, b)
}
10 changes: 5 additions & 5 deletions lib/go/thrift/framed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
// Make sure we return the read buffer back to pool
// after we finished reading from it.
if p.readBuf != nil && p.readBuf.Len() == 0 {
returnBufToPool(&p.readBuf)
bufPool.put(&p.readBuf)
}
}()

Expand Down Expand Up @@ -175,7 +175,7 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) {

func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
if p.writeBuf == nil {
p.writeBuf = getBufFromPool()
p.writeBuf = bufPool.get()
}
}

Expand All @@ -196,7 +196,7 @@ func (p *TFramedTransport) WriteString(s string) (n int, err error) {
}

func (p *TFramedTransport) Flush(ctx context.Context) error {
defer returnBufToPool(&p.writeBuf)
defer bufPool.put(&p.writeBuf)
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
Expand All @@ -215,9 +215,9 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {

func (p *TFramedTransport) readFrame() error {
if p.readBuf != nil {
returnBufToPool(&p.readBuf)
bufPool.put(&p.readBuf)
}
p.readBuf = getBufFromPool()
p.readBuf = bufPool.get()

buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
Expand Down
12 changes: 6 additions & 6 deletions lib/go/thrift/header_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {

// Read the frame fully into frameBuffer.
if t.frameBuffer == nil {
t.frameBuffer = getBufFromPool()
t.frameBuffer = bufPool.get()
}
_, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
if err != nil {
Expand Down Expand Up @@ -407,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
// It closes frameReader, and also resets frame related states.
func (t *THeaderTransport) endOfFrame() error {
defer func() {
returnBufToPool(&t.frameBuffer)
bufPool.put(&t.frameBuffer)
t.frameReader = nil
}()
return t.frameReader.Close()
Expand Down Expand Up @@ -572,7 +572,7 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) {
// You need to call Flush to actually write them to the transport.
func (t *THeaderTransport) Write(p []byte) (int, error) {
if t.writeBuffer == nil {
t.writeBuffer = getBufFromPool()
t.writeBuffer = bufPool.get()
}
return t.writeBuffer.Write(p)
}
Expand All @@ -583,7 +583,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
return nil
}

defer returnBufToPool(&t.writeBuffer)
defer bufPool.put(&t.writeBuffer)

switch t.clientType {
default:
Expand Down Expand Up @@ -633,8 +633,8 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
}
}

payload := getBufFromPool()
defer returnBufToPool(&payload)
payload := bufPool.get()
defer bufPool.put(&payload)
meta := headerMeta{
MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
SequenceID: t.SequenceID,
Expand Down
3 changes: 1 addition & 2 deletions lib/go/thrift/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -136,7 +135,7 @@ func (p *THttpClient) closeResponse() error {
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
io.Copy(ioutil.Discard, p.response.Body)
io.Copy(io.Discard, p.response.Body)

err = p.response.Body.Close()
}
Expand Down
15 changes: 6 additions & 9 deletions lib/go/thrift/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"io"
"net/http"
"strings"
"sync"
)

// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
Expand All @@ -41,23 +40,21 @@ func NewThriftHandlerFunc(processor TProcessor,

// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
sp := &sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
}
sp := newPool(func() *gzip.Writer {
return gzip.NewWriter(nil)
}, nil)

return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := sp.Get().(*gzip.Writer)
gz := sp.get()
gz.Reset(w)
defer func() {
_ = gz.Close()
sp.Put(gz)
gz.Close()
sp.put(&gz)
}()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
Expand Down
69 changes: 69 additions & 0 deletions lib/go/thrift/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package thrift

import (
"bytes"
"sync"
)

// pool is a generic sync.Pool wrapper with bells and whistles.
type pool[T any] struct {
pool sync.Pool
reset func(*T)
}

// newPool creates a new pool.
//
// Both generate and reset are optional.
// Default generate is just new(T),
// When reset is nil we don't do any additional resetting when calling get.
func newPool[T any](generate func() *T, reset func(*T)) *pool[T] {
if generate == nil {
generate = func() *T {
return new(T)
}
}
return &pool[T]{
pool: sync.Pool{
New: func() interface{} {
return generate()
},
},
reset: reset,
}
}

func (p *pool[T]) get() *T {
r := p.pool.Get().(*T)
if p.reset != nil {
p.reset(r)
}
return r
}

func (p *pool[T]) put(r **T) {
p.pool.Put(*r)
*r = nil
}

var bufPool = newPool(nil, func(buf *bytes.Buffer) {
buf.Reset()
})
51 changes: 25 additions & 26 deletions lib/go/thrift/buf_pool.go → lib/go/thrift/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,32 @@
package thrift

import (
"bytes"
"sync"
"testing"
"testing/quick"
)

var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
// before return.
func getBufFromPool() *bytes.Buffer {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
type poolTest int

// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
// accidental usage after it's returned.
//
// You usually want to use it this way:
//
// buf := getBufFromPool()
// defer returnBufToPool(&buf)
// // use buf
func returnBufToPool(buf **bytes.Buffer) {
bufPool.Put(*buf)
*buf = nil
func TestPoolReset(t *testing.T) {
p := newPool(nil, func(elem *poolTest) {
*elem = 0
})
f := func(i int) (passed bool) {
pt := p.get()
defer func() {
p.put(&pt)
if pt != nil {
t.Errorf("Expected pt to be nil after put, got %#v", pt)
passed = false
}
}()
if *pt != 0 {
t.Errorf("Expected *pt to be reset to 0 after get, got %d", *pt)
}
*pt = poolTest(i)
return !t.Failed()
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
6 changes: 3 additions & 3 deletions lib/go/thrift/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package thrift
import (
"bytes"
"context"
"io/ioutil"
"io"
"math"
"net"
"net/http"
Expand Down Expand Up @@ -60,7 +60,7 @@ type HTTPEchoServer struct{}
type HTTPHeaderEchoServer struct{}

func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
Expand All @@ -71,7 +71,7 @@ func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
Expand Down
Loading

0 comments on commit bdfde85

Please sign in to comment.