Question

I'm writing a simple maths library with a template vector type:

template<typename T, size_t N>
class Vector {
    public:
        Vector<T, N> &operator+=(Vector<T, N> const &other);
        // ... more operators, functions ...
};

Now I want some additional functionality specifically for some of these. Let's say I want functions x() and y() on Vector<T, 2> to access particular coordinates. I could create a partial specialization for this:

template<typename T>
class Vector<T, 3> {
    public:
        Vector<T, 3> &operator+=(Vector<T, 3> const &other);
        // ... and again all the operators and functions ...
        T x() const;
        T y() const;
};

But now I'm repeating everything that already existed in the generic template.

I could also use inheritance. Renaming the generic template to VectorBase, I could do this:

template<typename T, size_t N>
class Vector : public VectorBase<T, N> {
};

template<typename T>
class Vector<T, 3> : public VectorBase<T, 3> {
    public:
        T x() const;
        T y() const;
};

However, now the problem is that all operators are defined on VectorBase, so they return VectorBase instances. These cannot be assigned to Vector variables:

Vector<float, 3> v;
Vector<float, 3> w;
w = 5 * v; // error: no conversion from VectorBase<float, 3> to Vector<float, 3>

I could give Vector an implicit conversion constructor to make this possible:

template<typename T, size_t N>
class Vector : public VectorBase<T, N> {
    public:
        Vector(VectorBase<T, N> const &other);
};

However, now I'm converting from Vector to VectorBase and back again. Even though the types are the same in memory, and the compiler might optimize all this away, it feels clunky and I don't really like to have potential run-time overhead for what is essentially a compile-time problem.

Is there any other way to solve this?

Was it helpful?

Solution

I think you can use CRTP to solve this problem. This idiom is used in boost::operator.

template<typename ChildT, typename T, int N>
class VectorBase 
{    
public:
    /* use static_cast if necessary as we know that 'ChildT' is a 'VectorBase' */
    friend ChildT operator*(double lhs, ChildT const &rhs) { /* */ }
    friend ChildT operator*(ChildT const &lhs, double rhs) { /* */ }
};

template<typename T, size_t N>
class Vector : public VectorBase<Vector<T,N>, T, N> 
{
};

template<typename T>
class Vector<T, 3> : public VectorBase<Vector<T, 3>, T, 3>
{
public:
    T x() const {}
    T y() const {}
};

void test()
{
    Vector<float, 3> v;
    Vector<float, 3> w;
    w = 5 * v;
    w = v * 5;
    v.x();

    Vector<float, 5> y;
    Vector<float, 5> z;
    y = 5 * z;
    y = z * 5;
    //z.x(); // Error !!
}

OTHER TIPS

Here's something I came up with when playing with C++0x features a while back. The only C++0x feature used in this is static_assert, so you could use Boost to replace that.

Basically, we can use a static size check function that just checks to be sure a given index is less than the size of the vector. We use a static assert to generate a compiler error if the index is out of bounds:

template <std::size_t Index> 
void size_check_lt() const 
{ 
    static_assert(Index < N, "the index is not within the range of the vector"); 
}

Then we can provide a get() method that returns a reference to the element at a given index (obviously a const overload would be useful too):

template <std::size_t Index> 
T& get()
{ 
    size_check_lt<Index>(); return data_[Index]; 
}

Then we can write simple accessors like so:

T& x() { return get<0>(); }
T& y() { return get<1>(); }
T& z() { return get<2>(); }

If the vector has only two elements, you can use x and y but not z. If the vector has three or more elements you can use all three.

I ended up doing the same thing for constructors--I created constructors for vectors of dimension two, three, and four and added a size_check_eq that allowed them to be instantiated only for vectors of dimension two, three, and four, respectively. I can try and post the complete code when I get home tonight, if anyone is interested.

I dropped the project halfway through, so there might be some huge problem with doing it this way that I didn't run into... at least it's an option to consider.

The simplest way ? Using external functions:

template <class T>
T& x(Vector<T,2>& vector) { return vector.at<0>(); }

template <class T>
T const& x(Vector<T,2> const& vector) { return vector.at<0>(); }

In template programming using external functions is the simplest way to add functionality, simply because of the specialization issue you just encountered.

On the other hand, you could still provide x, y and z for any N or perhaps use enable_if / disable_if features to restrict the scope.

I don't know if you can get around the typing problems with the assignment operator, but you can make life a little easier by defining template versions of the various operators, helper functions to implement them, and then use inheritance.

template <typename T, std::size_t N>
class fixed_array {
public:
    virtual ~fixed_array() {}
    template <std::size_t K>
    fixed_array& operator+=(fixed_array<T,K> const& other) {
        for (std::size_t i=0; i<N; ++i)
            this->contents[i] += other[i];
        return *this;
    }
    template <std::size_t K>
    fixed_array& operator=(fixed_array<T,K> const& other) {
        assign_from(other);
        return *this;
    }
    T& operator[](std::size_t idx) {
        if (idx >= N)
            throw std::runtime_error("invalid index in fixed_array[]");
        return contents[idx];
    }
protected:
    template <std::size_t K>
    void assign_from(fixed_array<T,K> const& other) {
        for (std::size_t i=0; i<N; ++i)
            this->contents[i] = other[i];
    }
private:
    T contents[N];
};

template <typename T>
class fixed_2d_array: public fixed_array<T,2> {
public:
    T x_coord() const { return (*this)[0]; }
    T y_coord() const { return (*this)[1]; }
    template <std::size_t K>
    fixed_2d_array& operator=(fixed_array<T,K> const& other) {
        assign_from(other);
        return *this;
    }
};

int
main() {
    fixed_array<int,5> ary1;
    fixed_2d_array<int> ary2;
    ary2 = ary1;
    ary1 = ary2;
    ary2 += ary1;
    ary1 += ary2;
    return 0;
}
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top