Question

I'm writing a library that uses expression templates with CRTP. The source files can be found here: https://github.com/mspraggs/pyQCD/tree/master/lib/include/base

The expression templates are based on the example given in the Wikipedia article on the subject. I list the code here in case the Wiki article changes in future:

#include <vector>
#include <cassert>

template <typename E>
// A CRTP base class for Vecs with a size and indexing:
class VecExpression {
public:
  typedef std::vector<double>         container_type;
  typedef container_type::size_type   size_type;
  typedef container_type::value_type  value_type;
  typedef container_type::reference   reference;

  size_type  size()                  const { return static_cast<E const&>(*this).size(); }
  value_type operator[](size_type i) const { return static_cast<E const&>(*this)[i];     }

  operator E&()             { return static_cast<      E&>(*this); }
  operator E const&() const { return static_cast<const E&>(*this); }
};

// The actual Vec class:
class Vec : public VecExpression<Vec> {
  container_type _data;
public:
  reference  operator[](size_type i)       { return _data[i]; }
  value_type operator[](size_type i) const { return _data[i]; }
  size_type  size()                  const { return _data.size(); }

  Vec(size_type n) : _data(n) {} // Construct a given size:

  // Construct from any VecExpression:
  template <typename E>
  Vec(VecExpression<E> const& vec) {
    E const& v = vec;
    _data.resize(v.size());
    for (size_type i = 0; i != v.size(); ++i) {
      _data[i] = v[i];
    }
  }
};

template <typename E1, typename E2>
class VecDifference : public VecExpression<VecDifference<E1, E2> > {
  E1 const& _u;
  E2 const& _v;
public:
  typedef Vec::size_type size_type;
  typedef Vec::value_type value_type;
  VecDifference(VecExpression<E1> const& u, VecExpression<E2> const& v) : _u(u), _v(v) {
    assert(u.size() == v.size());
  }
  size_type size() const { return _v.size(); }
  value_type operator[](Vec::size_type i) const { return _u[i] - _v[i]; }
};

template <typename E>
class VecScaled : public VecExpression<VecScaled<E> > {
  double _alpha; 
  E const& _v;
public:
  VecScaled(double alpha, VecExpression<E> const& v) : _alpha(alpha), _v(v) {}
  Vec::size_type size() const { return _v.size(); }
  Vec::value_type operator[](Vec::size_type i) const { return _alpha * _v[i]; }
};

// Now we can overload operators:

template <typename E1, typename E2>
VecDifference<E1,E2> const
operator-(VecExpression<E1> const& u, VecExpression<E2> const& v) {
  return VecDifference<E1,E2>(u,v);
}

template <typename E>
VecScaled<E> const
operator*(double alpha, VecExpression<E> const& v) {
  return VecScaled<E>(alpha,v);
}

What I want to do is add another expression template that allows assignment to part of the original template object (the Vec class in the code above, and the LatticeBase class in the code I've linked to). Possible usage:

Vec myvector(10);
Vec another_vector(5);
myvector.head(5) = another_vector; // Assign first 5 elements on myvector
myvector.head(2) = another_vector.head(2); // EDIT

So I'd create a new function Vec::head that would a return an expression template for a portion of the Vec object. I don't know how this would fit into the framework I currently have. In particular I have the following questions/comments:

  1. I've seen examples of what I want to achieve in expression templates that don't use CRTP. What do I gain by using CRTP in this case? Is there any point? Should I ditch it and follow the other examples I've found?
  2. In the current framework, the assignment to the _data member in the Vec class is handled by a copy constructor in the Vec class. This won't work if I want to use the expression template returned by Vec::head, since the assignment happens within the class that holds the data, not the expression template.
  3. I've tried creating an assignment operator within the new expression template, but that won't work with the code above as all the expression template members are const references, and so the assignment operator is deleted at compile time. Can I just switch the members to being values instead of references? Will this impact on performance if additional storage is needed? Will this even work (if I change a stored copy of the expression rather than the expression itself)?

Overall I'm confused about how to go about adding an expression template that can be used as an lvalue in the code above. Any guidance on this would be greatly appreciated.

Was it helpful?

Solution

Try this:

#include <vector>
#include <cassert>

template <typename E>
// A CRTP base class for Vecs with a size and indexing:
class VecExpression {
public:
    typedef std::vector<double>         container_type;
    typedef container_type::size_type   size_type;
    typedef container_type::value_type  value_type;
    typedef container_type::reference   reference;

    size_type  size()                  const { return static_cast<E const&>(*this).size(); }
    value_type operator[](size_type i) const { return static_cast<E const&>(*this)[i]; }

    operator E&()             { return static_cast<E&>(*this); }
    operator E const&() const { return static_cast<const E&>(*this); }
};

class VecHead;

// The actual Vec class:
class Vec : public VecExpression<Vec> {
    container_type _data;
public:
    reference  operator[](size_type i)       { return _data[i]; }
    value_type operator[](size_type i) const { return _data[i]; }
    size_type  size()                  const { return _data.size(); }

    Vec(size_type n) : _data(n) {} // Construct a given size:

    // Construct from any VecExpression:
    template <typename E>
    Vec(VecExpression<E> const& vec) {
        E const& v = vec;
        _data.resize(v.size());
        for (size_type i = 0; i != v.size(); ++i) {
            _data[i] = v[i];
        }
    }

    VecHead head(size_type s);
};

class VecHead : public VecExpression< VecHead >
{
    Vec::size_type _s;
    Vec& _e;
public:

    typedef Vec::size_type size_type;
    typedef Vec::value_type value_type;
    VecHead(std::size_t s, Vec& e)
        : _s(s)
        , _e(e)
    {
        assert(_e.size() >= _s);
    }

    size_type size() const { return _s; }
    value_type operator[](Vec::size_type i) const { assert(i < _s);  return _e[i]; }

    VecHead& operator = (const VecHead& rhs)
    {
        return operator=(static_cast<const VecExpression<VecHead>&>(rhs));
    }

    template <typename E>
    VecHead& operator = (const VecExpression<E>& rhs)
    {
        assert(rhs.size() >= _s);
        for (size_type i = 0; i < _s && i < rhs.size(); ++i)
            _e[i] = rhs[i];
        return *this;
    }
};

VecHead Vec::head(size_type s)
{
    VecHead aHead(s, *this);
    return aHead;
}

template <typename E1, typename E2>
class VecDifference : public VecExpression<VecDifference<E1, E2> > {
    E1 const& _u;
    E2 const& _v;
public:
    typedef Vec::size_type size_type;
    typedef Vec::value_type value_type;
    VecDifference(VecExpression<E1> const& u, VecExpression<E2> const& v) : _u(u), _v(v) {
        assert(u.size() == v.size());
    }
    size_type size() const { return _v.size(); }
    value_type operator[](Vec::size_type i) const { return _u[i] - _v[i]; }
};

template <typename E>
class VecScaled : public VecExpression<VecScaled<E> > {
    double _alpha;
    E const& _v;
public:
    VecScaled(double alpha, VecExpression<E> const& v) : _alpha(alpha), _v(v) {}
    Vec::size_type size() const { return _v.size(); }
    Vec::value_type operator[](Vec::size_type i) const { return _alpha * _v[i]; }
};

// Now we can overload operators:

template <typename E1, typename E2>
VecDifference<E1, E2> const
    operator-(VecExpression<E1> const& u, VecExpression<E2> const& v) {
        return VecDifference<E1, E2>(u, v);
}

template <typename E>
VecScaled<E> const
    operator*(double alpha, VecExpression<E> const& v) {
        return VecScaled<E>(alpha, v);
}

int main()
{
    Vec myvector(10);
    Vec another_vector(5);
    for (int i = 0; i < 5; ++i)
        another_vector[i] = i;

    myvector.head(5) = another_vector; // Assign first 5 elements on myvector
    assert(myvector.head(5).size() == 5);
    for (int i = 0; i < 10; ++i)
    {
        assert(myvector[i] == (i < 5 ? static_cast<double>(i) : 0.));
    }

    //! Added test due to comment vec1.head(2) = vec2.head(2) doesn't work.
    Vec vec1(10), vec2(10);
    for (int i = 0; i < 10; ++i)
        vec2[i] = 2 * (vec1[i] = i);

    vec1.head(2) = vec2.head(2);
    for (int i = 0; i < 10; ++i)
    {
        if (i < 2)
        {
            assert(vec1[i] == vec2[i]);
        }
        else
        {
            assert(vec1[i] != vec2[i]);
        }
    }

    return 0;
}
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top