如何让类支持比较操作?

需求:
有时我们希望自定义类,实例间可以使用<,<=,>,>=,==,!=,符号进行比较,我们自定义比较的行为,例如,有一个矩形的类,我们希望比较两个矩形的实例时,比较的是他们的面积。

class Rectangel:
    def __init__(self,w,h):
        self.w = W
        self.h = h

    def area(self):
        return sefl.w * self.h

rect1 = Rectangel(5,3)
rect2 = Rectangel(4.4)
rect1 > rect2  # => rect1.area() > rect2.area()

思路
1、比较符号运算符重载,需要实现以下方法
__lt__,__le__,__gt__,__ge__,__eq__,__ne_
2、使用标准库下的functools下类装饰器total_ordering可以简化此过程

代码:

方法1:
class Rectangel:
    def __init__(self,w,h):
        self.w = w
        self.h = h

    def area(self):
        return self.w * self.h

    def __lt__(self,obj):
        print('in__lt__')
        return self.area() < obj.area()

    def __le__(self,obj):
        print('in__le__')
        return self.area() <= obj.area()
rect1 = Rectangel(5,3)
rect2 = Rectangel(4,4)
#rect1 > rect2  # => rect1.area() > rect2.area()

print(rect1 < rect2) # rect1.__lt__(rect2)
print(rect1 <= rect2)

方法二:
from functools import total_ordering

@total_ordering
class Rectangel:
    def __init__(self,w,h):
        self.w = w
        self.h = h

    def area(self):
        return self.w * self.h

    def __lt__(self,obj):
        print('in__lt__')
        return self.area() < obj.area()

    def __eq__(self,obj):
        print('in__eq__')
        return self.area() == obj.area()

class Circle(object):
    def __init__(self,r):
        self.r = r

    def area(self):
        return self.r ** 2 * 3.14


rect1 = Rectangel(5,3)
rect2 = Rectangel(4,4)
c1 = Circle(3)

#rect1 > rect2  # => rect1.area() > rect2.area()
print(rect1 < c1)
print(c1 > 1)
print(rect1 < rect2) # rect1.__lt__(rect2)
print(rect1 >= rect2)

方法三:
from functools import total_ordering
from abc import ABCMeta, abstractmethod  # 通过抽象类和抽象方法来实现公共的抽象基类,为什么要定义成抽象类呢?因为要让它的所有子类都必须实现抽象方法,否则就无法比较了。元类为ABCmeta,再定义一个抽象方法不用实现,等待让子类实现即可。

@total_ordering
class Shape(metaclass=ABCMeta):  # 把运算符重载的函数都放到公共的抽象基类中,这样可以避免其他的类都要写运算符的函数,其他的函数中只要实现area()的方法就可以了
                      # 再定义一个抽象的接口,能比较的都要实现这个area,否则不能进行比较
    @abstractmethod
    def area(self):   # 描述一下抽象的接口,它的子类都要实现这个接口
        pass

    def __lt__(self,obj):
        print('in__lt__')
        if not isinstance(obj,Shape):
            raise TypeError('obj is not Shape')
        return self.area() < obj.area()

    def __eq__(self,obj):
        if not isinstance(obj,Shape):
            raise TypeError('obj is not Shape')
        print('in__eq__')
        return self.area() == obj.area()


class Rectangel(Shape):
    def __init__(self,w,h):
        self.w = w
        self.h = h

    def area(self):
        return self.w * self.h


class Circle(Shape):
    def __init__(self,r):
        self.r = r

    def area(self):
        return self.r ** 2 * 3.14


rect1 = Rectangel(5,3)
rect2 = Rectangel(4,4)
c1 = Circle(3)

#rect1 > rect2  # => rect1.area() > rect2.area()
print(rect1 < c1)
print(c1 > rect2)
print(rect1 < rect2) # rect1.__lt__(rect2)
print(rect1 >= rect2)

===========================================================
>>> a = 5
>>> b = 3
>>> a < b
False
>>> a.__lt__(b)
False
>>> a >= b
True
>>> a.__ge__(b)
True
>>> s1 = 'abc'
>>> s2 = 'abd'
>>> s1 > s2
False
>>> s1.__gt__(s2)
False
>>> ord('c') > ord('d')
False
>>> {1,2,3} > {4}
False
>>> {1,2,3} < {4}
False
>>> {1,2,3} = {4}
  File "<ipython-input-15-1641dbbcfca1>", line 1
    {1,2,3} = {4}
                 ^
SyntaxError: can't assign to literal

>>> {1,2,3} == {4}
False
>>> {1,2,3} > {1,3} # 实际是包含的关系
True
>>> {1,2,3} > {1,2}
True
>>> 
====================================================
from functools import total_ordering

from abc import ABCMeta,abstractclassmethod

class Shape(metaclass=ABCMeta):  # 实现一个抽象基类
    @abstractclassmethod # 抽象方法等待子类去实现,要求子类都要实现area方法,否则没法比较
    def area(self):
        pass

    def __lt__(self,obj):
        print('__lt__',self,obj)
        return self.area() < obj.area()

    def __eq__(self,obj):
        return self.area() == obj.area()

@total_ordering
class Rect(Shape):
    def __init__(self,w,h):
        self.w = w
        self.h = h

    def area(self):
        return self.w * self.h

    def __str__(self):
        return 'Rect:(%s,%s)' % (self.w,self.h)
    

import math

class Circle(Shape):
    def __init__(self,r):
        self.r = r

    def area(self):
        return self.r ** 2 * math.pi

rect1 = Rect(6,9) # 54
'''
print(rect1 < rect2)
print(rect1 >= rect2) # rect2 < rect1
print(rect1 <= rect2)
'''
rect2 = Rect(7,8) # 56
c = Circle(8)

print(rect1<c)
print(c>rect2)

posted @ 2020-07-22 23:07  Richardo-M-Lu  阅读(150)  评论(0编辑  收藏  举报