Question

I'm writing some numerical simulation code in C++. In this simulation, there are some things that are "local", having a floating point value at every point on a two-dimensional grid, and others that are "global", having only a single global floating point value.

Aside from this difference, the two types of object behave similarly, and so I would like to be able to have an array that contains both types of object. However, because this is a numerical simulation, I need to do this in a way that (a) avoids virtual function call overheads as much as possible, and (b) allows the compiler to use optimisations as much as possible - and in particular, allows the compiler to do SIMD auto-vecotorisation where possible.

Currently I'm finding myself writing code like this (which, I now realise, will not actually work as intended):

class Base {};

class Local: public Base {
public:
    float data[size];
    // plus constructors etc.
};

class Global: public Base {
public:
    float data;
    // ...
};

void doStuff(Local a, Local b) {
    for (int i; i<size; ++i) {
        a.data[i] += b.data[i];
    }
}

void doStuff(Local a, Global b) {
    for (int i; i<size; ++i) {
        a.data[i] += b.data;
    }
}

void doStuff(Global a, Local b) {
    for (int i; i<size; ++i) {
        a.data += b.data[i];
    }
}

void doStuff(Global a, Global b) {
    a.data += b.data*size;
}

My code is a bit more complex than this - the array is two dimensional, and there are several doStuff-type functions that have three rather than two arguments, so I have to write eight specialisations for each one.

The reason this doesn't work as intended is that the types of the arguments to doStuff are not actually known at compile time. What I want to do is to have an array of Base * and to call doStuff on two of its members. I then want the correct specialisation of doStuff to be called for the specific types of its arguments. (It doesn't matter if there's a virtual method call involved in doStuff - I just want to avoid them in the inner loop.)

The point of doing it this way rather than (for example) overloading operator[] is that the compiler can (hopefully) do SIMD auto-vectorisation for doStuff(Local, Local) and doStuff(Local, Global), and I can lose the loop entirely in doStuff(Global, Global). Perhaps there are other compiler optimisations that can happen in these functions as well.

However, it's annoying to have to write such repetitive code. Consequently I'm wondering whether there's a way to achieve this using templates, so that I can just write one function doStuff(Base, Base) and code equivalent to the above will be generated. (I hope that gcc is smart enough to optimise away the loop in the case of doStuff(Global, Global).)

I stress that the following solution is not what I'm looking for, since it involves a virtual function call on every iteration through the loop, which adds overhead and probably prevents a lot of compiler optimisations.

class Base {
    virtual float &operator[](int) = 0;
};

class Local: public Base {
    float data[size];
public:
    float &operator[](int i) {
        return data[i];
    }
    // …
};

class Global: public Base {
    float data;
public:
    float &operator[](int i) {
        return data;
    }
    // ...
};

void doStuff(Base a, Base b) {
    for (int i; i<size; ++i) {
        a[i] += b[i];
    }
}

I would like to achieve a similar effect to the above, but without the overhead of a virtual function call on every iteration through the inner loop. (Unless I'm completely wrong, and the compiler can actually optimise away all the virtual function calls and generate code like the above. In that case you could save me a lot of time by telling me this!)

I did have a look at CRTP, but it's not obvious how to adapt it to this case, at least not to me, because of the multiple overloaded arguments to doStuff.

No correct solution

OTHER TIPS

You almost have the answer. A template function like this should work (though I don't know where size is coming from):

template<typename A, typename B>
void doStuff(A & a, B & b) {
    for (int i; i<size; ++i) {
        a[i] += b[i];
    }
}

Here you have an overloaded operator[] but it isn't virtual.


If you don't know at call time what types you have, but you have a fixed number of derived types, then creating a static dispatch is an option

void doStuff( Base & a, Base & b ) {
    Local * a_local = dynamic_cast<Local*>(&a);
    Global * a_global = dynamic_cast<Global*>(&a);
    //same for b
    if( a_local && b_local ) {
        doStuffImpl(*a, *b); {
    } else if( a_local && b_global ) {
        doStuffImpl(*a, *b):
    } ...
}

You'll notice the code in the if block is the same for every condition, assuming doStuffImpl is a template function. I'd suggest wrapping this up in a macro to reduce the code overhead. You may also wish to track the type on your own and no use dynamic_cast. Have an enum in your Base class which explicitly lists the types. This is a safety mechanism that basically prevents unknown derived classes from appearing at doStuff.

Unfortunately this type of approach is required. It's the only way to convert from dynamic types to static ones. And if you wish to use templates you need the static ones.

Is it OK in your code for Local to know about Global and Global to know about Local?

If the answer to the above question is yes, you can the avoid cost of a virtual function for every point in the domain by one virtual function call and couple of dynamic casts.

class Base {
    public:
       virtual void doStuff(Base& b) = 0;
};

class Local: public Base {
    public:
       virtual void doStuff(Base& b);
       float data[size];
       // plus constructors etc.
};

class Global: public Base {
    public:
       virtual void doStuff(Base& b);
       float data;
       // ...
};

void Local::doStuff(Base& b) {
    Local* lb = NULL;
    Global* gb = NULL;
    if ( NULL != (lb = dynamic_cast<Local*>(&b)) )
    {
       // Do Local+Local stuff.
    }
    else if ( NULL != (gb = dynamic_cast<Global*>(&b)))
    {
       // Do Local+Global stuff.
    }

}

void Global::doStuff(Base& b) {
    Local* lb = NULL;
    Global* gb = NULL;
    if ( NULL != (lb = dynamic_cast<Local*>(&b)) )
    {
       // Do Global+Local stuff.
    }
    else if ( NULL != (gb = dynamic_cast<Global*>(&b)))
    {
       // Do Global+Global stuff.
    }

}

void doStuff(Base a, Base b) {
    a.doStuff(b);
}
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top