Question

How can I create a __m128i having the n most significant bits set (in the entire vector)? I need this to mask portions of a buffer that are relevant for a computation. If possible, the solution should have no branches, but this seems hard to achieve

How can I do this ?

Était-ce utile?

La solution 2

You can use one of the methods from this question to generate a mask with the MS n bytes set to all ones. You would then just need to fix up any remaining bits when n is not a multiple of 8.

I suggest trying something like this:

- init vector A = all (8 bit) elements to the residual mask of n % 8 bits
- init vector B = mask of n / 8 bytes using one of the above-mentioned methods
- init vector C = mask of (n + 7) / 8 bytes using one of the above-mentioned methods
- result = A | B & C

So for example if n = 36:

A = f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0
B = ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00 00
C = ff ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00
==> ff ff ff ff f0 00 00 00 00 00 00 00 00 00 00 00

This would be branchless, as required, but it's probably of the order of ~10 instructions. There may be a more efficient method but I would need to give this some more thought.

Autres conseils

I'm adding this as a second answer and leaving the first answer for historical interest. It looks like you can do something more efficient with _mm_slli_epi64:

#include <emmintrin.h>
#include <stdio.h>

__m128i bit_mask(int n)
{
    __m128i v0 = _mm_set_epi64x(-1, -(n > 64)); // AND mask
    __m128i v1 = _mm_set_epi64x(-(n > 64), 0);  // OR mask
    __m128i v2 = _mm_slli_epi64(_mm_set1_epi64x(-1), (128 - n) & 63);
    v2 = _mm_and_si128(v2, v0);
    v2 = _mm_or_si128(v2, v1);
    return v2;
}

int main(int argc, char *argv[])
{
    int n = 36;

    if (argc > 1) n = atoi(argv[1]);

    printf("bit_mask(%3d) = %02vx\n", n, bit_mask(n));

    return 0;
}

Test:

$ gcc -Wall -msse2 sse_bit_mask.c
$ for n in 1 2 3 63 64 65 127 128 ; do ./a.out $n ; done
bit_mask(  1) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 80
bit_mask(  2) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 c0
bit_mask(  3) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 e0
bit_mask( 63) = 00 00 00 00 00 00 00 00 fe ff ff ff ff ff ff ff
bit_mask( 64) = 00 00 00 00 00 00 00 00 ff ff ff ff ff ff ff ff
bit_mask( 65) = 00 00 00 00 00 00 00 80 ff ff ff ff ff ff ff ff
bit_mask(127) = fe ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff
bit_mask(128) = ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff

The next two solutions are an alternative to Paul R's answer. These solutions are of interest when the masks are needed in the context of a performance critical loop.


SSE2

__m128i bit_mask_v2(unsigned int n){                      /* Create an __m128i vector with the n most significant bits set to 1  */
    __m128i ones_hi   = _mm_set_epi64x(-1,0);             /* Binary vector of bits 1...1 and 0...0                               */
    __m128i ones_lo   = _mm_set_epi64x(0,-1);             /* Binary vector of bits 0...0 and 1...1                               */
    __m128i cnst64    = _mm_set1_epi64x(64);
    __m128i cnst128   = _mm_set1_epi64x(128);

    __m128i shift     = _mm_cvtsi32_si128(n);             /* Move n to SSE register                                              */
    __m128i shift_hi  = _mm_subs_epu16(cnst64,shift);     /* Subtract with saturation                                            */
    __m128i shift_lo  = _mm_subs_epu16(cnst128,shift);   
    __m128i hi        = _mm_sll_epi64(ones_hi,shift_hi);  /* Shift the hi bits 64-n positions if 64-n>=0, else no shift          */
    __m128i lo        = _mm_sll_epi64(ones_lo,shift_lo);  /* Shift the lo bits 128-n positions if 128-n>=0, else no shift        */
               return   _mm_or_si128(lo,hi);              /* Merge hi and lo                                                     */
}


SSSE3 The SSSE3 case is more interesting. The pshufb instruction is used as a small lookup table. It took me some time to figure out the right combination of the (saturated) arithmetic and the constants.

__m128i bit_mask_SSSE3(unsigned int n){                   /* Create an __m128i vector with the n most significant bits set to 1   */
    __m128i sat_const = _mm_set_epi8(247,239,231,223,   215,207,199,191,   183,175,167,159,   151,143,135,127);  /* Constant used in combination with saturating addition */
    __m128i sub_const = _mm_set1_epi8(248);
    __m128i pshub_lut = _mm_set_epi8(0,0,0,0,   0,0,0,0,   
                          0b11111111, 0b11111110, 0b11111100, 0b11111000,
                          0b11110000, 0b11100000, 0b11000000, 0b10000000);

    __m128i shift_bc  = _mm_set1_epi8(n);                         /* Broadcast n to the 16 8-bit elements.                                */
    __m128i shft_byte = _mm_adds_epu8(shift_bc,sat_const);        /* The constants sat_const and sub_const are selected such that         */
    __m128i shuf_indx = _mm_sub_epi8(shft_byte,sub_const);        /* _mm_shuffle_epi8 can be used as a tiny lookup table                  */
                return  _mm_shuffle_epi8(pshub_lut,shuf_indx);    /* which finds the right bit pattern at the right position.             */
}


Functionality
For 1<=n<=128, which was specified by the OP, the functions bit_mask_Paul_R(n) (Paul R's answer), and bit_mask_v2(n) produce the same results:

bit_mask_Paul_R(  0) = FFFFFFFFFFFFFFFF 0000000000000000
bit_mask_Paul_R(  1) = 8000000000000000 0000000000000000
bit_mask_Paul_R(  2) = C000000000000000 0000000000000000
bit_mask_Paul_R(  3) = E000000000000000 0000000000000000
.....
bit_mask_Paul_R(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_Paul_R(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_Paul_R(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF


bit_mask_v2(  0) = 0000000000000000 0000000000000000
bit_mask_v2(  1) = 8000000000000000 0000000000000000
bit_mask_v2(  2) = C000000000000000 0000000000000000
bit_mask_v2(  3) = E000000000000000 0000000000000000
.....
bit_mask_v2(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_v2(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_v2(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF


bit_mask_SSSE3(  0) = 0000000000000000 0000000000000000
bit_mask_SSSE3(  1) = 8000000000000000 0000000000000000
bit_mask_SSSE3(  2) = C000000000000000 0000000000000000
bit_mask_SSSE3(  3) = E000000000000000 0000000000000000
.....
bit_mask_SSSE3(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_SSSE3(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_SSSE3(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF

For n=0 the most reasonable result is the zero vector, which is produced by bit_mask_v2(n) and bit_mask_SSSE3(n).


Performance
To get a rough impression of the performance of the different functions, the following piece of code is used:

__m128i sum = _mm_setzero_si128();
for (i=0;i<1000000000;i=i+1){
    sum=_mm_add_epi64(sum,bit_mask_Paul_R(i));   // or use next line instead 
//    sum=_mm_add_epi64(sum,bit_mask_v2(i));
//    sum=_mm_add_epi64(sum,bit_mask_SSSE3(i));
}
_mm_storeu_si128((__m128i*)x,sum);
printf("sum = %016lX %016lX\n", x[1],x[0]);

The performance of the code depends slightly on the type of instruction encoding. GCC options opts1 = -O3 -m64 -Wall -march=nehalem lead to non-vex encoded sse instructions, while opts2 = -O3 -m64 -Wall -march=sandybridge compiles to vex encoded avx128 instructions.

The results with gcc 5.4 are:

Cycles per iteration on Intel Skylake, estimated with: perf stat -d ./a.out
                     opts1       opts2
bit_mask_Paul_R       6.0         7.0
bit_mask_v2           3.8         3.3
bit_mask_SSSE3        3.0         3.0

In practice the performance will depend on the cpu type and the surrounding code. The performance of bit_mask_SSSE3 is limited by port 5 pressure; three instructions (one movd and the two pshufb-s) per iteration are handled by port 5.

With AVX2, a more efficient code is possible, see here.

Licencié sous: CC-BY-SA avec attribution
Non affilié à StackOverflow
scroll top