forked from WqyJh/tiktoken-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tiktoken_test.go
42 lines (35 loc) · 1.54 KB
/
tiktoken_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package tiktoken
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEncoding(t *testing.T) {
ass := assert.New(t)
enc, err := EncodingForModel("gpt-3.5-turbo-16k")
ass.Nil(err, "Encoding init should not be nil")
tokens := enc.Encode("hello world!你好,世界!", []string{"all"}, []string{"all"})
// these tokens are converted from the original python code
sourceTokens := []int{15339, 1917, 0, 57668, 53901, 3922, 3574, 244, 98220, 6447}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")
tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, nil)
sourceTokens = []int{15339, 220, 100257}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")
tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, []string{"all"})
sourceTokens = []int{15339, 220, 100257}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")
ass.Panics(func() {
tokens = enc.Encode("hello <|endoftext|><|endofprompt|>", []string{"<|endoftext|>"}, []string{"all"})
})
ass.Panics(func() {
tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, []string{"<|endoftext|>"})
})
}
func TestDecoding(t *testing.T) {
ass := assert.New(t)
// enc, err := GetEncoding("cl100k_base")
enc, err := GetEncoding(MODEL_CL100K_BASE)
ass.Nil(err, "Encoding init should not be nil")
sourceTokens := []int{15339, 1917, 0, 57668, 53901, 3922, 3574, 244, 98220, 6447}
txt := enc.Decode(sourceTokens)
ass.Equal("hello world!你好,世界!", txt, "Decoding should be equal")
}