Question

In Python, given a module X and a class Y, how can I iterate or generate a list of all subclasses of Y that exist in module X?

Was it helpful?

Solution

Here's one way to do it:

import inspect

def get_subclasses(mod, cls):
    """Yield the classes in module ``mod`` that inherit from ``cls``"""
    for name, obj in inspect.getmembers(mod):
        if hasattr(obj, "__bases__") and cls in obj.__bases__:
            yield obj

OTHER TIPS

Although Quamrana's suggestion works fine, there are a couple of possible improvements I'd like to suggest to make it more pythonic. They rely on using the inspect module from the standard library.

  1. You can avoid the getattr call by using inspect.getmembers()
  2. The try/catch can be avoided by using inspect.isclass()

With those, you can reduce the whole thing to a single list comprehension if you like:

def find_subclasses(module, clazz):
    return [
        cls
            for name, cls in inspect.getmembers(module)
                if inspect.isclass(cls) and issubclass(cls, clazz)
    ]

Can I suggest that neither of the answers from Chris AtLee and zacherates fulfill the requirements? I think this modification to zacerates answer is better:

def find_subclasses(module, clazz):
    for name in dir(module):
        o = getattr(module, name)
        try:
            if (o != clazz) and issubclass(o, clazz):
                yield name, o
        except TypeError: pass

The reason I disagree with the given answers is that the first does not produce classes that are a distant subclass of the given class, and the second includes the given class.

Given the module foo.py

class foo(object): pass
class bar(foo): pass
class baz(foo): pass

class grar(Exception): pass

def find_subclasses(module, clazz):
    for name in dir(module):
        o = getattr(module, name)

        try: 
             if issubclass(o, clazz):
             yield name, o
        except TypeError: pass

>>> import foo
>>> list(foo.find_subclasses(foo, foo.foo))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>)]
>>> list(foo.find_subclasses(foo, object))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>), ('grar', <class 'foo.grar'>)]
>>> list(foo.find_subclasses(foo, Exception))
[('grar', <class 'foo.grar'>)]
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top