質問

Python function calls are relatively expensive. But I keep running into cases where I would like to be able to call function in different ways, and the easiest way seems to be to create a light wrapper around the function with the different call.

Is there a more pythonic and/or more effecient way to enable more than way to call a function?


For a completely contrived, overly simple, example that illustrates what I'm asking:

from math import sqrt
from collections import namedtuple

Point = namedtuple('Point', 'x y')

def distFunc(x1, y1, x2, y2):
    return sqrt((x1-x2)**2 + (y1-y2)**2)

def pointDistFunc(p1, p2):
    return distFunc(p1.x, p1.y, p2.x, p2.y)

Is there a better way to write the pointDistFunc?

As it is, this timeit:

p1 = Point(1, 1)
p2 = Point(100, 100)

if __name__ == '__main__':
    import timeit
    print(timeit.timeit("distFunc(1, 1, 100, 100)", setup="from __main__ import distFunc"))
    print(timeit.timeit('pointDistFunc(p1, p2)', setup= 'from __main__ import pointDistFunc, p1, p2'))

Gives:

0.392938508373
0.977704155415

So the overhead seems to be noticeable.

役に立ちましたか?

解決

I think that in general, the best thing is to write the clearest code and not worry overmuch about efficiency. I think in this case I would code it just the way you already did and not worry about it.

But if you know a bit of code will be heavily called, and you want it to be as fast as possible, then you can perhaps speed things up by rewriting. In your trivial example, you could gain speed by rewriting the wrapper to just do the calculation:

def pointDistFunc(p1, p2):
    return sqrt((p1.x-p2.x)**2 + (p1.y-p2.y)**2)

Ideally you should have some unit tests somewhere that check that

pointDistFunc(p1, p2) == distFunc(p1.x, p1.y, p2.x, p2.y)

That way if you wind up changing distFunc() but forget to also change pointDistFunc() the test will fail and you will be reminded.

That guideline you mentioned was not so much intended to keep you from writing wrappers; it was more to suggest how to rewrite hot spots involving things like lists:

def gen_point_dist_from_lst(lst, p2):
    return (sqrt((p1.x-p2.x)**2 + (p1.y-p2.y)**2) for p1 in lst)

If the list has 1000 points, then the above saves 2000 function calls compared to the straightforward generator expression

(pointDistFunc(p1, p2) for p1 in lst)

The key is to actually have a problem first before you try these tricks. If your program runs fast enough already, maybe you don't need to optimize anything. If need your code to be faster, you can try these tricks.

P.S. If you can use PyPy for what you are doing, it should wipe out the overhead from the function calls. PyPy has a just-in-time compiler to optimize hot spots in your program for you.

http://speed.pypy.org/

ライセンス: CC-BY-SA帰属
所属していません StackOverflow
scroll top