This seems to do the trick:
mask = a[:,:,0] > 1
a[:,:,4][mask] = 255
So the indexing just needed to be a little different and then it's just standard practice of applying a mask.
edit @Ophion showed this is much better written as:
a[mask,:,-1] = 255