質問

A friend of mine showed me a home exercise in a C++ course which he attend. Since I already know C++, but just started learning Haskell I tried to solve the exercise in the "Haskell way".

These are the exercise instructions (I translated from our native language so please comment if the instructions aren't clear):

Write a program which reads non-zero coefficients (A,B,C,D) from the user and places them in the following equation: A*x + B*y + C*z = D The program should also read from the user N, which represents a range. The program should find all possible integral solutions for the equation in the range -N/2 to N/2.

For example:

Input: A = 2,B = -3,C = -1, D = 5, N = 4
Output: (-1,-2,-1), (0,-2, 1), (0,-1,-2), (1,-1, 0), (2,-1,2), (2,0, -1)

The most straight-forward algorithm is to try all possibilities by brute force. I implemented it in Haskell in the following way:

triSolve :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
triSolve a b c d n =
  let equation x y z = (a * x + b * y + c * z) == d
      minN = div (-n) 2
      maxN = div n 2
  in [(x,y,z) | x <- [minN..maxN], y <- [minN..maxN], z <- [minN..maxN], equation x y z]

So far so good, but the exercise instructions note that a more efficient algorithm can be implemented, so I thought how to make it better. Since the equation is linear, based on the assumption that Z is always the first to be incremented, once a solution has been found there's no point to increment Z. Instead, I should increment Y, set Z to the minimum value of the range and keep going. This way I can save redundant executions. Since there are no loops in Haskell (to my understanding at least) I realized that such algorithm should be implemented by using a recursion. I implemented the algorithm in the following way:

solutions :: (Integer -> Integer -> Integer -> Bool) -> Integer -> Integer -> Integer -> Integer -> Integer ->     [(Integer,Integer,Integer)]
solutions f maxN minN x y z
  | solved = (x,y,z):nextCall x (y + 1) minN
  | x >= maxN && y >= maxN && z >= maxN = []
  | z >= maxN && y >= maxN = nextCall (x + 1) minN minN
  | z >= maxN = nextCall x (y + 1) minN
  | otherwise = nextCall x y (z + 1)
  where solved = f x y z
        nextCall = solutions f maxN minN

triSolve' :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
triSolve' a b c d n =
  let equation x y z = (a * x + b * y + c * z) == d
      minN = div (-n) 2
      maxN = div n 2
  in solutions equation maxN minN minN minN minN

Both yield the same results. However, trying to measure the execution time yielded the following results:

*Main> length $ triSolve' 2 (-3) (-1) 5 100
3398
(2.81 secs, 971648320 bytes)
*Main> length $ triSolve 2 (-3) (-1) 5 100
3398
(1.73 secs, 621862528 bytes)

Meaning that the dumb algorithm actually preforms better than the more sophisticated one. Based on the assumption that my algorithm was correct (which I hope won't turn as wrong :) ), I assume that the second algorithm suffers from an overhead created by the recursion, which the first algorithm isn't since it's implemented using a list comprehension. Is there a way to implement in Haskell a better algorithm than the dumb one? (Also, I'll be glad to receive general feedbacks about my coding style)

役に立ちましたか?

解決

Of course there is. We have:

a*x + b*y + c*z = d

and as soon as we assume values for x and y, we have that

a*x + b*y = n

where n is a number we know. Hence

c*z = d - n
z = (d - n) / c

And we keep only integral zs.

他のヒント

It's worth noticing that list comprehensions are given special treatment by GHC, and are generally very fast. This could explain why your triSolve (which uses a list comprehension) is faster than triSolve' (which doesn't).

For example, the solution

solve :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
-- "Buffalo buffalo buffalo buffalo Buffalo buffalo buffalo..."
solve a b c d n =
    [(x,y,z) | x <- vals, y <- vals
             , let p = a*x +b*y
             , let z = (d - p) `div` c
             , z >= minN, z <= maxN, c * z == d - p ]
    where
        minN = negate (n `div` 2)
        maxN = (n `div` 2)
        vals = [minN..maxN]

runs fast on my machine:

> length $ solve 2 (-3) (-1) 5 100
3398
(0.03 secs, 4111220 bytes)

whereas the equivalent code written using do notation:

solveM :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
solveM a b c d n = do
    x <- vals
    y <- vals
    let p = a * x + b * y
        z = (d - p) `div` c
    guard $ z >= minN
    guard $ z <= maxN
    guard $ z * c == d - p
    return (x,y,z)
    where
        minN = negate (n `div` 2)
        maxN = (n `div` 2)
        vals = [minN..maxN]

takes twice as long to run and uses twice as much memory:

> length $ solveM 2 (-3) (-1) 5 100
3398
(0.06 secs, 6639244 bytes) 

Usual caveats about testing within GHCI apply -- if you really want to see the difference, you need to compile the code with -O2 and use a decent benchmarking library (like Criterion).

ライセンス: CC-BY-SA帰属
所属していません StackOverflow
scroll top