Cast base class to derived class python (or more pythonic way of extending classes)
-
27-09-2019 - |
문제
I need to extend the Networkx python package and add a few methods to the Graph
class for my particular need
The way I thought about doing this is simplying deriving a new class say NewGraph
, and adding the required methods.
However there are several other functions in networkx which create and return Graph
objects (e.g. generate a random graph). I now need to turn these Graph
objects into NewGraph
objects so that I can use my new methods.
What is the best way of doing this? Or should I be tackling the problem in a completely different manner?
해결책
If you are just adding behavior, and not depending on additional instance values, you can assign to the object's __class__
:
from math import pi
class Circle(object):
def __init__(self, radius):
self.radius = radius
def area(self):
return pi * self.radius**2
class CirclePlus(Circle):
def diameter(self):
return self.radius*2
def circumference(self):
return self.radius*2*pi
c = Circle(10)
print c.radius
print c.area()
print repr(c)
c.__class__ = CirclePlus
print c.diameter()
print c.circumference()
print repr(c)
Prints:
10
314.159265359
<__main__.Circle object at 0x00A0E270>
20
62.8318530718
<__main__.CirclePlus object at 0x00A0E270>
This is as close to a "cast" as you can get in Python, and like casting in C, it is not to be done without giving the matter some thought. I've posted a fairly limited example, but if you can stay within the constraints (just add behavior, no new instance vars), then this might help address your problem.
다른 팁
Here's how to "magically" replace a class in a module with a custom-made subclass without touching the module. It's only a few extra lines from a normal subclassing procedure, and therefore gives you (almost) all the power and flexibility of subclassing as a bonus. For instance this allows you to add new attributes, if you wish.
import networkx as nx
class NewGraph(nx.Graph):
def __getattribute__(self, attr):
"This is just to show off, not needed"
print "getattribute %s" % (attr,)
return nx.Graph.__getattribute__(self, attr)
def __setattr__(self, attr, value):
"More showing off."
print " setattr %s = %r" % (attr, value)
return nx.Graph.__setattr__(self, attr, value)
def plot(self):
"A convenience method"
import matplotlib.pyplot as plt
nx.draw(self)
plt.show()
So far this is exactly like normal subclassing. Now we need to hook this subclass into the networkx
module so that all instantiation of nx.Graph
results in a NewGraph
object instead. Here's what normally happens when you instantiate an nx.Graph
object with nx.Graph()
1. nx.Graph.__new__(nx.Graph) is called 2. If the returned object is a subclass of nx.Graph, __init__ is called on the object 3. The object is returned as the instance
We will replace nx.Graph.__new__
and make it return NewGraph
instead. In it, we call the __new__
method of object
instead of the __new__
method of NewGraph
, because the latter is just another way of calling the method we're replacing, and would therefore result in endless recursion.
def __new__(cls):
if cls == nx.Graph:
return object.__new__(NewGraph)
return object.__new__(cls)
# We substitute the __new__ method of the nx.Graph class
# with our own.
nx.Graph.__new__ = staticmethod(__new__)
# Test if it works
graph = nx.generators.random_graphs.fast_gnp_random_graph(7, 0.6)
graph.plot()
In most cases this is all you need to know, but there is one gotcha. Our overriding of the __new__
method only affects nx.Graph
, not its subclasses. For example, if you call nx.gn_graph
, which returns an instance of nx.DiGraph
, it will have none of our fancy extensions. You need to subclass each of the subclasses of nx.Graph
that you wish to work with and add your required methods and attributes. Using mix-ins may make it easier to consistently extend the subclasses while obeying the DRY principle.
Though this example may seem straightforward enough, this method of hooking into a module is hard to generalize in a way that covers all the little problems that may crop up. I believe it's easier to just tailor it to the problem at hand. For instance, if the class you're hooking into defines its own custom __new__
method, you need to store it before replacing it, and call this method instead of object.__new__
.
If a function is creating Graph objects, you can't turn them into NewGraph objects.
Another option is for NewGraph is to have a Graph rather than being a Graph. You delegate the Graph methods to the Graph object you have, and you can wrap any Graph object into a new NewGraph object:
class NewGraph:
def __init__(self, graph):
self.graph = graph
def some_graph_method(self, *args, **kwargs):
return self.graph.some_graph_method(*args, **kwargs)
#.. do this for the other Graph methods you need
def my_newgraph_method(self):
....
For your simple case you could also write your subclass __init__
like this and assign the pointers from the Graph data structures to your subclass data.
from networkx import Graph
class MyGraph(Graph):
def __init__(self, graph=None, **attr):
if graph is not None:
self.graph = graph.graph # graph attributes
self.node = graph.node # node attributes
self.adj = graph.adj # adjacency dict
else:
self.graph = {} # empty graph attr dict
self.node = {} # empty node attr dict
self.adj = {} # empty adjacency dict
self.edge = self.adj # alias
self.graph.update(attr) # update any command line attributes
if __name__=='__main__':
import networkx as nx
R=nx.gnp_random_graph(10,0.4)
G=MyGraph(R)
You could also use copy() or deepcopy() in the assignments but if you are doing that you might as well use
G=MyGraph()
G.add_nodes_from(R)
G.add_edges_from(R.edges())
to load your graph data.
You could simply create a new NewGraph
derived from Graph
object and have the __init__
function include something like self.__dict__.update(vars(incoming_graph))
as the first line, before you define your own properties. In this way you basically copy all the properties from the Graph
you have onto a new object, derived from Graph
, but with your special sauce.
class NewGraph(Graph):
def __init__(self, incoming_graph):
self.__dict__.update(vars(incoming_graph))
# rest of my __init__ code, including properties and such
Usage:
graph = function_that_returns_graph()
new_graph = NewGraph(graph)
cool_result = function_that_takes_new_graph(new_graph)
Have you guys tried [Python] cast base class to derived class
I have tested it, and seems it works. Also I think this method is bit better than below one since below one does not execute init function of derived function.
c.__class__ = CirclePlus