Skip to content

Commit

Permalink
feat: add RecvRaw (#896)
Browse files Browse the repository at this point in the history
  • Loading branch information
WqyJh authored Nov 30, 2024
1 parent 21fa42c commit c203ca0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
39 changes: 22 additions & 17 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
}

func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished {
err = io.EOF
rawLine, err := stream.RecvRaw()
if err != nil {
return
}

response, err = stream.processLines()
return
err = stream.unmarshaler.Unmarshal(rawLine, &response)
if err != nil {
return
}
return response, nil
}

func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
if stream.isFinished {
return nil, io.EOF
}

return stream.processLines()
}

//nolint:gocognit
func (stream *streamReader[T]) processLines() (T, error) {
func (stream *streamReader[T]) processLines() ([]byte, error) {
var (
emptyMessagesCount uint
hasErrorPrefix bool
Expand All @@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
return nil, fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
return nil, readErr
}

noSpaceLine := bytes.TrimSpace(rawLine)
Expand All @@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
return nil, writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
return nil, ErrTooManyEmptyStreamMessages
}

continue
Expand All @@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return *new(T), io.EOF
}

var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
return nil, io.EOF
}

return response, nil
return noPrefixLine, nil
}
}

Expand Down
13 changes: 13 additions & 0 deletions stream_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
_, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}

func TestStreamReaderRecvRaw(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
}
rawLine, err := stream.RecvRaw()
if err != nil {
t.Fatalf("Did not return raw line: %v", err)
}
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
t.Fatalf("Did not return raw line: %v", string(rawLine))
}
}

0 comments on commit c203ca0

Please sign in to comment.