Iterate over subclasses of a given class in a given module
Question
In Python, given a module X and a class Y, how can I iterate or generate a list of all subclasses of Y that exist in module X?
Solution
Here's one way to do it:
import inspect
def get_subclasses(mod, cls):
"""Yield the classes in module ``mod`` that inherit from ``cls``"""
for name, obj in inspect.getmembers(mod):
if hasattr(obj, "__bases__") and cls in obj.__bases__:
yield obj
OTHER TIPS
Although Quamrana's suggestion works fine, there are a couple of possible improvements I'd like to suggest to make it more pythonic. They rely on using the inspect module from the standard library.
- You can avoid the getattr call by using
inspect.getmembers()
- The try/catch can be avoided by using
inspect.isclass()
With those, you can reduce the whole thing to a single list comprehension if you like:
def find_subclasses(module, clazz):
return [
cls
for name, cls in inspect.getmembers(module)
if inspect.isclass(cls) and issubclass(cls, clazz)
]
Can I suggest that neither of the answers from Chris AtLee and zacherates fulfill the requirements? I think this modification to zacerates answer is better:
def find_subclasses(module, clazz):
for name in dir(module):
o = getattr(module, name)
try:
if (o != clazz) and issubclass(o, clazz):
yield name, o
except TypeError: pass
The reason I disagree with the given answers is that the first does not produce classes that are a distant subclass of the given class, and the second includes the given class.
Given the module foo.py
class foo(object): pass
class bar(foo): pass
class baz(foo): pass
class grar(Exception): pass
def find_subclasses(module, clazz):
for name in dir(module):
o = getattr(module, name)
try:
if issubclass(o, clazz):
yield name, o
except TypeError: pass
>>> import foo
>>> list(foo.find_subclasses(foo, foo.foo))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>)]
>>> list(foo.find_subclasses(foo, object))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>), ('grar', <class 'foo.grar'>)]
>>> list(foo.find_subclasses(foo, Exception))
[('grar', <class 'foo.grar'>)]