Question

How can I find all the abstract base classes that a given class is a "virtual subclass" of?

In other words, I'm looking for a magic function virtual_base_classes() that does something like this:

>>> for cls in virtual_base_classes(list):
>>>   print(cls)
<class 'collections.abc.MutableSequence'>
<class 'collections.abc.Sequence'>
<class 'collections.abc.Sized'>
<class 'collections.abc.Iterable'>
<class 'collections.abc.Container'>

(I don't know all the abc classes that list is registered with, so the above example may not be complete.)

Note that not every abstract base class will be defined in collections.abc. There is a module abc (distinct from collections.abc) which provides the metaclass ABCMeta. Any class that is an instance of ABCMeta supports registration of "virtual subclasses" using the standard interface (the register method). There's nothing that stops someone (whether a programmer or Python library) from creating an instance of ABCMeta that does not belong in collections.abc.

Was it helpful?

Solution

Use issubclass and a list comprehension:

>>> import collections.abc
>>> import inspect
>>> [v for k, v in vars(collections.abc).items()
                                  if inspect.isclass(v) and issubclass(list, v) ]
[<class 'collections.abc.Container'>,
 <class 'collections.abc.Sequence'>,
 <class 'collections.abc.MutableSequence'>,
 <class 'collections.abc.Iterable'>,
 <class 'collections.abc.Sized'>
]

OTHER TIPS

This is, in fact, possible. The trick is to wrap the ABCMeta.register function, and scan object.__subclasses__() recursively to find anything registered before this module is imported.

''' Allow listing virtual bases from ABCMeta.

    Copyright © 2018 Ben Longbons <b.r.longbons@gmail.com>
    Licensed under GPL3+ or CC BY-SA 3.0+

    IMPORTANT CAVEATS:

    1. C++ programmers, remember "virtual base" means something
    different in Python!

    2. Importing this module as soon as possible, before creating any
    threads, and don't create bound references to `ABCMeta.register`.

    3. It's possible for there to be no valid MRO.

    4. This can't handle __subclasshook__ (for Callable etc).

    5. The MRO can change any time that `register` is called.
'''


import abc
import threading


tracked_bases = {}
tracker_lock = threading.Lock()
global_virtual_mro_calculator = None
global_virtual_mro_calculator_token = None


def dotted_name(obj, *, relative=None):
    ''' Fully-qualified name of a class, function, or module.
    '''
    if relative is not None:
        relative = relative.__module__
    name = getattr(obj, '__qualname__', obj.__name__)
    # modules themselves have just a __name__
    mod = getattr(obj, '__module__', None)
    if mod is not None and mod != relative:
        return '%s.%s' % (mod, name)
    return name

def sorted_types(collection):
    return sorted(collection, key=lambda cls: dotted_name(cls))

def sanity_check():
    standard_calculator = MroCalculator()
    virtual_mro_calculator = get_global_virtual_mro_calculator()
    for t in sorted_types(hunt_types(object)):
        std_mro = standard_calculator.mro(t)
        assert tuple(std_mro) == t.__mro__, (std_mro, t.__mro__)
        virt_mro = virtual_mro_calculator.mro(t)
        assert set(virt_mro) >= set(t.__mro__)

def bases_string(cls, *, virtual=False, relative=None):
    assert cls is not relative
    rv = [dotted_name(base, relative=relative) for base in cls.__bases__]
    if virtual:
        for base in tracked_bases.get(cls, ()):
            rv.append('+' + dotted_name(base, relative=relative))
    if len(rv) <= 1:
        return ''
    return '(%s)' % ', '.join(rv)

def dump_types(cls, *, virtual=False, depth=0, relative=None, sigil=''):
    name = dotted_name(cls, relative=relative)
    bases = bases_string(cls, virtual=virtual, relative=relative)
    if isinstance(cls, abc.ABCMeta):
        sigil += '!'
    print('  ' * depth, sigil, name, bases, sep='')
    depth += 1

    for child in sorted_types(type(cls).__subclasses__(cls)):
        dump_types(child, virtual=virtual, depth=depth, relative=cls)
    if virtual:
        if isinstance(cls, abc.ABCMeta):
            for child in sorted_types(cls._abc_registry):
                dump_types(child, virtual=virtual, depth=depth, relative=cls, sigil='+')

def hunt_types(cls, seen=None):
    if seen is None:
        seen = {cls}
    # There are usually only about 500 classes at startup, and the walk
    # finishes in under ~0.5 ms on CPython and ~0.05 ms on PyPy.
    #
    # Using `gc.get_referrers(abc.ABCMeta)` would take about the same
    # time on CPython, but it's *much* slower on PyPy (~50 ms).
    for scls in type(cls).__subclasses__(cls):
        if scls in seen:
            continue
        seen.add(scls)
        hunt_types(scls, seen)
    return seen

def _register_types(cls):
    assert isinstance(cls, abc.ABCMeta)
    # This order is also nondeterministic, but this one doesn't matter.
    for vscls in cls._abc_registry:
        tracked_bases.setdefault(vscls, []).append(cls)

def _install():
    with tracker_lock:
        old_register = abc.ABCMeta.register
        def abc_tracker_register(cls, subclass):
            with tracker_lock:
                if issubclass(subclass, cls):
                    return subclass
                rv = old_register(cls, subclass)
                assert rv is subclass
                tracked_bases.setdefault(subclass, []).append(cls)
                return rv
        abc.ABCMeta.register = abc_tracker_register
        # TODO nondeterministic order - `hunt_types` returns a `set`.
        # This only matters if some class has been registered as a
        # (direct) virtual subclass of *multiple* ABCs
        for t in hunt_types(object):
            if isinstance(t, abc.ABCMeta):
                _register_types(t)
        del t


class MroCalculator:
    # Based on https://www.python.org/download/releases/2.3/mro/
    # but with more flexibility.
    def __init__(self):
        self.mro_cache = {}

    def bases(self, cls):
        assert isinstance(cls, type)
        return cls.__bases__

    def merge(self, mros):
        assert isinstance(mros, list)
        for mro in mros:
            assert isinstance(mro, list)
            for cls in mro:
                assert isinstance(cls, type)

        rv = []
        while True:
            mros = [mro for mro in mros if mro]
            if not mros:
                return tuple(rv)
            for mro in mros:
                cand = mro[0]
                for m in mros:
                    if cand in m[1:]:
                        cand = None
                        break
                if cand is not None:
                    break
            if cand is None:
                bad_bases = [mro[0] for mro in mros]
                bad_base_str = ', '.join(sorted(set([dotted_name(b) for b in bad_bases])))
                raise TypeError('Cannot create a consistent method resolution order (MRO) for bases %s' % bad_base_str)
            rv.append(cand)
            for mro in mros:
                if cand is mro[0]:
                    del mro[0]

    def mro(self, cls):
        assert isinstance(cls, type)
        rv = self.mro_cache.get(cls, None)
        if rv is not None:
            return rv

        bases = list(self.bases(cls))
        if cls is not object:
            assert bases, cls
        rv = self.merge([[cls]] + [list(self.mro(b)) for b in bases] + [bases])
        assert cls not in self.mro_cache
        assert isinstance(rv, tuple)
        self.mro_cache[cls] = rv
        return rv

class VirtualMroCalculator(MroCalculator):
    def bases(self, cls):
        super_bases = super().bases(cls)
        more_bases = tuple(tracked_bases.get(cls, ()))
        # Virtual bases must come *first* in the MRO to handle:
        #
        #   class _io.BytesIO(_io._BufferedIOBase): pass
        #   class io.BufferedIOBase(_io._BufferedIOBase): pass
        #   io.BufferedIOBase.register(_io.BytesIO)
        #
        # This also occurs in _frozen_importlib_external.
        return more_bases + super_bases


def get_global_virtual_mro_calculator():
    global global_virtual_mro_calculator
    global global_virtual_mro_calculator_token
    # The token can't change while we hold the lock.
    with tracker_lock:
        current_token = abc.get_cache_token()
        if current_token != global_virtual_mro_calculator_token:
            global_virtual_mro_calculator_token = current_token
            global_virtual_mro_calculator = VirtualMroCalculator()
        return global_virtual_mro_calculator

def virtual_mro(cls):
    virtual_mro_calculator = get_global_virtual_mro_calculator()
    # This may or may not see (and cache) concurrent changes.
    # The important part is that *future* calls will see them.
    return virtual_mro_calculator.mro(cls)


_install()


if __name__ == '__main__':
    sanity_check()
    dump_types(object, virtual=True)
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top