Elegant ways to support equivalence ("equality") in Python classes
Consider this simple problem:
class Number: def __init__(self, number): self.number = numbern1 = Number(1)n2 = Number(1)n1 == n2 # False -- oops
So, Python by default uses the object identifiers for comparison operations:
id(n1) # 140400634555856id(n2) # 140400634555920
Overriding the __eq__
function seems to solve the problem:
def __eq__(self, other): """Overrides the default implementation""" if isinstance(other, Number): return self.number == other.number return Falsen1 == n2 # Truen1 != n2 # True in Python 2 -- oops, False in Python 3
In Python 2, always remember to override the __ne__
function as well, as the documentation states:
There are no implied relationships among the comparison operators. The truth of
x==y
does not imply thatx!=y
is false. Accordingly, when defining__eq__()
, one should also define__ne__()
so that the operators will behave as expected.
def __ne__(self, other): """Overrides the default implementation (unnecessary in Python 3)""" return not self.__eq__(other)n1 == n2 # Truen1 != n2 # False
In Python 3, this is no longer necessary, as the documentation states:
By default,
__ne__()
delegates to__eq__()
and inverts the result unless it isNotImplemented
. There are no other implied relationships among the comparison operators, for example, the truth of(x<y or x==y)
does not implyx<=y
.
But that does not solve all our problems. Let’s add a subclass:
class SubNumber(Number): passn3 = SubNumber(1)n1 == n3 # False for classic-style classes -- oops, True for new-style classesn3 == n1 # Truen1 != n3 # True for classic-style classes -- oops, False for new-style classesn3 != n1 # False
Note: Python 2 has two kinds of classes:
classic-style (or old-style) classes, that do not inherit from
object
and that are declared asclass A:
,class A():
orclass A(B):
whereB
is a classic-style class;new-style classes, that do inherit from
object
and that are declared asclass A(object)
orclass A(B):
whereB
is a new-style class. Python 3 has only new-style classes that are declared asclass A:
,class A(object):
orclass A(B):
.
For classic-style classes, a comparison operation always calls the method of the first operand, while for new-style classes, it always calls the method of the subclass operand, regardless of the order of the operands.
So here, if Number
is a classic-style class:
n1 == n3
callsn1.__eq__
;n3 == n1
callsn3.__eq__
;n1 != n3
callsn1.__ne__
;n3 != n1
callsn3.__ne__
.
And if Number
is a new-style class:
- both
n1 == n3
andn3 == n1
calln3.__eq__
; - both
n1 != n3
andn3 != n1
calln3.__ne__
.
To fix the non-commutativity issue of the ==
and !=
operators for Python 2 classic-style classes, the __eq__
and __ne__
methods should return the NotImplemented
value when an operand type is not supported. The documentation defines the NotImplemented
value as:
Numeric methods and rich comparison methods may return this value if they do not implement the operation for the operands provided. (The interpreter will then try the reflected operation, or some other fallback, depending on the operator.) Its truth value is true.
In this case the operator delegates the comparison operation to the reflected method of the other operand. The documentation defines reflected methods as:
There are no swapped-argument versions of these methods (to be used when the left argument does not support the operation but the right argument does); rather,
__lt__()
and__gt__()
are each other’s reflection,__le__()
and__ge__()
are each other’s reflection, and__eq__()
and__ne__()
are their own reflection.
The result looks like this:
def __eq__(self, other): """Overrides the default implementation""" if isinstance(other, Number): return self.number == other.number return NotImplementeddef __ne__(self, other): """Overrides the default implementation (unnecessary in Python 3)""" x = self.__eq__(other) if x is NotImplemented: return NotImplemented return not x
Returning the NotImplemented
value instead of False
is the right thing to do even for new-style classes if commutativity of the ==
and !=
operators is desired when the operands are of unrelated types (no inheritance).
Are we there yet? Not quite. How many unique numbers do we have?
len(set([n1, n2, n3])) # 3 -- oops
Sets use the hashes of objects, and by default Python returns the hash of the identifier of the object. Let’s try to override it:
def __hash__(self): """Overrides the default implementation""" return hash(tuple(sorted(self.__dict__.items())))len(set([n1, n2, n3])) # 1
The end result looks like this (I added some assertions at the end for validation):
class Number: def __init__(self, number): self.number = number def __eq__(self, other): """Overrides the default implementation""" if isinstance(other, Number): return self.number == other.number return NotImplemented def __ne__(self, other): """Overrides the default implementation (unnecessary in Python 3)""" x = self.__eq__(other) if x is not NotImplemented: return not x return NotImplemented def __hash__(self): """Overrides the default implementation""" return hash(tuple(sorted(self.__dict__.items())))class SubNumber(Number): passn1 = Number(1)n2 = Number(1)n3 = SubNumber(1)n4 = SubNumber(4)assert n1 == n2assert n2 == n1assert not n1 != n2assert not n2 != n1assert n1 == n3assert n3 == n1assert not n1 != n3assert not n3 != n1assert not n1 == n4assert not n4 == n1assert n1 != n4assert n4 != n1assert len(set([n1, n2, n3, ])) == 1assert len(set([n1, n2, n3, n4])) == 2
You need to be careful with inheritance:
>>> class Foo: def __eq__(self, other): if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ else: return False>>> class Bar(Foo):pass>>> b = Bar()>>> f = Foo()>>> f == bTrue>>> b == fFalse
Check types more strictly, like this:
def __eq__(self, other): if type(other) is type(self): return self.__dict__ == other.__dict__ return False
Besides that, your approach will work fine, that's what special methods are there for.
The way you describe is the way I've always done it. Since it's totally generic, you can always break that functionality out into a mixin class and inherit it in classes where you want that functionality.
class CommonEqualityMixin(object): def __eq__(self, other): return (isinstance(other, self.__class__) and self.__dict__ == other.__dict__) def __ne__(self, other): return not self.__eq__(other)class Foo(CommonEqualityMixin): def __init__(self, item): self.item = item