-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
/* | ||
* Copyright 2023 ByteDance Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
* Unsafe reflect package for mockey, copy most code from go/src/reflect/type.go, | ||
* allow to export the address of private member methods. | ||
*/ | ||
|
||
package unsafereflect | ||
|
||
import ( | ||
"reflect" | ||
"unsafe" | ||
) | ||
|
||
func MethodByName(target interface{}, name string) (fn unsafe.Pointer, ok bool) { | ||
r := castRType(target) | ||
rt := toRType(r) | ||
if r.Kind() == reflect.Interface { | ||
return funcPointer(r.MethodByName(name)) | ||
} | ||
|
||
for _, p := range rt.methods() { | ||
if rt.nameOff(p.name).name() == name { | ||
return rt.Method(p), true | ||
} | ||
} | ||
return nil, false | ||
} | ||
|
||
// rtype is the common implementation of most values. | ||
// It is embedded in other struct types. | ||
// | ||
// rtype must be kept in sync with src/runtime/type.go:/^type._type. | ||
type rtype struct { | ||
size uintptr | ||
ptrdata uintptr // number of bytes in the type that can contain pointers | ||
hash uint32 // hash of type; avoids computation in hash tables | ||
tflag tflag // extra type information flags | ||
align uint8 // alignment of variable with this type | ||
fieldAlign uint8 // alignment of struct field with this type | ||
kind uint8 // enumeration for C | ||
// function for comparing objects of this type | ||
// (ptr to object A, ptr to object B) -> ==? | ||
equal func(unsafe.Pointer, unsafe.Pointer) bool | ||
gcdata *byte // garbage collection data | ||
str nameOff // string form | ||
ptrToThis typeOff // type for pointer to this type, may be zero | ||
} | ||
|
||
func castRType(val interface{}) reflect.Type { | ||
if rTypeVal, ok := val.(reflect.Type); ok { | ||
return rTypeVal | ||
} | ||
return reflect.TypeOf(val) | ||
} | ||
|
||
func toRType(t reflect.Type) *rtype { | ||
i := *(*funcValue)(unsafe.Pointer(&t)) | ||
r := (*rtype)(i.p) | ||
return r | ||
} | ||
|
||
type funcValue struct { | ||
_ uintptr | ||
p unsafe.Pointer | ||
} | ||
|
||
func funcPointer(v reflect.Method, ok bool) (unsafe.Pointer, bool) { | ||
return (*funcValue)(unsafe.Pointer(&v.Func)).p, ok | ||
} | ||
|
||
func (t *rtype) Method(p method) (fn unsafe.Pointer) { | ||
tfn := t.textOff(p.tfn) | ||
fn = unsafe.Pointer(&tfn) | ||
return | ||
} | ||
|
||
const kindMask = (1 << 5) - 1 | ||
|
||
func (t *rtype) Kind() reflect.Kind { return reflect.Kind(t.kind & kindMask) } | ||
|
||
type tflag uint8 | ||
type nameOff int32 // offset to a name | ||
type typeOff int32 // offset to an *rtype | ||
type textOff int32 // offset from top of text section | ||
|
||
// resolveNameOff resolves a name offset from a base pointer. | ||
// The (*rtype).nameOff method is a convenience wrapper for this function. | ||
// Implemented in the runtime package. | ||
// | ||
//go:linkname resolveNameOff reflect.resolveNameOff | ||
func resolveNameOff(unsafe.Pointer, int32) unsafe.Pointer | ||
|
||
func (t *rtype) nameOff(off nameOff) name { | ||
return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))} | ||
} | ||
|
||
// resolveTextOff resolves a function pointer offset from a base type. | ||
// The (*rtype).textOff method is a convenience wrapper for this function. | ||
// Implemented in the runtime package. | ||
// | ||
//go:linkname resolveTextOff reflect.resolveTextOff | ||
func resolveTextOff(unsafe.Pointer, int32) unsafe.Pointer | ||
|
||
func (t *rtype) textOff(off textOff) unsafe.Pointer { | ||
return resolveTextOff(unsafe.Pointer(t), int32(off)) | ||
} | ||
|
||
const tflagUncommon tflag = 1 << 0 | ||
|
||
// uncommonType is present only for defined types or types with methods | ||
type uncommonType struct { | ||
pkgPath nameOff // import path; empty for built-in types like int, string | ||
mcount uint16 // number of methods | ||
xcount uint16 // number of exported methods | ||
moff uint32 // offset from this uncommontype to [mcount]method | ||
_ uint32 // unused | ||
} | ||
|
||
// ptrType represents a pointer type. | ||
type ptrType struct { | ||
rtype | ||
elem *rtype // pointer element (pointed at) type | ||
} | ||
|
||
// funcType represents a function type. | ||
type funcType struct { | ||
rtype | ||
inCount uint16 | ||
outCount uint16 // top bit is set if last input parameter is ... | ||
} | ||
|
||
func (t *funcType) in() []*rtype { | ||
Check failure on line 145 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
uadd := unsafe.Sizeof(*t) | ||
if t.tflag&tflagUncommon != 0 { | ||
uadd += unsafe.Sizeof(uncommonType{}) | ||
} | ||
if t.inCount == 0 { | ||
return nil | ||
} | ||
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "t.inCount > 0"))[:t.inCount:t.inCount] | ||
} | ||
|
||
func (t *funcType) out() []*rtype { | ||
Check failure on line 156 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
uadd := unsafe.Sizeof(*t) | ||
if t.tflag&tflagUncommon != 0 { | ||
uadd += unsafe.Sizeof(uncommonType{}) | ||
} | ||
outCount := t.outCount & (1<<15 - 1) | ||
if outCount == 0 { | ||
return nil | ||
} | ||
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "outCount > 0"))[t.inCount : t.inCount+outCount : t.inCount+outCount] | ||
} | ||
|
||
func (t *rtype) IsVariadic() bool { | ||
tt := (*funcType)(unsafe.Pointer(t)) | ||
return tt.outCount&(1<<15) != 0 | ||
} | ||
|
||
func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { | ||
return unsafe.Pointer(uintptr(p) + x) | ||
} | ||
|
||
// interfaceType represents an interface type. | ||
type interfaceType struct { | ||
rtype | ||
pkgPath name // import path | ||
methods []imethod // sorted by hash | ||
} | ||
|
||
type imethod struct { | ||
name nameOff // name of method | ||
Check failure on line 185 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
typ typeOff // .(*FuncType) underneath | ||
Check failure on line 186 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
} | ||
|
||
func (t *rtype) methods() []method { | ||
if t.tflag&tflagUncommon == 0 { | ||
return nil | ||
} | ||
switch t.Kind() { | ||
case reflect.Pointer: | ||
type u struct { | ||
ptrType | ||
u uncommonType | ||
} | ||
return (*u)(unsafe.Pointer(t)).u.methods() | ||
case reflect.Func: | ||
type u struct { | ||
funcType | ||
u uncommonType | ||
} | ||
return (*u)(unsafe.Pointer(t)).u.methods() | ||
case reflect.Interface: | ||
type u struct { | ||
interfaceType | ||
u uncommonType | ||
} | ||
return (*u)(unsafe.Pointer(t)).u.methods() | ||
case reflect.Struct: | ||
type u struct { | ||
structType | ||
u uncommonType | ||
} | ||
return (*u)(unsafe.Pointer(t)).u.methods() | ||
default: | ||
return nil | ||
} | ||
} | ||
|
||
// Method on non-interface type | ||
type method struct { | ||
name nameOff // name of method | ||
mtyp typeOff // method type (without receiver), not valid for private methods | ||
Check failure on line 226 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
ifn textOff // fn used in interface call (one-word receiver) | ||
Check failure on line 227 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
tfn textOff // fn used for normal method call | ||
} | ||
|
||
func (t *uncommonType) methods() []method { | ||
if t.mcount == 0 { | ||
return nil | ||
} | ||
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.mcount > 0"))[:t.mcount:t.mcount] | ||
} | ||
|
||
// name is an encoded type name with optional extra data. | ||
type name struct { | ||
bytes *byte | ||
} | ||
|
||
func (n name) data(off int, whySafe string) *byte { | ||
return (*byte)(add(unsafe.Pointer(n.bytes), uintptr(off), whySafe)) | ||
} | ||
|
||
func (n name) readVarint(off int) (int, int) { | ||
v := 0 | ||
for i := 0; ; i++ { | ||
x := *n.data(off+i, "read varint") | ||
v += int(x&0x7f) << (7 * i) | ||
if x&0x80 == 0 { | ||
return i + 1, v | ||
} | ||
} | ||
} | ||
|
||
type _String struct { | ||
Data unsafe.Pointer | ||
Len int | ||
} | ||
|
||
func (n name) name() (s string) { | ||
if n.bytes == nil { | ||
return | ||
} | ||
i, l := n.readVarint(1) | ||
hdr := (*_String)(unsafe.Pointer(&s)) | ||
hdr.Data = unsafe.Pointer(n.data(1+i, "non-empty string")) | ||
hdr.Len = l | ||
return | ||
} | ||
|
||
// Struct field | ||
type structField struct { | ||
name name // name is always non-empty | ||
Check failure on line 276 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
typ *rtype // type of field | ||
Check failure on line 277 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
offset uintptr // byte offset of field | ||
Check failure on line 278 in internal/unsafereflect/type.go GitHub Actions / staticcheck
|
||
} | ||
|
||
// structType | ||
type structType struct { | ||
rtype | ||
pkgPath name | ||
fields []structField // sorted by offset | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright 2023 ByteDance Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package unsafereflect_test | ||
|
||
import ( | ||
"crypto/sha256" | ||
"hash" | ||
"reflect" | ||
"testing" | ||
"unsafe" | ||
|
||
"github.com/bytedance/mockey" | ||
"github.com/bytedance/mockey/internal/tool" | ||
"github.com/bytedance/mockey/internal/unsafereflect" | ||
) | ||
|
||
func TestMethodByName(t *testing.T) { | ||
// private structure private method: *sha256.digest.checkSum | ||
tfn, ok := unsafereflect.MethodByName(sha256.New(), "checkSum") | ||
tool.Assert(ok, "private member of private structure is allowed") | ||
// type of `func(*sha256.digest, []byte) [32]byte` | ||
pFn := unsafe.Pointer(&tfn) | ||
|
||
mockey.PatchConvey("ReflectFuncReturn", t, func() { | ||
f := reflect.FuncOf([]reflect.Type{reflect.TypeOf(sha256.New()), reflect.TypeOf([]byte{})}, | ||
[]reflect.Type{reflect.TypeOf([sha256.Size]byte{})}, false) | ||
fn := reflect.NewAt(f, pFn).Elem().Interface() | ||
// Such function cannot be exported as `(*sha256.digest).checkSum`, | ||
// since the receiver's type is *sha256.digest, only Return API can be used | ||
mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() | ||
rets := sha256.New().Sum(nil) | ||
want := make([]byte, sha256.Size) | ||
want[1] = 1 | ||
tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") | ||
}) | ||
|
||
mockey.PatchConvey("InterfaceFuncReturn", t, func() { | ||
fn := *(*func(hash.Hash, []byte) [sha256.Size]byte)(pFn) | ||
// Interface to fit the function shape is allowed here | ||
mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() | ||
rets := sha256.New().Sum(nil) | ||
want := make([]byte, sha256.Size) | ||
want[1] = 1 | ||
tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") | ||
}) | ||
|
||
mockey.PatchConvey("InterfaceFuncTo", t, func() { | ||
fn := *(*func(hash.Hash, []byte) [sha256.Size]byte)(pFn) | ||
// Interface to fit the function shape is allowed here, | ||
// since the receiver's type is interface, To API can be used here | ||
mockey.Mock(fn).To(func(hash.Hash, []byte) [sha256.Size]byte { | ||
return [sha256.Size]byte{1: 1} | ||
}).Build() | ||
rets := sha256.New().Sum(nil) | ||
want := make([]byte, sha256.Size) | ||
want[1] = 1 | ||
tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") | ||
}) | ||
} |