internal/runtime: impl type.equal

This commit is contained in:
visualfc
2024-07-02 13:20:34 +08:00
parent 4286a510b4
commit 46423ed166
4 changed files with 82 additions and 88 deletions

View File

@@ -442,7 +442,7 @@ bucketloop:
if t.IndirectKey() { if t.IndirectKey() {
k = *((*unsafe.Pointer)(k)) k = *((*unsafe.Pointer)(k))
} }
if mapKeyEqual(t, key, k) { if t.Key.Equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize))
if t.IndirectElem() { if t.IndirectElem() {
e = *((*unsafe.Pointer)(e)) e = *((*unsafe.Pointer)(e))
@@ -503,7 +503,7 @@ bucketloop:
if t.IndirectKey() { if t.IndirectKey() {
k = *((*unsafe.Pointer)(k)) k = *((*unsafe.Pointer)(k))
} }
if mapKeyEqual(t, key, k) { if t.Key.Equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize))
if t.IndirectElem() { if t.IndirectElem() {
e = *((*unsafe.Pointer)(e)) e = *((*unsafe.Pointer)(e))
@@ -547,7 +547,7 @@ bucketloop:
if t.IndirectKey() { if t.IndirectKey() {
k = *((*unsafe.Pointer)(k)) k = *((*unsafe.Pointer)(k))
} }
if mapKeyEqual(t, key, k) { if t.Key.Equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize))
if t.IndirectElem() { if t.IndirectElem() {
e = *((*unsafe.Pointer)(e)) e = *((*unsafe.Pointer)(e))
@@ -635,7 +635,7 @@ bucketloop:
if t.IndirectKey() { if t.IndirectKey() {
k = *((*unsafe.Pointer)(k)) k = *((*unsafe.Pointer)(k))
} }
if !mapKeyEqual(t, key, k) { if !t.Key.Equal(key, k) {
continue continue
} }
// already have a mapping for key. Update it. // already have a mapping for key. Update it.
@@ -747,7 +747,7 @@ search:
if t.IndirectKey() { if t.IndirectKey() {
k2 = *((*unsafe.Pointer)(k2)) k2 = *((*unsafe.Pointer)(k2))
} }
if !mapKeyEqual(t, key, k2) { if !t.Key.Equal(key, k2) {
continue continue
} }
// Only clear key if there are pointers in it. // Only clear key if there are pointers in it.
@@ -935,7 +935,7 @@ next:
// through the oldbucket, skipping any keys that will go // through the oldbucket, skipping any keys that will go
// to the other new bucket (each oldbucket expands to two // to the other new bucket (each oldbucket expands to two
// buckets during a grow). // buckets during a grow).
if t.ReflexiveKey() || mapKeyEqual(t, k, k) { if t.ReflexiveKey() || t.Key.Equal(k, k) {
// If the item in the oldbucket is not destined for // If the item in the oldbucket is not destined for
// the current new bucket in the iteration, skip it. // the current new bucket in the iteration, skip it.
hash := t.Hasher(k, uintptr(h.hash0)) hash := t.Hasher(k, uintptr(h.hash0))
@@ -956,7 +956,7 @@ next:
} }
} }
if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) || if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) ||
!(t.ReflexiveKey() || mapKeyEqual(t, k, k)) { !(t.ReflexiveKey() || t.Key.Equal(k, k)) {
// This is the golden data, we can return it. // This is the golden data, we can return it.
// OR // OR
// key!=key, so the entry can't be deleted or updated, so we can just return it. // key!=key, so the entry can't be deleted or updated, so we can just return it.
@@ -1214,7 +1214,7 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
// Compute hash to make our evacuation decision (whether we need // Compute hash to make our evacuation decision (whether we need
// to send this key/elem to bucket x or bucket y). // to send this key/elem to bucket x or bucket y).
hash := t.Hasher(k2, uintptr(h.hash0)) hash := t.Hasher(k2, uintptr(h.hash0))
if h.flags&iterator != 0 && !t.ReflexiveKey() && !mapKeyEqual(t, k2, k2) { if h.flags&iterator != 0 && !t.ReflexiveKey() && !t.Key.Equal(k2, k2) {
// If key != key (NaNs), then the hash could be (and probably // If key != key (NaNs), then the hash could be (and probably
// will be) entirely different from the old hash. Moreover, // will be) entirely different from the old hash. Moreover,
// it isn't reproducible. Reproducibility is required in the // it isn't reproducible. Reproducibility is required in the

View File

@@ -218,6 +218,11 @@ func Interface(pkgPath, name string, methods []Imethod) *InterfaceType {
PkgPath_: pkgPath, PkgPath_: pkgPath,
Methods: methods, Methods: methods,
} }
if len(methods) == 0 {
ret.Equal = nilinterequal
} else {
ret.Equal = interequal
}
return ret return ret
} }
@@ -355,12 +360,6 @@ func Implements(T, V *abi.Type) bool {
} }
func EfaceEqual(v, u eface) bool { func EfaceEqual(v, u eface) bool {
if v.Kind() == abi.Interface {
v = v.Elem()
}
if u.Kind() == abi.Interface {
u = u.Elem()
}
if v._type == nil || u._type == nil { if v._type == nil || u._type == nil {
return v._type == u._type return v._type == u._type
} }
@@ -370,52 +369,10 @@ func EfaceEqual(v, u eface) bool {
if isDirectIface(v._type) { if isDirectIface(v._type) {
return v.data == u.data return v.data == u.data
} }
switch v.Kind() { if equal := v._type.Equal; equal != nil {
case abi.Bool, return equal(v.data, u.data)
abi.Int, abi.Int8, abi.Int16, abi.Int32, abi.Int64,
abi.Uint, abi.Uint8, abi.Uint16, abi.Uint32, abi.Uint64, abi.Uintptr,
abi.Float32, abi.Float64:
return *(*uintptr)(v.data) == *(*uintptr)(u.data)
case abi.Complex64:
return *(*complex64)(v.data) == *(*complex64)(u.data)
case abi.Complex128:
return *(*complex128)(v.data) == *(*complex128)(u.data)
case abi.String:
return *(*string)(v.data) == *(*string)(u.data)
case abi.Pointer, abi.UnsafePointer:
return v.data == u.data
case abi.Array:
n := v._type.Len()
tt := v._type.ArrayType()
index := func(data unsafe.Pointer, i int) eface {
offset := i * int(tt.Elem.Size_)
return eface{tt.Elem, c.Advance(data, offset)}
} }
for i := 0; i < n; i++ { panic(errorString("comparing uncomparable type " + v._type.String()).Error())
if !EfaceEqual(index(v.data, i), index(u.data, i)) {
return false
}
}
return true
case abi.Struct:
st := v._type.StructType()
field := func(data unsafe.Pointer, ft *abi.StructField) eface {
ptr := c.Advance(data, int(ft.Offset))
if isDirectIface(ft.Typ) {
ptr = *(*unsafe.Pointer)(ptr)
}
return eface{ft.Typ, ptr}
}
for _, ft := range st.Fields {
if !EfaceEqual(field(v.data, &ft), field(u.data, &ft)) {
return false
}
}
return true
case abi.Func, abi.Map, abi.Slice:
break
}
panic("not comparable")
} }
func (v eface) Kind() abi.Kind { func (v eface) Kind() abi.Kind {

View File

@@ -82,29 +82,3 @@ func MapIterNext(it *hiter) (ok bool, k unsafe.Pointer, v unsafe.Pointer) {
mapiternext(it) mapiternext(it)
return return
} }
func mapKeyEqual(t *maptype, p, q unsafe.Pointer) bool {
if isDirectIface(t.Key) {
switch t.Key.Size_ {
case 0:
return true
case 1:
return memequal8(p, q)
case 2:
return memequal16(p, q)
case 4:
return memequal32(p, q)
case 8:
return memequal64(p, q)
}
}
switch t.Key.Kind() {
case abi.String:
return strequal(p, q)
case abi.Complex64:
return c64equal(p, q)
case abi.Complex128:
return c128equal(p, q)
}
return t.Key.Equal(p, q)
}

View File

@@ -30,6 +30,36 @@ var (
tyBasic [abi.UnsafePointer + 1]*Type tyBasic [abi.UnsafePointer + 1]*Type
) )
func basicEqual(kind Kind, size uintptr) func(a, b unsafe.Pointer) bool {
switch kind {
case abi.Bool, abi.Int, abi.Int8, abi.Int16, abi.Int32, abi.Int64,
abi.Uint, abi.Uint8, abi.Uint16, abi.Uint32, abi.Uint64, abi.Uintptr:
switch size {
case 1:
return memequal8
case 2:
return memequal16
case 4:
return memequal32
case 8:
return memequal64
}
case abi.Float32:
return f32equal
case abi.Float64:
return f64equal
case abi.Complex64:
return c64equal
case abi.Complex128:
return c128equal
case abi.String:
return strequal
case abi.UnsafePointer:
return ptrequal
}
panic("unreachable")
}
func Basic(kind Kind) *Type { func Basic(kind Kind) *Type {
if tyBasic[kind] == nil { if tyBasic[kind] == nil {
name, size, align := basicTypeInfo(kind) name, size, align := basicTypeInfo(kind)
@@ -39,10 +69,8 @@ func Basic(kind Kind) *Type {
Align_: uint8(align), Align_: uint8(align),
FieldAlign_: uint8(align), FieldAlign_: uint8(align),
Kind_: uint8(kind), Kind_: uint8(kind),
Equal: basicEqual(kind, size),
Str_: name, Str_: name,
Equal: func(a, b unsafe.Pointer) bool {
return uintptr(a) == uintptr(b)
},
} }
} }
return tyBasic[kind] return tyBasic[kind]
@@ -115,15 +143,33 @@ func Struct(pkgPath string, size uintptr, fields ...abi.StructField) *Type {
PkgPath_: pkgPath, PkgPath_: pkgPath,
Fields: fields, Fields: fields,
} }
var comparable bool = true
var typalign uint8 var typalign uint8
for _, f := range fields { for _, f := range fields {
ft := f.Typ ft := f.Typ
if ft.Align_ > typalign { if ft.Align_ > typalign {
typalign = ft.Align_ typalign = ft.Align_
} }
comparable = comparable && (ft.Equal != nil)
} }
ret.Align_ = typalign ret.Align_ = typalign
ret.FieldAlign_ = typalign ret.FieldAlign_ = typalign
if comparable {
if size == 0 {
ret.Equal = memequal0
} else {
ret.Equal = func(p, q unsafe.Pointer) bool {
for _, ft := range fields {
pi := add(p, ft.Offset)
qi := add(q, ft.Offset)
if !ft.Typ.Equal(pi, qi) {
return false
}
}
return true
}
}
}
return &ret.Type return &ret.Type
} }
@@ -149,6 +195,7 @@ func newPointer(elem *Type) *Type {
Align_: pointerAlign, Align_: pointerAlign,
FieldAlign_: pointerAlign, FieldAlign_: pointerAlign,
Kind_: uint8(abi.Pointer), Kind_: uint8(abi.Pointer),
Equal: ptrequal,
}, },
Elem: elem, Elem: elem,
} }
@@ -192,6 +239,22 @@ func ArrayOf(length uintptr, elem *Type) *Type {
Slice: SliceOf(elem), Slice: SliceOf(elem),
Len: length, Len: length,
} }
if eequal := elem.Equal; eequal != nil {
if elem.Size_ == 0 {
ret.Equal = memequal0
} else {
ret.Equal = func(p, q unsafe.Pointer) bool {
for i := uintptr(0); i < length; i++ {
pi := add(p, i*elem.Size_)
qi := add(q, i*elem.Size_)
if !eequal(pi, qi) {
return false
}
}
return true
}
}
}
return &ret.Type return &ret.Type
} }