Monday, March 3, 2014

SIMD-ifying Google Code Jam "Shop Credit" puzzle - Part 5, Binary Search

Yes, 2 in a row for today.  If you just read my post on doing a sort & search you saw I broke the O(n^2)-ness of the previous attempts, but was still slower than the best (vector) version.

1. "Tree" Version

And one reason why is that the sort & search is a bit wasteful.  It does n*log(n) work to sort and then another log(n) to search.  But the n*log(n) work is 'tossed' at the end as only one query is done on all that sorting work.

So then it occurred to me to instead 'search as you go'.  If I were to do a binary tree, I can insert an item in log(n) time.  I still must do O(n*log(n)) work, but maybe I can skip some of the sorting work.  That is, maybe I can skip sorting things I never need to scan.

Here's the code I came up with.  First I defined a struct to hold the data for an item.

struct item {
    int price;
    int idx;
    item* right;
    item* left;
};

The 'right' and 'left' pointers define the our 2 nodes - the left being smaller and the right being larger in price.  Next, as before I do my min / max determination for later culling.  

bool tree_version(int sz, item* items, int credit, int& ans1, int& ans2) {

    int min = items[0].price;
    int max = items[0].price;
    for (int i = 0; i < sz; i++) {
        min = items[i].price < min ? items[i].price : min;
        max = items[i].price > max ? items[i].price : max;
    }

Now it's time to build the binary tree.  I'm not being very fancy about it; it's not a balanced tree, for instance.

        for (int i = 1; i < sz; i++) {
        int goal = credit - items[i].price;
        if (goal > max) continue;
        if (goal < min) continue;

        item* it = &items[i];
        item* head = &items[0];

        while (true) {
            if (it->price < head->price) {
                if (head->left) {
                    head = head->left;
                } else {
                    head->left = it;
                    break;
                }
            } else {
                if (head->right) {
                    head = head->right;
                } else {
                    head->right = it;
                    break;
                }
            }
        }
    // look through the current tree
        if (scan(sz, &items[0], goal, ans1, ans2, i)) {
            return 1;
        }
    }

Note the 'scan' function is in the insert itself.  Since we are discarding the tree after one query we can abort when we find what we want.  

The 'scan' implementation is pretty simple...except for one thing....in my first version I did "did I find the answer" check first.  But then I realized this is the least common case.  Most of the time we don't find the answer.  By moving this to be the 3rd conditional clause I branch less often.  This showed up as a 6% gain.

bool scan(int sz, const item* items, int goal, int& ans1, int& ans2, 
    int skip_idx) {
    // scan tree
    const item* head = items;
    while (head) {
        if (head->price > goal) {
    // traverse left
            head = head->left;
        } else if (head->price < goal) {
    // traverse right
            head = head->right;
        } else if (head->idx != skip_idx) {
    // have I found the answer?
            ans1 = skip_idx < head->idx ? skip_idx + 1: head->idx+ 1;
            ans2 = skip_idx > head->idx ? skip_idx + 1: head->idx+ 1;
            return 1;
        } else {
    // traverse left; this is the case where I have found an item which 
    //  when added to itself gives the goal, which is not desired
            head = head->left;
        }
    }
    return 0;

}
   
2. Results

Version
Performance (abs)
Speedup
Time to Implement
Naive (scalar)
32M clocks
1x
15 mins
Less Naive (scalar)
16M clocks
2x
10 mins
Min / Max Cull (scalar) 
9M clocks
3.5x
20 mins
Intrinsics (vector) 
4M clocks
8x
540 mins
Sort + search (scalar) 
5M clocks
6.4x
20 mins
Binary Tree (scalar) 
3.4M clocks
9.4x
17 mins

Hooray!  A new record.  And again, there may be vector opportunities to add on top of this.



3. Conclusions

  1. In this case, at least so far, proper use of the scalar hardware has actually surpassed the best vector result
  2. Moreover, recall that while all 'scalar' versions are plainly-compiled C++, the vector version required tricky, non-portable, intricate intrinsics work.
  3. It would be exciting to see what vector + the BST can do.  Maybe the characteristics of the vector instruction set will preclude techniques that got us this far.
  4. This is an incremental improvement over the 'sort' version that takes advantage of 'early out' situations.


SIMD-ifying Google Code Jam "Shop Credit" puzzle - Part 4 - Sorting

Ok, time to take a break from SIMD, at least for now.

1. O(n^2)
One problem with all SIMD approaches thus far is that they have all been O(n^2).  Remember the super naive version?

bool naive(int sz, const int* data, int credit, int& ans1, int& ans2) {
    int min, max;
    for (int i = 0; i < sz; i++) {
        for (int j = 0; j < sz; j++) {
            if (i == j) continue;
            if (data[i] + data[j] == credit) {
                ans1 = i < j ? i + 1 : j + 1;
                ans2 = i < j ? j + 1 : i + 1;
                return 1;
            }
        }
    }
    return 0;
}

We see that for each search we do an 'i' and a 'j' loop each (potentially) as large as 'sz'.  Hence, our O(n^2).

In our SIMD version we had cut down that 'j' loop by the vector width (4 for SSE)....

   for (int i = 0; i < sz; i++) {
   // <...>
        int iters = (sz - i) / 4 + 1;
        for (int j = 0; j <  iters; j++) {
   // <...>
                }
      }

But O(n^2) / 4 is - at least for non-trivial sizes of 'n' dominated by n^2.

2. Sorting
Of the "2 n's" that multiply together, there isn't much we can do about the 'outer' one.  One way or another we will have to scan all the items once on average.  But what about the 'inner' n?

We know that a good sorting algorithm can find data in log2(n) time.  But it also costs us n*log2(n) time to sort it for a total of n*log2(n) + log2(n).

So, how does n^2 compare with n*log2(n) + log2(n)?  This of course reduces to (n + 1)*log2(n), which devolves to O(n*log2(n)), which is always less than n^2.  But in case you don't believe me, here's what Wolfram Alpha has to say about it.

So I tried it out and implemented it like this:

bool sort_version(int sz, int* data, int credit, int& ans1, int& ans2) {
    int min = data[0];
    int max = data[0];
    int j;
    int idxs[sz];
    for (int i = 0; i < sz; i++) {
        idxs[i] = i;
    }

    qsort(data, idxs, range(0, sz - 1));

    for (int i = 0; i < sz; i++) {
        min = data[i] < min ? data[i] : min;
        max = data[i] > max ? data[i] : max;
    }

    for (int i = 0; i < sz; i++) {
        if (data[i] + max < credit) continue; // hardly helps
        if (data[i] + min > credit) continue; // helps; 2x

        int j = find(data, credit - data[i], range(i+1, sz));
        if (j == -1) continue;
        ans1 = idxs[i] + 1;
        ans2 = idxs[j] + 1;

        if (ans1 > ans2) {
            int tmp = ans1;
            ans1 = ans2;
            ans2 = tmp;
        }
        return 1;
    }
    return 0;
}

Here's our results along with our previous work

Version
Performance (abs)
Speedup
Time to Implement
Naive (scalar)
32M clocks
1x
15 mins
Less Naive (scalar)
16M clocks
2x
10 mins
Min / Max Cull (scalar) 
9M clocks
3.5x
20 mins
Intrinsics (vector) 
4M clocks
8x
540 mins
Sort + search (scalar) 
5M clocks
6.4x
20 mins

3. Conclusions
Yes, we lost ground to the vector version.  However note:

  1. We are not using vector hardware...yet.  This result really compares against the other bolded row, which was the best non-vector result we achieved.  So, we achieved 1.8x by doing sorting / searching.  Who knows?  With a vectorized sort / search we might achieve 12.8x...
  2. This is in some sense a worst case scenario for sorting.  We are dong n*log(n) work to sort, then log(n) to search, and then throwing all that sorting work away.  If, say, we were tasked for finding lots of item combinations within the same inventory we could expect O(log(n)) results, which would crush all results so far.  In fact I project if searching was done to the extent the sorting were to be negligible, the result would be a gain of 3500x
  3. The sort / search version tracks O(n*log(n)) very nicely as you can see in the chart below.  The other versions 'get lucky' when they find answers close to the beginning.  The vector version is simply an attenuated result of the scalar version.
  4. Realize also we are dealing with only 1800 points.  For much larger working sets the sort / search version would crush it.
  5. Nonetheless, it is clear that if you have to pick where to invest your tuning time in workloads like this you likely want to do your vectorization work last.  I spent more time on the vectorized version than all the others combined.