سؤال

So, i would like to convert a part of C code to Haskell. I wrote this part (it's a simplified example of what I want to do) in C, but being the newbie I am in Haskell, I can't really make it work.

float g(int n, float a, float p, float s)
{
    int c;
    while (n>0)
    {
        c = n % 2;
        if (!c) s += p;
        else s -= p;
        p *= a;
        n--;
    }
    return s;
}

Anyone got any ideas/solutions?

هل كانت مفيدة؟

المحلول

Lee's translation is already pretty good (well, he confused the odd and even cases(1)), but he fell into a couple of performance traps.

g n a p s =
  if n > 0
  then
    let c = n `mod` 2
        s' = (if c == 0 then (-) else (+)) s p
        p' = p * a
    in g (n-1) a p' s'        
  else s
  1. He used mod instead of rem. The latter maps to machine division, the former performs additional checks to ensure a non-negative result. Thus mod is a bit slower than rem, and if either satisfies the needs - because they yield identical results in the case where both arguments are non-negative; or because the result is only compared to 0 (both conditions are satisfied here) - rem is preferable. Even better, and a bit more idiomatic is to use even (which uses rem for the reasons mentioned above). The difference is not huge, though.

  2. No type signature. That means that the code is (type-class) polymorphic, and thus no strictness analysis is possible, nor any specialisations. If the code is used in the same module at a specific type, GHC can (and usually will, if optimisations are enabled) create a specialised version for that specific type that allows strictness analysis and some other optimisations (inlining of class methods like (+) etc.), in that case, one does not pay the polymorhism penalty. But if the use site is in a different module, that cannot happen. If (type-class) polymorphic code is desired, one should mark it INLINABLE or INLINE (for GHC < 7), so that its unfolding is exposed in the .hi file and the function can be specialised and optimised at the use site.

    Since g is recursive, it cannot be inlined [meaning, GHC cannot inline it; in principle it is possible] at use sites, which often would enable more optimisations than a mere specialisation.

    One technique that often allows better optimisation for recursive functions is the worker/wrapper transformation. One creates a wrapper that calls a recursive (local) worker, then the non-recursive wrapper can be inlined, and when the worker is called with known arguments, that can enable further optimisations like constant folding or, in the case of function arguments, inlining. In particular the latter often has an enormous impact, when combined with a static-argument-transformation (arguments that never change in the recursive calls are not passed as arguments to the recursive worker).

    In this case, we only have one static argument of type Float, so a worker/wrapper transformation with a SAT typically makes no difference (as a rule of thumb, a SAT pays off when

    • the static argument is a function
    • several non-function arguments are static

    so by this rule, we shouldn't expect any benefit from w/w + SAT, and in general, there is none). Here we have one special case where w/w + SAT can make a big difference, and that is when the factor a is 1. GHC has {-# RULES #-} that eliminate multiplication by 1 for various types, and with such a short loop body, a multiplication more or less per iteration makes a difference, the running time is reduced by about 40% after points 3 and 4 have been applied. (There are no RULES for multiplication by 0 or by -1 for floating point types because 0*x = 0 resp. (-1)*x = -x don't hold for NaNs.) For all other a, the w/w + SATed

    {-# INLINABLE g #-}
    g n a p s = worker n p s
      where
        worker n p s
          | n <= 0    = s
          | otherwise = let s' = if even n then s + p else s - p
                        in worker (n-1) a (p*a) s'
    

    does not perform measurably different from the top-level recursive version with the same optimisations done.

  3. Strictness. GHC's strictness analyser is good, but not perfect. It cannot see far enough through the algorithm to determine that the function is

    • strict in p if n >= 1 (assuming addition - (+) - is strict in both arguments)
    • also strict in a if n >= 2 (assuming strictness of (*) in both arguments)

    and then produce a worker that is strict in both. Instead you get a worker that uses an unboxed Int# for n and an unboxed Float# for s (I'm using the type Int -> Float -> Float -> Float -> Float here, corresponding to the C), and boxed Floats for a and p. Thus in each iteration you get two unboxings and a re-boxing. That costs (relatively) a lot of time, since besides that it's just a bit of simple arithmetic and tests. Help GHC along a bit, and make the worker (or g itself, if you don't do the worker/wrapper transform) strict in p (bang pattern for example). That is enough to allow GHC producing a worker using unboxed values throughout.

  4. Using division to test parity (not applicable if the type is Int and the LLVM backend is used).

    GHC's optimiser hasn't got down to the low-level bits very much yet, so the native code generator emits a division instruction for

    x `rem` 2 == 0
    

    and, when the rest of the loop body is as cheap as it is here, that costs a lot of time. LLVM's optimiser has already been taught to replace that with a bitmasking at type Int, so with ghc -O2 -fllvm you don't need to do that manually. With the native code generator, substituting that with

    x .&. 1 == 0
    

    (needs import Data.Bits of course) produces a significant speedup (on normal platforms where a bitwise and is much faster than a division).

The final result

{-# INLINABLE g #-}
g n a p s = worker n p s
  where
    worker k !ap acc
        | k > 0 = worker (k-1) (ap*a) (if k .&. (1 :: Int) == 0 then acc + ap else acc - ap)
        | otherwise = acc

performs not measurably different (for the tested values) from the result of gcc -O3 -msse2 loop.c, except for a = -1, where gcc replaces the multiplication with a negation (assuming all NaNs equivalent).


(1) He's not alone in that,

c = n % 2;
if (!c) s += p;
else s -= p;

seems to be really tricky, as far as I can see everybody(2) got that wrong.

(2) With one exception ;)

نصائح أخرى

As a first step, let's simplify your code:

float g(int n, float a, float p, float s) {
    if (n <= 0) return s;

    float s2 = n % 2 == 0 ? s + p : s - p;
    return g(n - 1, a, a*p, s2)
}

We have turned your original function into a recursive one that exhibits a certain structure. It's a sequence! We can turn this into Haskell conveniently:

gs :: Bool -> Float -> Float -> Float -> [Float]
gs nb a p s = s : gs (not nb) a (a*p) (if nb then s - p else s + p)

Finally we just need to index this list:

g :: Integer -> Float -> Float -> Float -> Float
g n a p s = gs (even n) a p s !! (n - 1)

The code is not tested, but it should work. If not, it's probably just an off-by-one error.

Here is how I would tackle this problem in Haskell. First, I observe that there are several loops merged into one here: we are

  1. forming a geometric sequence (whose factor is a suitably negative version of p)
  2. taking a prefix of the sequence
  3. summing the result

So my solution follows this structure as well, with a tiny bit of s and p thrown in for good measure because that's what your code does. In a from-scratch version, I'd probably drop those two parameters entirely.

g n a p s = sum (s : take n (iterate (*(-a)) start)) where
    start | odd n     = -p
          | otherwise = p

A fairly direct translation would be:

g n a p s =
  if n > 0
  then
    let c = n `mod` 2
        s' = (if c == 0 then (-) else (+)) s p
        p' = p * a
    in g (n-1) a p' s'        
  else s

Look at the signature of the g function (i.e., float g (int n, float a, float p, float s)) you know that your Haskell function will receive 4 elements and return a float, thus:

g :: Integer -> Float -> Float -> Float -> Float

let us now look into the loop, we see that n > 0 is the stop case, and n--; will be the decreasing step used on the recursive call. Therefore:

g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s

to n > 0, you have another conditional if (!(n % 2)) s += p; else s -= p; inside the loop. If n is odd than you will do s += p, p *= a and n--. In Haskell it will be:

g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s
          | odd n = g (n-1) a (p*a) (s+p)

If n is even than you will do s-=p, p*=a; and n--. Thus:

g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s
          | odd n = g (n-1) a (p*a) (s+p)
          | otherwise = g (n-1) a (p*a) (s-p)

To expand on @Landei and @MathematicalOrchid 's comments below the question: The algorithm proposed to solve the problem at hand is always O(n). However, if you realize that what you're actually doing is computing a partial sum of the geometric series, you can use the well-known summation formula:

g n a p s = s + (-1)**n * p * ((-a)**n-1) / (-a-1) 

This will be faster as the exponentiation can be done faster than O(n) by repeated squaring or other clever methods, which are likely automatically employed for integer powers by modern compilers.

You can encode loops almost-naturally with the Haskell Prelude function until :: (a -> Bool) -> (a -> a) -> a -> a:

g :: Int -> Float -> Float -> Float -> Float
g n a p s = 
  fst.snd $ 
    until ((<= 0).fst) 
          (\(n,(!s,!p)) -> (n-1, (if even n then s+p else s-p, p*a)))
          (n,(s,p))

The bang-patterns !s and !p mark strictly-calculated intermediate variables, to prevent excessive laziness which would otherwise harm efficiency.

until pred step start repeatedly applies the step function until pred called with the last generated value will hold, starting with initial value start. It can be represented by the pseudocode:

def until (pred, step, start):             // well, actually,
  while( true ):                         def until (pred, step, start): 
    if pred(start): return(start)          if pred(start): return(start)
    start := step(start)                   call until(pred, step, step(start))

The first pseudocode is equivalent to the second (which is how until is actually implemented) in the presence of tail call optimization, which is why in many functional languages where TCO is present loops are encoded via recursion.

So in Haskell, until is coded as

until p f x  | p x       = x
             | otherwise = until p f (f x)

But it could have been coded differently, making explicit the interim results:

until p f x = last $ go x     -- or, last (go x)
  where go x | p x       = [x]
             | otherwise = x : go (f x)

Using the Haskell standard higher-order functions break and iterate this could be written as a stream-processing code,

until p f x = let (_,(r:_)) = break p (iterate f x) in r
                       -- or: span (not.p) ....

or just

until p f x = head $ dropWhile (not.p) $ iterate f x    -- or, equivalently,
           -- head . dropWhile (not.p) . iterate f $ x

If TCO weren't present in a given Haskell implementation, the last version would be the one to use.


Hopefully this makes clearer how the stream-processing code from Daniel Wagner's answer comes about,

g n a p s = s + (sum . take n . iterate (*(-a)) $ if odd n then (-p) else p)

because the predicate involved is about counting down from n, and

fst . snd . head . dropWhile ((> 0).fst) $
  iterate (\(n,(!s,!p)) -> (n-1, (if even n then s+p else s-p, p*a)))
          (n,(s,p))
===
fst . snd . head . dropWhile ((> 0).fst) $
  iterate (\(n,(!s,!p)) -> (n-1, (s+p, p*(-a))))
          (n,(s, if odd n then (-p) else p))          -- 0 is even
===
fst . (!! n) $
  iterate (\(!s,!p) -> (s+p, p*(-a)))
          (s, if odd n then (-p) else p)    
===
foldl' (+) s . take n . iterate (*(-a)) $ if odd n then (-p) else p

In pure FP, the stream-processing paradigm makes all history of a computation available, as a stream (list) of values.

مرخصة بموجب: CC-BY-SA مع الإسناد
لا تنتمي إلى StackOverflow
scroll top