将基类转换为派生类的Python方法(或更Pythonic的扩展类方式)

63 投票
9 回答
68580 浏览
提问于 2025-04-16 02:40

我想扩展一下Networkx这个Python库,给它的Graph类添加一些我需要的方法。

我想的办法是简单地创建一个新类,比如叫NewGraph,然后在里面添加我需要的方法。

不过,Networkx里还有很多其他的函数,它们会创建并返回Graph对象(比如生成一个随机图)。现在我需要把这些Graph对象转换成NewGraph对象,这样我才能使用我新增的方法。

那么,最好的做法是什么呢?或者我应该换个完全不同的方式来解决这个问题吗?

9 个回答

4

我把PaulMcG做的东西扩展了一下,把它变成了一个工厂模式。

class A:
 def __init__(self, variable):
    self.a = 10
    self.a_variable = variable

 def do_something(self):
    print("do something A")


class B(A):

 def __init__(self, variable=None):
    super().__init__(variable)
    self.b = 15

 @classmethod
 def from_A(cls, a: A):
    # Create new b_obj
    b_obj = cls()
    # Copy all values of A to B
    # It does not have any problem since they have common template
    for key, value in a.__dict__.items():
        b_obj.__dict__[key] = value
    return b_obj

if __name__ == "__main__":
 a = A(variable="something")
 b = B.from_A(a=a)
 print(a.__dict__)
 print(b.__dict__)
 b.do_something()
 print(type(b))

结果:

{'a': 10, 'a_variable': 'something'}
{'a': 10, 'a_variable': 'something', 'b': 15}
do something A
<class '__main__.B'>
17

下面是如何“神奇地”用自己制作的子类替换模块中的一个类,而不需要修改这个模块。其实只需要在正常的子类化过程中多加几行代码,这样你就几乎可以享受到子类化带来的所有强大功能和灵活性。例如,这样可以让你添加新的属性,如果你想的话。

import networkx as nx

class NewGraph(nx.Graph):
    def __getattribute__(self, attr):
        "This is just to show off, not needed"
        print "getattribute %s" % (attr,)
        return nx.Graph.__getattribute__(self, attr)

    def __setattr__(self, attr, value):
        "More showing off."
        print "    setattr %s = %r" % (attr, value)
        return nx.Graph.__setattr__(self, attr, value)

    def plot(self):
        "A convenience method"
        import matplotlib.pyplot as plt
        nx.draw(self)
        plt.show()

到这里为止,这和正常的子类化完全一样。接下来,我们需要把这个子类连接到 networkx 模块,这样每次创建 nx.Graph 的实例时,都会得到一个 NewGraph 对象。通常情况下,当你用 nx.Graph() 创建一个 nx.Graph 对象时,会发生这样的事情:

1. nx.Graph.__new__(nx.Graph) is called
2. If the returned object is a subclass of nx.Graph, 
   __init__ is called on the object
3. The object is returned as the instance

我们将替换 nx.Graph.__new__,让它返回 NewGraph 对象。在这里,我们调用的是 object__new__ 方法,而不是 NewGraph__new__ 方法,因为后者只是另一种调用我们要替换的方法的方式,这样会导致无限递归。

def __new__(cls):
    if cls == nx.Graph:
        return object.__new__(NewGraph)
    return object.__new__(cls)

# We substitute the __new__ method of the nx.Graph class
# with our own.     
nx.Graph.__new__ = staticmethod(__new__)

# Test if it works
graph = nx.generators.random_graphs.fast_gnp_random_graph(7, 0.6)
graph.plot()

在大多数情况下,这些就是你需要知道的全部内容,但有一个小问题。我们重写的 __new__ 方法只影响 nx.Graph,而不影响它的子类。例如,如果你调用 nx.gn_graph,它会返回一个 nx.DiGraph 的实例,这个实例就不会有我们添加的那些新功能。你需要为每一个你想要使用的 nx.Graph 的子类进行子类化,并添加你需要的方法和属性。使用 mix-ins 可能会让你更容易一致地扩展这些子类,同时遵循 DRY 原则。

虽然这个例子看起来很简单,但这种连接模块的方法很难普遍适用,因为可能会出现各种小问题。我认为最好是根据具体问题进行调整。例如,如果你要连接的类定义了自己的自定义 __new__ 方法,你需要在替换之前先保存它,然后调用这个方法,而不是 object.__new__

94

如果你只是想给一个对象添加一些功能,而不需要依赖额外的实例值,你可以直接给这个对象的 __class__ 属性赋值:

from math import pi

class Circle(object):
    def __init__(self, radius):
        self.radius = radius

    def area(self):
        return pi * self.radius**2

class CirclePlus(Circle):
    def diameter(self):
        return self.radius*2

    def circumference(self):
        return self.radius*2*pi

c = Circle(10)
print c.radius
print c.area()
print repr(c)

c.__class__ = CirclePlus
print c.diameter()
print c.circumference()
print repr(c)

输出结果是:

10
314.159265359
<__main__.Circle object at 0x00A0E270>
20
62.8318530718
<__main__.CirclePlus object at 0x00A0E270>

这在Python中算是最接近“类型转换”的方式了,就像在C语言中的类型转换一样,使用之前最好先想清楚。我这里举了一个比较简单的例子,但如果你能遵循这些限制(只添加功能,不添加新的实例变量),那么这可能会帮助你解决问题。

撰写回答