I get close to a 4x speed-up in computing the product M.T * D * M
out of the three diagonal arrays. If d0
and d1
are the main and upper diagonal of M
, and d
is the main diagonal of D
, then the following code creates M.T * D * M
directly:
def make_tridi_bis(d0, d1, d, nc=10):
d00 = d0*d0*d
d11 = d1*d1*d
d01 = d0*d1*d
len_ = d0.size
data = np.empty((3*len_ + nc,))
indices = np.empty((3*len_ + nc,), dtype=np.int)
# Fill main diagonal
data[:2*nc:2] = d00[:nc]
indices[:2*nc:2] = np.arange(nc)
data[2*nc+1:-2*nc:3] = d00[nc:] + d11[:-nc]
indices[2*nc+1:-2*nc:3] = np.arange(nc, len_)
data[-2*nc+1::2] = d11[-nc:]
indices[-2*nc+1::2] = np.arange(len_, len_ + nc)
# Fill top diagonal
data[1:2*nc:2] = d01[:nc]
indices[1:2*nc:2] = np.arange(nc, 2*nc)
data[2*nc+2:-2*nc:3] = d01[nc:]
indices[2*nc+2:-2*nc:3] = np.arange(2*nc, len_+nc)
# Fill bottom diagonal
data[2*nc:-2*nc:3] = d01[:-nc]
indices[2*nc:-2*nc:3] = np.arange(len_ - nc)
data[-2*nc::2] = d01[-nc:]
indices[-2*nc::2] = np.arange(len_ - nc ,len_)
indptr = np.empty((len_ + nc + 1,), dtype=np.int)
indptr[0] = 0
indptr[1:nc+1] = 2
indptr[nc+1:len_+1] = 3
indptr[-nc:] = 2
np.cumsum(indptr, out=indptr)
return sparse.csr_matrix((data, indices, indptr), shape=(len_+nc, len_+nc))
If your matrix M
were in CSR format, you can extract d0
and d1
as d0 = M.data[::2]
and d1 = M.data[1::2]
, I modified you toy data making routine to return those arrays as well, and here's what I get:
In [90]: np.allclose((M.T * sparse.diags(d, 0) * M).A, make_tridi_bis(d0, d1, d).A)
Out[90]: True
In [92]: %timeit make_tridi_bis(d0, d1, d)
10 loops, best of 3: 124 ms per loop
In [93]: %timeit M.T * sparse.diags(d, 0) * M
1 loops, best of 3: 501 ms per loop
The whole purpose of the above code is to take advantage of the structure of the non-zero entries. If you draw a diagram of the matrices you are multiplying together, it is relatively easy to convince yourself that the main (d_0
) and top and bottom (d_1
) diagonals of the resulting tridiagonal matrix are simply:
d_0 = np.zeros((len_ + nc,))
d_0[:len_] = d00
d_0[-len_:] += d11
d_1 = d01
The rest of the code in that function is simply building the tridiagonal matrix directly, as calling sparse.diags
with the above data is several times slower.