如何比较两个ctypes对象是否相等?

5 投票
2 回答
2232 浏览
提问于 2025-04-18 10:21
import ctypes as ct

class Point(ct.Structure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]

p1 = Point(10, 10)
p2 = Point(10, 10)

print p1 == p2 # => False

在上面的简单例子中,等于运算符 '==' 返回了 False。有没有什么简单的方法可以解决这个问题呢?

补充说明:

这里有一个稍微改进的版本(基于被接受的答案),它也可以处理嵌套数组的情况:

import ctypes as ct

class CtStruct(ct.Structure):

    def __eq__(self, other):
        for field in self._fields_:
            attr_name = field[0]
            a, b = getattr(self, attr_name), getattr(other, attr_name)
            is_array = isinstance(a, ct.Array)
            if is_array and a[:] != b[:] or not is_array and a != b:
                return False
        return True

    def __ne__(self, other):
        for field in self._fields_:
            attr_name = field[0]
            a, b = getattr(self, attr_name), getattr(other, attr_name)
            is_array = isinstance(a, ct.Array)
            if is_array and a[:] != b[:] or not is_array and a != b:
                return True
        return False

class Point(CtStruct):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
        ('arr', ct.c_int * 2),
    ]

p1 = Point(10, 20, (30, 40))
p2 = Point(10, 20, (30, 40))

print p1 == p2 # True

2 个回答

3

在这个简单的例子中,p1.x == p2.x and p1.y = p2.y 这样写是可以的。

你也可以在你的 Point 类里面实现 __eq__()__ne__() 这两个方法:

class Point(ct.Structure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]
    def __eq__(self, other):
        return (self.x == other.x) and (self.y == other.y)
    def __ne__(self, other):
        return not self.__eq__(other)

>>> p1 = Point(10, 10)
>>> p2 = Point(10, 10)
>>> p3 = Point(10, 66)
>>> p1 == p2
True
>>> p1 != p2
False
>>> p1 == p3
False
>>> p1 != p3
True
9

创建一个名为 MyCtStructure 的类,这样它的所有子类就不需要再实现 __eq____ne__ 这两个方法了。对于你来说,定义相等的功能就不会再是一件麻烦的事情了。

import ctypes as ct
class MyCtStructure(ct.Structure):

    def __eq__(self, other):
        for fld in self._fields_:
            if getattr(self, fld[0]) != getattr(other, fld[0]):
                return False
        return True

    def __ne__(self, other):
        for fld in self._fields_:
            if getattr(self, fld[0]) != getattr(other, fld[0]):
                return True
        return False

class Point(MyCtStructure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]


p1 = Point(10, 11)
p2 = Point(10, 11)

print p1 == p2

撰写回答