diff --git a/README.md b/README.md index 6ee55f1..30d36af 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ go run channel_gen.go Uint64Channel uint64 channel_uint64.go 这种分包协议的包结构很简单,每个消息包由N个字节的固定长度的包头和不定长的包体组成,包头的数据是小端格式或者大段格式编码的包体长度值。分包的时候先读取包头,解码后获得包体长度,接着读取对应长度的数据即为包体。 -这种分包协议简单易用,但是需要注意消息包的大小要控制好,否则容易成为漏洞被黑客利用,比如伪造一个长度超长的包头信息,让服务器一次申请一大块内存,或者频繁发送无效的大消息包,导致服务器内存耗尽。`Packet()`函数有一个`MaxPacketSize`参数用来限制最大包大小,当接收或发送的消息超过这个体积限制时,内部将抛出`ErrTooLargePacket`错误,除了控制合理的`MaxPacketSize`之外,建议使用者在自己的网络层也要做好安全防范措施,比如控制每个用户的出错频率等。 +这种分包协议简单易用,但是需要注意消息包的大小要控制好,否则容易成为漏洞被黑客利用,比如伪造一个长度超长的包头信息,让服务器一次申请一大块内存,或者频繁发送无效的大消息包,导致服务器内存耗尽。`Packet()`函数有一个`MaxPacketSize`参数用来限制最大包大小,当接收或发送的消息超过这个体积限制时,内部将抛出`ErrPacketTooLarge`错误,除了控制合理的`MaxPacketSize`之外,建议使用者在自己的网络层也要做好安全防范措施,比如控制每个用户的出错频率等。 优化提示1:在实践中,建议采用2字节包头结构,在需要发送大消息包的地方在协议上做消息分帧,而不是一次性发送一个大体积的消息包,这样除了起到安全防范作用,也可以获得较好的性能表现。 diff --git a/codec_packet.go b/codec_packet.go index a014724..c5af1ce 100644 --- a/codec_packet.go +++ b/codec_packet.go @@ -16,13 +16,14 @@ var ( ) var ( - ErrUnsupportedPacketType = errors.New("unsupported packet type") - ErrTooLargePacket = errors.New("too large packet") + ErrPacketUnsupported = errors.New("funny/link: unsupported packet type") + ErrPacketTooLarge = errors.New("funny/link: too large packet") + ErrPacketNoReadAll = errors.New("funny/link: no read all content from packet") ) func Packet(n, maxPacketSize, readBufferSize int, byteOrder ByteOrder, base CodecType) CodecType { if n != 1 && n != 2 && n != 4 && n != 8 { - panic(ErrUnsupportedPacketType) + panic(ErrPacketUnsupported) } return &packetCodecType{ n: n, @@ -54,7 +55,7 @@ func (codecType *packetCodecType) NewEncoder(w io.Writer) Encoder { } encoder.buffer.data = make([]byte, 1024) encoder.buffer.n = codecType.n - encoder.buffer.max = codecType.maxPacketSize + encoder.buffer.max = codecType.n + codecType.maxPacketSize switch codecType.n { case 1: encoder.encodeHead = codecType.encodeHead1 @@ -74,7 +75,7 @@ func (codecType *packetCodecType) encodeHead1(b []byte) { if n := len(b) - 1; n <= 254 && n <= codecType.maxPacketSize { b[0] = byte(n) } else { - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } } @@ -82,7 +83,7 @@ func (codecType *packetCodecType) encodeHead2(b []byte) { if n := len(b) - 2; n <= 65534 && n <= codecType.maxPacketSize { codecType.byteOrder.PutUint16(b, uint16(n)) } else { - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } } @@ -90,7 +91,7 @@ func (codecType *packetCodecType) encodeHead4(b []byte) { if n := len(b) - 4; n <= codecType.maxPacketSize { codecType.byteOrder.PutUint32(b, uint32(n)) } else { - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } } @@ -98,7 +99,7 @@ func (codecType *packetCodecType) encodeHead8(b []byte) { if n := len(b) - 8; n <= codecType.maxPacketSize { codecType.byteOrder.PutUint64(b, uint64(n)) } else { - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } } @@ -131,28 +132,28 @@ func (codecType *packetCodecType) decodeHead1(b []byte) int { if n := int(b[0]); n <= 254 && n <= codecType.maxPacketSize { return n } - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } func (codecType *packetCodecType) decodeHead2(b []byte) int { if n := int(codecType.byteOrder.Uint16(b)); n > 0 && n <= 65534 && n <= codecType.maxPacketSize { return n } - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } func (codecType *packetCodecType) decodeHead4(b []byte) int { if n := int(codecType.byteOrder.Uint32(b)); n > 0 && n <= codecType.maxPacketSize { return n } - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } func (codecType *packetCodecType) decodeHead8(b []byte) int { if n := int(codecType.byteOrder.Uint64(b)); n > 0 && n <= codecType.maxPacketSize { return n } - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } type packetEncoder struct { @@ -193,7 +194,7 @@ func (decoder *packetDecoder) Decode(msg interface{}) (err error) { return } if decoder.reader.N != 0 { - decoder.reader.R.(*bufio.Reader).Discard(int(decoder.reader.N)) + panic(ErrPacketNoReadAll) } return } @@ -241,7 +242,7 @@ func (pb *PacketBuffer) Next(n int) (b []byte) { pb.gorws(n) n += pb.wpos if n > pb.max { - panic(ErrTooLargePacket) + panic(ErrPacketTooLarge) } b = pb.data[pb.wpos:n] pb.wpos = n diff --git a/example/echo/benchmark.go b/example/echo/benchmark.go index 68e8138..a4a55af 100644 --- a/example/echo/benchmark.go +++ b/example/echo/benchmark.go @@ -73,7 +73,12 @@ type TestDecoder struct { } func (decoder *TestDecoder) Decode(msg interface{}) error { - // message data already in buffer, no need copy. + d := make([]byte, decoder.r.(*io.LimitedReader).N) + _, err := io.ReadFull(decoder.r, d) + if err != nil { + return err + } + *(msg.(*[]byte)) = d return nil }