将基类转换为派生类的Python方法(或更Pythonic的扩展类方式)
我想扩展一下Networkx这个Python库,给它的Graph
类添加一些我需要的方法。
我想的办法是简单地创建一个新类,比如叫NewGraph
,然后在里面添加我需要的方法。
不过,Networkx里还有很多其他的函数,它们会创建并返回Graph
对象(比如生成一个随机图)。现在我需要把这些Graph
对象转换成NewGraph
对象,这样我才能使用我新增的方法。
那么,最好的做法是什么呢?或者我应该换个完全不同的方式来解决这个问题吗?
9 个回答
我把PaulMcG做的东西扩展了一下,把它变成了一个工厂模式。
class A:
def __init__(self, variable):
self.a = 10
self.a_variable = variable
def do_something(self):
print("do something A")
class B(A):
def __init__(self, variable=None):
super().__init__(variable)
self.b = 15
@classmethod
def from_A(cls, a: A):
# Create new b_obj
b_obj = cls()
# Copy all values of A to B
# It does not have any problem since they have common template
for key, value in a.__dict__.items():
b_obj.__dict__[key] = value
return b_obj
if __name__ == "__main__":
a = A(variable="something")
b = B.from_A(a=a)
print(a.__dict__)
print(b.__dict__)
b.do_something()
print(type(b))
结果:
{'a': 10, 'a_variable': 'something'}
{'a': 10, 'a_variable': 'something', 'b': 15}
do something A
<class '__main__.B'>
下面是如何“神奇地”用自己制作的子类替换模块中的一个类,而不需要修改这个模块。其实只需要在正常的子类化过程中多加几行代码,这样你就几乎可以享受到子类化带来的所有强大功能和灵活性。例如,这样可以让你添加新的属性,如果你想的话。
import networkx as nx
class NewGraph(nx.Graph):
def __getattribute__(self, attr):
"This is just to show off, not needed"
print "getattribute %s" % (attr,)
return nx.Graph.__getattribute__(self, attr)
def __setattr__(self, attr, value):
"More showing off."
print " setattr %s = %r" % (attr, value)
return nx.Graph.__setattr__(self, attr, value)
def plot(self):
"A convenience method"
import matplotlib.pyplot as plt
nx.draw(self)
plt.show()
到这里为止,这和正常的子类化完全一样。接下来,我们需要把这个子类连接到 networkx
模块,这样每次创建 nx.Graph
的实例时,都会得到一个 NewGraph
对象。通常情况下,当你用 nx.Graph()
创建一个 nx.Graph
对象时,会发生这样的事情:
1. nx.Graph.__new__(nx.Graph) is called 2. If the returned object is a subclass of nx.Graph, __init__ is called on the object 3. The object is returned as the instance
我们将替换 nx.Graph.__new__
,让它返回 NewGraph
对象。在这里,我们调用的是 object
的 __new__
方法,而不是 NewGraph
的 __new__
方法,因为后者只是另一种调用我们要替换的方法的方式,这样会导致无限递归。
def __new__(cls):
if cls == nx.Graph:
return object.__new__(NewGraph)
return object.__new__(cls)
# We substitute the __new__ method of the nx.Graph class
# with our own.
nx.Graph.__new__ = staticmethod(__new__)
# Test if it works
graph = nx.generators.random_graphs.fast_gnp_random_graph(7, 0.6)
graph.plot()
在大多数情况下,这些就是你需要知道的全部内容,但有一个小问题。我们重写的 __new__
方法只影响 nx.Graph
,而不影响它的子类。例如,如果你调用 nx.gn_graph
,它会返回一个 nx.DiGraph
的实例,这个实例就不会有我们添加的那些新功能。你需要为每一个你想要使用的 nx.Graph
的子类进行子类化,并添加你需要的方法和属性。使用 mix-ins 可能会让你更容易一致地扩展这些子类,同时遵循 DRY 原则。
虽然这个例子看起来很简单,但这种连接模块的方法很难普遍适用,因为可能会出现各种小问题。我认为最好是根据具体问题进行调整。例如,如果你要连接的类定义了自己的自定义 __new__
方法,你需要在替换之前先保存它,然后调用这个方法,而不是 object.__new__
。
如果你只是想给一个对象添加一些功能,而不需要依赖额外的实例值,你可以直接给这个对象的 __class__
属性赋值:
from math import pi
class Circle(object):
def __init__(self, radius):
self.radius = radius
def area(self):
return pi * self.radius**2
class CirclePlus(Circle):
def diameter(self):
return self.radius*2
def circumference(self):
return self.radius*2*pi
c = Circle(10)
print c.radius
print c.area()
print repr(c)
c.__class__ = CirclePlus
print c.diameter()
print c.circumference()
print repr(c)
输出结果是:
10
314.159265359
<__main__.Circle object at 0x00A0E270>
20
62.8318530718
<__main__.CirclePlus object at 0x00A0E270>
这在Python中算是最接近“类型转换”的方式了,就像在C语言中的类型转换一样,使用之前最好先想清楚。我这里举了一个比较简单的例子,但如果你能遵循这些限制(只添加功能,不添加新的实例变量),那么这可能会帮助你解决问题。