Question

Pretty straightforward question: Given an N x N symmetric matrix A and an N-vector x, is there a built-in Matlab function to calculate x'*A*x? i.e., instead of y = x'*A*x, is there a function quadraticform s.t. y = quadraticform(A, x)?

Obviously I can just do y = x'*A*x, but I need performance and it seems like there ought to be a way to take advantage of

  1. A is symmetric
  2. The left and right multipliers are the same vector

If there's not a single built-in function, is there method that's faster than x'*A*x? OR, is the Matlab parser smart enough to optimize x'*A*x? If so, can you point me to a place in the documentation that verifies the fact?

Was it helpful?

Solution

I couldn't find such a built-in function, and I have an idea why.

y=x'*A*x can be written as a sum of n^2 terms A(i,j)*x(i)*x(j), where i and j runs from 1 to n (where A is an nxn matrix). A is symmetric: A(i,j) = A(j,i) for all i and j. Due to symmetry, every term appears twice in the sum, except for those where i equals j. So we have n*(n+1)/2 different terms. Each has two floating-point multiplications, so a naive method would need n*(n+1) multiplications in total. It is easy to see that the naive calculation of x'*A*x, that is, calculating z=A*x and then y=x'*z, also needs n*(n+1) multiplications. However, there is a faster way to sum our n*(n+1)/2 different terms: for every i, we can factor out x(i), which means that only n*(n-1)/2+3*n multiplications is enough. But this does not really help: the running time of the calculation of y=x'*A*x is still O(n^2).

So, I think that the calculation of quadratic forms cannot be done faster than O(n^2), and since this can also be achieved by the formula y=x'*A*x, there would be no real advantage of a special "quadraticform" function.

=== UPDATE ===

I've written the function "quadraticform" in C, as a Matlab extension:

// y = quadraticform(A, x)
#include "mex.h" 

/* Input Arguments */
#define A_in prhs[0]
#define x_in prhs[1]

/* Output Arguments */
#define y_out plhs[0] 

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  mwSize mA, nA, n, mx, nx;
  double *A, *x;
  double z, y;
  int i, j, k;

  if (nrhs != 2) { 
      mexErrMsgTxt("Two input arguments required."); 
  } else if (nlhs > 1) {
      mexErrMsgTxt("Too many output arguments."); 
  }

  mA = mxGetM(A_in);
  nA = mxGetN(A_in);
  if (mA != nA)
    mexErrMsgTxt("The first input argument must be a quadratic matrix.");
  n = mA;

  mx = mxGetM(x_in);
  nx = mxGetN(x_in);
  if (mx != n || nx != 1)
    mexErrMsgTxt("The second input argument must be a column vector of proper size.");

  A = mxGetPr(A_in);
  x = mxGetPr(x_in);
  y = 0.0;
  k = 0;
  for (i = 0; i < n; ++i)
  {
    z = 0.0;
    for (j = 0; j < i; ++j)
      z += A[k + j] * x[j];
    z *= x[i];
    y += A[k + i] * x[i] * x[i] + z + z;
    k += n;
  }

  y_out = mxCreateDoubleScalar(y);
}

I saved this code as "quadraticform.c", and compiled it with Matlab:

mex -O quadraticform.c

I wrote a simple performance test to compare this function with x'Ax:

clear all; close all; clc;

sizes = int32(logspace(2, 3, 25));
nsizes = length(sizes);
etimes = zeros(nsizes, 2); % Matlab vs. C
nrepeats = 100;
h = waitbar(0, 'Please wait...');
for i = 1 : nrepeats
  for j = 1 : nsizes
    n = sizes(j);
    A = randn(n); 
    A = (A + A') / 2;
    x = randn(n, 1);
    if randn > 0
      start = tic;
      y1 = x' * A * x;
      etimes(j, 1) = etimes(j, 1) + toc(start);
      start = tic;
      y2 = quadraticform(A, x);
      etimes(j, 2) = etimes(j, 2) + toc(start);      
    else
      start = tic;
      y2 = quadraticform(A, x);
      etimes(j, 2) = etimes(j, 2) + toc(start);      
      start = tic;
      y1 = x' * A * x;
      etimes(j, 1) = etimes(j, 1) + toc(start);
    end;
    if abs((y1 - y2) / y2) > 1e-10
      error('"x'' * A * x" is not equal to "quadraticform(A, x)"');
    end;
    waitbar(((i - 1) * nsizes + j) / (nrepeats * nsizes), h);
  end;
end;
close(h);
clear A x y;
etimes = etimes / nrepeats;

n = double(sizes);
n2 = n .^ 2.0;
i = nsizes - 2 : nsizes;
n2_1 = mean(etimes(i, 1)) * n2 / mean(n2(i));
n2_2 = mean(etimes(i, 2)) * n2 / mean(n2(i));

figure;
loglog(n, etimes(:, 1), 'r.-', 'LineSmoothing', 'on');
hold on;
loglog(n, etimes(:, 2), 'g.-', 'LineSmoothing', 'on');
loglog(n, n2_1, 'k-', 'LineSmoothing', 'on');
loglog(n, n2_2, 'k-', 'LineSmoothing', 'on');
axis([n(1) n(end) 1e-4 1e-2]);
xlabel('Matrix size, n');
ylabel('Running time (a.u.)');
legend('x'' * A * x', 'quadraticform(A, x)', 'O(n^2)', 'Location', 'NorthWest');

W = 16 / 2.54; H = 12 / 2.54; dpi = 100;
set(gcf, 'PaperPosition', [0, 0, W, H]);
set(gcf, 'PaperSize', [W, H]);
print(gcf, sprintf('-r%d',dpi), '-dpng', 'quadraticformtest.png');

The result is very interesting. The running time of both x'*A*x and quadraticform(A,x) converges to O(n^2), but the former has a smaller factor:

quadraticformtest.png

OTHER TIPS

MATLAB is clever enough to recognise and optimise some sorts of compound matrix expression, and I believe (although I can't definitely confirm) that the quadratic form is one of the optimisations that it does make.

However, it's not the sort of thing MathWorks tend to document, because a) it will typically only be optimised within functions, not in scripts, at the command line or in debugging b) it may only work in some circumstances, such as for real nonsparse A c) it may change from release to release, so they don't want you to rely on it d) it's one of the proprietary things that make MATLAB so good.

To confirm, you could try comparing timings for y=x'*A*x against B=A*x; y=x'*B. You could also try feature('accel','off'), which will turn most of those sort of optimisations off.

Finally, if you contact MathWorks support, you might be able to get one of the developers to confirm whether the optimisation is being made.

I'm not sure if this will work in your case, but I came across a similar situation where I wanted to calculate many sums of squares. After tinkering with the algebra, I realized I was approaching this like a mathematician and not like a computer engineer:

If the rows of X are your data points, then the ith row of Q below will be the ith sum:

Q = sum(X.^2 * A)

Hope that helps!

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top