Question

I've written a DSL and a compiler that generates a .NET expression tree from it. All expressions within the tree are side-effect-free and the expression is guaranteed to be a "non-statement" expression (no locals, loops, blocks etc.). (Edit: The tree may include literals, property accesses, standard operators and function calls - which may be doing fancy things like memoization inside, but are externally side-effect free).

Now I would like to perform the "Common sub-expression elimination" optimization on it.

For example, given a tree corresponding to the C# lambda:

foo =>      (foo.Bar * 5 + foo.Baz * 2 > 7) 
         || (foo.Bar * 5 + foo.Baz * 2 < 3)  
         || (foo.Bar * 5 + 3 == foo.Xyz)

...I would like to generate the tree-equivalent of (ignore the fact that some of the short-circuiting semantics are being ignored):

foo =>
{
     var local1 = foo.Bar * 5;

     // Notice that this local depends on the first one.        
     var local2 = local1 + foo.Baz * 2; 

     // Notice that no unnecessary locals have been generated.
     return local2 > 7 || local2 < 3 || (local1 + 3 == foo.Xyz);
}

I'm familiar with writing expression-visitors, but the algorithm for this optimization isn't immediately obvious to me - I could of course find "duplicates" within a tree, but there's obviously some trick to analyzing the dependencies within and between sub-trees to eliminate sub-expressions efficiently and correctly.

I looked for algorithms on Google but they seem quite complicated to implement quickly. Also, they seem very "general" and don't necessarily take the simplicity of the trees I have into account.

Was it helpful?

Solution 2

You're correct in noting this is not a trivial problem.

The classical way that compilers handle it is a Directed Acyclic Graph (DAG) representation of the expression. The DAG is built in the same manner as the abstract syntax tree (and can be built by traversing the AST - perhaps a job for the expression visitor; I don't know much of C# libraries), except that a dictionary of previously emitted subgraphs is maintained. Before generating any given node type with given children, the dictionary is consulted to see if one already exists. Only if this check fails is a new one created, then added to the dictionary.

Since now a node may descend from multiple parents, the result is a DAG.

Then the DAG is traversed depth first to generate code. Since common sub-expressions are now represented by a single node, the value is only computed once and stored in a temp for other expressions emitted later in the code generation to use. If the original code contains assignments, this phase gets complicated. Since your trees are side-effect free, the DAG ought to be the most straightforward way to solve your problem.

As I recall, the coverage of DAGs in the Dragon book is particularly nice.

As others have noted, if your trees will ultimately be compiled by an existing compiler, it's kind of futile to redo what's already there.

Addition

I had some Java code laying around from a student project (I teach) so hacked up a little example of how this works. It's too long to post, but see the Gist here.

Running it on your input prints the DAG below. The numbers in parens are (unique id, DAG parent count). The parent count is needed to decide when to compute the local temp variables and when to just use the expression for a node.

Binary OR (27,1)
  lhs:
    Binary OR (19,1)
      lhs:
        Binary GREATER (9,1)
          lhs:
            Binary ADD (7,2)
              lhs:
                Binary MULTIPLY (3,2)
                  lhs:
                    Id 'Bar' (1,1)
                  rhs:
                    Number 5 (2,1)
              rhs:
                Binary MULTIPLY (6,1)
                  lhs:
                    Id 'Baz' (4,1)
                  rhs:
                    Number 2 (5,1)
          rhs:
            Number 7 (8,1)
      rhs:
        Binary LESS (18,1)
          lhs:
            ref to Binary ADD (7,2)
          rhs:
            Number 3 (17,2)
  rhs:
    Binary EQUALS (26,1)
      lhs:
        Binary ADD (24,1)
          lhs:
            ref to Binary MULTIPLY (3,2)
          rhs:
            ref to Number 3 (17,2)
      rhs:
        Id 'Xyz' (25,1)

Then it generates this code:

t3 = (Bar) * (5);
t7 = (t3) + ((Baz) * (2));
return (((t7) > (7)) || ((t7) < (3))) || (((t3) + (3)) == (Xyz));

You can see that the temp var numbers correspond to DAG nodes. You could make the code generator more complex to get rid of the unnecessary parentheses, but I'll leave that for others.

OTHER TIPS

You are doing unnecessary work, common sub-expression elimination is the job of the jitter optimizer. Let's take your example and look at the generated code. I wrote it like this:

    static void Main(string[] args) {
        var lambda = new Func<Foo, bool>(foo => 
               (foo.Bar * 5 + foo.Baz * 2 > 7)
            || (foo.Bar * 5 + foo.Baz * 2 < 3) 
            || (foo.Bar * 5 + 3 == foo.Xyz));
        var obj = new Foo() { Bar = 1, Baz = 2, Xyz = 3 };
        var result = lambda(obj);
        Console.WriteLine(result);
    }
}

class Foo {
    public int Bar { get; internal set; }
    public int Baz { get; internal set; }
    public int Xyz { get; internal set; }
}

The x86 jitter generated this machine code for the lambda expression:

006526B8  push        ebp                          ; prologue
006526B9  mov         ebp,esp  
006526BB  push        esi  
006526BC  mov         esi,dword ptr [ecx+4]        ; esi = foo.Bar
006526BF  lea         esi,[esi+esi*4]              ; esi = 5 * foo.Bar
006526C2  mov         edx,dword ptr [ecx+8]        ; edx = foo.Baz
006526C5  add         edx,edx                      ; edx = 2 * foo.Baz
006526C7  lea         eax,[esi+edx]                ; eax = 5 * foo.Bar + 2 * foo.Baz
006526CA  cmp         eax,7                        ; > 7 test
006526CD  jg          006526E7                     ; > 7 then return true
006526CF  add         edx,esi                      ; HERE!!
006526D1  cmp         edx,3                        ; < 3 test
006526D4  jl          006526E7                     ; < 3 then return true
006526D6  add         esi,3                        ; HERE!!
006526D9  mov         eax,esi  
006526DB  cmp         eax,dword ptr [ecx+0Ch]      ; == foo.Xyz test
006526DE  sete        al                           ; convert to bool
006526E1  movzx       eax,al  
006526E4  pop         esi                          ; epilogue
006526E5  pop         ebp  
006526E6  ret 
006526E7  mov         eax,1  
006526EC  pop         esi  
006526ED  pop         ebp  
006526EE  ret   

I marked the places in the code where the foo.Bar * 5 sub-expression was eliminated with HERE. Notable is how it did not eliminate the foo.Bar * 5 + foo.Baz * 2 sub-expression, the addition was performed again at address 006526CF. There is a good reason for that, the x86 jitter doesn't have enough registers available to store the intermediary result. If you look at the machine code generated by the x64 jitter then you do see it eliminated, the r9 register stores it.

This ought to give enough reasons to reconsider your intend. You are doing work that doesn't need to be done. And not only that, you are liable to generate worse code than the jitter will generate since you don't have the luxury to estimate the CPU register budget.

Don't do this.

  1. Make a SortedDictionary<Expression, object> that can compare arbitrary Expressions.
    (You can define your own arbitrary comparison function here -- for example, you can lexicographically compare the types of the expressions, and if they compare equal then you can compare the children one by one.)

  2. Go through all the leaves and add them to the dictionary; if they already exist, then they're duplicates, so merge them.
    (This is also a good time to emit code -- such as creating a new variable -- for this leaf if it's the first instance of it; you can then store the emitted code inside the object value in the dictionary.)

  3. Then go through the parents of all the previous leaves and add them to the dictionary; if they already exist, then they're duplicates, so merge them.

  4. Keep on going up level by level until you reach the root.

Now you know what all the duplicates are, and where they occur, and you've generated code for all of them.

Disclaimer: I have never tackled a problem like this, I'm just throwing out an idea that seems reasonably efficient:

For every node in the tree have some sort of signature. A hash should do, collisions can be dealt with. The signature must map all Foo.Bar entries to the same value.

Traverse the tree (O(n)) building a list of signatures of INTERNAL nodes (ignore leaves), sort on a combined key of expression size and then signature (O(n log n)). Take the most common item of the smallest expression in the list (O(n)) and go through replacing the expression with a local variable. (Check that they are truly matches at this time just in case we had a hash collision. B)

Repeat this until you accomplish nothing. This can't possibly run more than n/2 times, thus bounding the whole operation to O(n^2 log n).

I agree with hans-passant about the practicality of doing this. However, if you're researching this academically you may be interested in the Quine-McCluskey algorithm. Beware this is a very complex problem. Mathematica has a very nice all-purpose expression optimizer and depending on your needs you may be able to just use it - e.g. if you feed it your expression:

enter image description here

(foo.Bar = A, foo.Baz = B, foo.Xyz = X)

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top