Friday, February 7, 2014

SIMD-ifying Google Code Jam "Shop Credit" puzzle - Part 3

So last time I indulged in a reverie tuning the part of the 'Shop Credit' puzzle that doesn't affect performance - the min / max determination.  Let's get down to business and SIMD-ify the part that does matter.

Recall we are given a list of numbers and a desired sum.  We need to find the which of the 2 numbers in the list add up to the sum.  So if we had:

sum: 50
list: 0, 1, 2, 3, 4, 5, 6, 20, 30

...we want the answer {20, 30}

With SIMD we are looking for a way to do 'n' checks at a time preferably on data that lives sequentially.

There are different strategies. Here's the one I used in my program.

1) Select the first item in the list.  Let's call this A.
2) Take 4 (say for SSE) more items from the list.  Let's call them B{1-4}
3) See if any of A + B{1-4} adds up to our sum
4) If so, determine which of the 4 did and leave
5) Else, pick 4 more items, goto 2

First, here's the whole code.  Then I'll break it down to show how it handles these steps.

    __m128i credit_exp = _mm_set1_epi32(credit);

    for (int i = 0; i < sz; i++) {
        int A = data[i];
        if (A + max < credit) continue; // hardly helps
        if (A + min > credit) continue; // helps; 2x

        __m128i A_broadcast = _mm_set1_epi32(A);

        int start_at = (i / 4)*4;
        const int* B_ptr = &data[start_at];
        int iters = (sz - i) / 4 + 1;
        for (int j = 0; j <  iters; j++) {
            __m128i B_vals = _mm_load_si128((__m128i*)B_ptr);
            __m128i sum    = _mm_add_epi32(B_vals, A_broadcast);

            __m128i cmp = _mm_cmpeq_epi32(sum, credit_exp);
            int lanes = _mm_movemask_ps((__m128)cmp);
            if (lanes != 0) {
                if (j == 0) {
                    lanes &= ~(1 << (i % 4));
                }
                if (j ==  iters -1) {
                    int mask = 0x0f >> (3 - ((sz + 3)% 4));
                    lanes &= mask;
                }
            }

            if (lanes) {
                int answ = 4*j + start_at + _bit_scan_reverse(lanes);
                ans1 = i < answ ? i + 1 : answ + 1;
                ans2 = i < answ ? answ + 1 : i + 1;
                return 1;
            }
            B_ptr += 4;
        }
    }


1) Select the first item in the list.  Let's call this A.

 __m128i credit_exp = _mm_set1_epi32(credit);

We will want to compare multiple lanes to the credit (correct sum).  To do this we must have the sum in all lanes.  This takes the credit and broadcasts it to all lanes.

for (int i = 0; i < sz; i++) {
    int A = data[i];
    if (A + max < credit) continue; // hardly helps
    if (A + min > credit) continue; // helps; 2x

Here we are simply walking through our loop just as before.  We are doing the same "don't proceed if impossible" check as before (that is, if the target sum is 500, the max is 400 and this value is 5, skip).  It would be interesting to go back and have this check 4 at a time.

    __m128i A_broadcast = _mm_set1_epi32(A);

Just like 'credit', we broadcast 'A' to all lanes so it can be easily compared later.

2) Take 4 more items from the list.  Let's call them B{1-4}

    int start_at = (i / 4)*4;
    const int* B_ptr = &data[start_at];


Recall that if we are at the outer ('A') check of the 10th list element, we can skip checking elements 0-9 in our inner loop.  We are doing the same thing here.

Almost.  At least with SSE SIMD, one has to load an aligned chunk of data.  So when we are at outer element 5, we cannot start at 6; we must start at the smallest index that is at an alignment boundary.  Instead, as you see below, we must start at '4'

| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
That our inner loop must start behind the outer loop in most cases will create another problem to solve later.

3) See if any of A + B{1-4} adds up to our sum

    int iters = (sz - i) / 4 + 1;
iters will tell us how many SIMD chunks we have to do.  If there are 10 elements like so....

| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
| 8 | 9 |        

And our 'outer' index is at '2', we have 2 more SIMD groups to check (4-7 and 8,9).  Iters would compute to '2'.  Let's do our 'iters'.

    for (int j = 0; j <  iters; j++) {

        __m128i B_vals = _mm_load_si128((__m128i*)B_ptr);
        __m128i sum    = _mm_add_epi32(B_vals, A_broadcast);  
 

These 2 intrinsics are pretty simple; we are loading 4 'B' values into 'B_vals'.  Then we add our 'A' values to this.  If 'A' is 25 we have something like this.  All seem so simple, no?

   A: | 25 | 25 | 25 | 25 |
   B: | 14 |  5 |  9 |  7 |
 sum: | 39 | 30 | 34 | 32 |


4) If so, determine which of the 4 sums is the one we want

This is where things get nasty.  Finding the lane that compares to the answer we want is pretty easy:

        __m128i cmp = _mm_cmpeq_epi32(sum, credit_exp); 

The first intrinsic compares 'sum' and the broadcasted credit.  If sum was as above and the credit was 30, we'd have:


    sum: | 39 | 30 | 34 | 32 |

 credit: | 30 | 30 | 30 | 30 |

   cmp:  | 0x0 | 0xf0000000 | 0x0 | 0x0 |


SSE (and AVX for that matter) have a somewhat odd property for comparisons like this: the result in a lane is not '0' or '1' as a C-programmer would expect, but '0x0' or 'High bit set'.  This can create all kinds of puzzles if you are doing your own masking scheme, but I digress.

Luckily, pulling this 'cmp' out of lanes and into an integer is pretty easy:


        int lanes = _mm_movemask_ps((__m128)cmp);

In this case, 'lanes' would be 2.  The MOVMASK opcode which this intrinsic triggers in fact returns a bitfield for whichever lanes are true.  If the answer had been:
   cmp:  | 0x0 | 0xf0000000 | 0x0 | 0xf0000000 |


...'lanes' would instead by 10 (8 + 2).  We'll see why this is tricky in a minute.

4a) Do nasty stuff to make sure we really have the lane we want

I didn't have this listed as a step before so as to not scare you off.  Ok, here goes!

        if (lanes != 0) {

Ok, no big deal - we only have work to do if we in fact have one match.

            if (j == 0) {
                lanes &= ~(1 << (i % 4));
            }


What?!? Well recall the bit about how our 'inner' loop must overlap the outer loop.  This means sometimes we are checking an item against itself.  If the desired answer is 50, and 'A' is 25, then this overlap means we could get a false hit at 25 (A) + 25 (A) - that is, buying the same item twice.

This bit of bit twiddling says "if we are on the first SIMD chunk (which is the one that might overlap), mask out the element that corresponds to 'A'.  Note we must preserve what might be a legitimate match in another lane.

It gets worse.

             if (j ==  iters -1) {
                int mask = 0x0f >> (3 - ((sz + 4 - 1)% 4));
                lanes &= mask;
            }


Here's the other special case - when we get to the end of the line.  If we have a case like this:
| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
| 8 | 9 |        


And we get to the third 'row', we must disregard the 2 slots after '9'.  They might be zero, or they might be garbage.  Either way they can trigger a false positive.  If the desired answer is 100 and we have 'A' = 100, then '0' would give a wrong answer.  The trickery on how this mask works is left as an exercise to the reader (gosh I've always wanted to say that).

Finally, we have the proper lane (if any).  Recall I said that the lanes are arranged in a bit field - 1, 2, 4, 8 for lanes 1, 2, 3, 4.  But to get our index we want the 1, 2, 3, 4.  One long-handed way is to do this:

        if (lanes) {
            int answ = j*4;
            if (lanes & 0x01) {
                answ += 0;
            }
            else if (lanes & 0x02) {
                answ += 1;
            }
            else if (lanes & 0x04) {
                answ += 2;
            }
            else if (lanes & 0x08) {
                answ += 3;
            }
        }


Yuck.  Well this is of course a First Find One operation.  Again, intrinsics come to our rescue (not that we were in much danger - this is a small portion of the computation)

        if (lanes) {
       int answ = 4*j + start_at + _bit_scan_reverse(lanes);
    }
And finally we're done.  Was it worth it?  Let's see:


The red line is our baseline - the not-quite-totally naive scalar version.  That is, it's the version that does the same tricks this version does, only with no intrinsic help.  As we'd expect, we see degradation at smaller datasets - teh SIMD version is even worse on one case - but we see a nice 2-ish average as we grow larger.

But though this is great, the amount of time investment for me was enormous vs. the non-SIMD versions.  No, I didn't code up those lane-conditioning tests perfectly in 5 minutes.  Truth be told, I probably invested a good day into shaping this up vs a matter of minutes - maybe an hour total for the previous scalar work.  2x is probably worth it to most people, but the bang for your programming time buck is not there.

And we have neglected the more dominant trend, which is this is still an O(N^2) solution.  This plot shows the clocks spent per point vs. how many points were checked.



This data is a bit skewed by the fact that often the code gets 'lucky' and finishes early.  Still, we would like to see this plot be constant, not with the clear trend toward increasing clocks-per-point for more points; obviously this will eventually blow up.

Till next time...

No comments:

Post a Comment