Question

I want to test an unknown value against the constraints that a given NumPy dtype implies -- e.g., if I have an integer value, is it small enough to fit in a uint8?

As best I can ascertain, NumPy's dtype architecture doesn't offer a way to do something like this:

### FICTIONAL NUMPY CODE: I made this up ###
try:
    numpy.uint8.validate(rupees)
except numpy.dtype.ValidationError:
    print "Users can't hold more than 255 rupees."

My little fantasy API is based on Django's model-field validators, but that's just one example -- the best mechanism I managed to contrive was along the lines of this:

>>> nd = numpy.array([0,0,0,0,0,0], dtype=numpy.dtype('uint8'))
>>> nd[0]
0
>>> nd[0] = 1
>>> nd[0] = -1
>>> nd
array([255,   0,   0,   0,   0,   0], dtype=uint8)
>>> nd[0] = 257
>>> nd
array([1, 0, 0, 0, 0, 0], dtype=uint8)

Round-tripping the questionable values through a numpy.ndarray typed as explicitly numpy.uint8 gives me back integers that have been wrapped to something with an appropriate size -- without tossing an exception, or raising any other sort of actionable error state.

I'd rather not put on the architecture-astronaut flight suit, of course, but that's preferable the alternative, which looks like unmaintainable spaghetti-monster mess of if dtype(this) ... elif dtype(that) statements. Is there anything I can do here besides embarking on the grandiose and indulgent act of writing my own API?

Was it helpful?

Solution

If a is your original iterable, you could do something along the following lines:

np.all(np.array(a, dtype=np.int8) == a)

Quite simply, this compares the resulting ndarray to the original values, and tells you whether the conversion to ndarray has been lossless.

This will also catch things like using a floating-point type that's too narrow to represent some of the values exactly:

>>> a = [0, 0, 0, 0, 0, 0.123456789]
>>> np.all(np.array(a, dtype=np.float32) == a)
False
>>> np.all(np.array(a, dtype=np.float64) == a)
True

Edit: One caveat when using the above code with floating-point numbers is that NaNs always compare unequal. If required, it is trivial to extend the code to handle that case too.

OTHER TIPS

Have a look at numpy iinfo / finfo structs. They should provide all the information needed for a validation service that works for elementary dtypes. This wont work for composite or binary field dtypes. You still would have to implement the service skeleton for this.

Try using numpy.seterr() with over in order to trigger warnings/errors on overflow.

e.g.

numpy.seterr(over='raise')
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top