I went ahead and let myself get a little distracted today after reading a tweet from Daniel Collin over at DICE. He posted a bit of code that uses SSE intrinsics to find the four highest valued floats in an array of floats (along with their indices). The original code looked like this:
void find_four(const float * a, size_t sz, float * fres, int * ires)
{
__declspec(align(16)) float sinit = -FLT_MAX;
__declspec(align(16)) int iinit[4] = {-1, -1, -1, -1};
// Initialize all the scores to -FLT_MAX
__m128 s = _mm_load_ps1(&sinit);
// We just do shuffles and blends of the indices, so we store the ints as floats.
__m128 index = _mm_load_ps((float*)iinit);
int i = 0;
for(const float* pa = a, *paend = a + sz; pa != paend; ++pa, ++i)
{
// Load the index into all 4 elements of im
__m128 im = _mm_load_ps1((float*)&i);
// Load a value from the array into all 4 elements in v
__m128 v = _mm_load_ps1(pa);
// Compare with the currently best scores
__m128 cmp = _mm_cmpge_ps(v, s);
// Convert to a mask which is one of 0000, 1000, 1100, 1110 or 1111
// Switch on the mask and shuffle/blend as appropriate.
// The same operation is done on both s and index to keep them in sync.
switch(_mm_movemask_ps(cmp))
{
case 0x0:
// dcba -> dcba
break;
case 0x8:
// dcba -> Vcba
s = _mm_blend_ps(s, v, 8);
index = _mm_blend_ps(index, im, 8);
break;
case 0xc:
// dcba -> cVba
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 2, 1, 0));
s = _mm_blend_ps(s, v, 4);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 2, 1, 0));
index = _mm_blend_ps(index, im, 4);
break;
case 0xe:
// dcba -> cbVa
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 1, 0));
s = _mm_blend_ps(s, v, 2);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 1, 0));
index = _mm_blend_ps(index, im, 2);
break;
case 0xf:
// dcba -> cbaV
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 0, 0));
s = _mm_blend_ps(s, v, 1);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 0, 0));
index = _mm_blend_ps(index, im, 1);
break;
default:
assert(0);
break;
}
}
_mm_store_ps(fres, s);
_mm_store_ps((float*)ires, index);
}
You can write up a more straightforward plain scalar version of this code:
void find_four_scalar(const float * a, size_t sz, float * fres, int * ires)
{
fres[0] = fres[1] = fres[2] = fres[3] = -FLT_MAX;
ires[0] = ires[1] = ires[2] = ires[3] = -1;
int i = 0;
for(const float* pa = a, *paend = a + sz; pa != paend; pa++, i++)
{
float v = *pa;
if (v >= fres[0])
{
fres[3] = fres[2];
fres[2] = fres[1];
fres[1] = fres[0];
fres[0] = v;
ires[3] = ires[2];
ires[2] = ires[1];
ires[1] = ires[0];
ires[0] = i;
}
else if (v >= fres[1])
{
fres[3] = fres[2];
fres[2] = fres[1];
fres[1] = v;
ires[3] = ires[2];
ires[2] = ires[1];
ires[1] = i;
}
else if (v >= fres[2])
{
fres[3] = fres[2];
fres[2] = v;
ires[3] = ires[2];
ires[2] = i;
}
else if (v >= fres[3])
{
fres[3] = v;
ires[3] = i;
}
}
}
Given an array of 2^25 random floats, and testing on my i5-3317U I get the following times:
find_four: 54 msWe can provoke the worst case behaviour for the SSE implementation by making the array of values be monotonically increasing - giving us:
find_four_scalar: 79 ms
find_four: 148 msAnd best case behaviour by making sure the four highest values are at the very start:
find_four_scalar: 68 ms
find_four: 54 ms
find_four_scalar: 79 ms
So the question is - can we do better? As it turns out we can make a simple adjustment that improves performance quite a bit for the random and best case and has only a small impact on the worse case. The SSE version still works on one float at a time, but we can adjust it to potentially reject groups of floats at a time (in this case I will try 8). Specifically, I load up 8 floats and using some shuffles and _mm_max_ps I get the maximum value of those 8 floats; if the maximum is less than our current 4 best then we can just skip to the next 8. Simple. The code:
inline void cmp_one_to_four(int i, __m128 v, __m128 & s, __m128 & index)
{
// Load the index into all 4 elements of im
__m128 im = _mm_load_ps1((float*)&i);
// Compare with the currently best scores
__m128 cmp = _mm_cmpge_ps(v, s);
// Convert to a mask which is one of 0000, 1000, 1100, 1110 or 1111
// Switch on the mask and shuffle/blend as appropriate.
// The same operation is done on both s and index to keep them in sync.
switch(_mm_movemask_ps(cmp))
{
case 0x0:
// dcba -> dcba
break;
case 0x8:
// dcba -> Vcba
s = _mm_blend_ps(s, v, 8);
index = _mm_blend_ps(index, im, 8);
break;
case 0xc:
// dcba -> cVba
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 2, 1, 0));
s = _mm_blend_ps(s, v, 4);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 2, 1, 0));
index = _mm_blend_ps(index, im, 4);
break;
case 0xe:
// dcba -> cbVa
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 1, 0));
s = _mm_blend_ps(s, v, 2);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 1, 0));
index = _mm_blend_ps(index, im, 2);
break;
case 0xf:
// dcba -> cbaV
s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 0, 0));
s = _mm_blend_ps(s, v, 1);
index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 0, 0));
index = _mm_blend_ps(index, im, 1);
break;
default:
assert(0);
break;
}
}
void find_four_mod(const float * a, size_t sz, float * fres, int * ires)
{
__declspec(align(16)) float sinit = -FLT_MAX;
__declspec(align(16)) int iinit[4] = {-1, -1, -1, -1};
__m128 s = _mm_load_ps1(&sinit);
__m128 index = _mm_load_ps((float*)iinit);
int i = 0;
for(const float* pa = a, *paend = a + sz; pa != paend; pa += 8, i += 8)
{
__m128 m = _mm_max_ps(_mm_load_ps(pa), _mm_load_ps(pa + 4));
m = _mm_max_ps(_mm_max_ps(_mm_shuffle_ps(m,m,_MM_SHUFFLE(0,0,0,0)), _mm_shuffle_ps(m,m,_MM_SHUFFLE(1,1,1,1))),
_mm_max_ps(_mm_shuffle_ps(m,m,_MM_SHUFFLE(2,2,2,2)), _mm_shuffle_ps(m,m,_MM_SHUFFLE(3,3,3,3))));
if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) == 0)
continue;
__m128 a = _mm_load1_ps(pa);
__m128 b = _mm_load1_ps(pa + 1);
__m128 c = _mm_load1_ps(pa + 2);
__m128 d = _mm_load1_ps(pa + 3);
m = _mm_max_ps(_mm_max_ps(a,b), _mm_max_ps(c,d));
if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) != 0)
{
cmp_one_to_four(i, a, s, index);
cmp_one_to_four(i + 1, b, s, index);
cmp_one_to_four(i + 2, c, s, index);
cmp_one_to_four(i + 3, d, s, index);
}
a = _mm_load1_ps(pa + 4);
b = _mm_load1_ps(pa + 5);
c = _mm_load1_ps(pa + 6);
d = _mm_load1_ps(pa + 7);
m = _mm_max_ps(_mm_max_ps(a,b), _mm_max_ps(c,d));
if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) != 0)
{
cmp_one_to_four(i + 4, a, s, index);
cmp_one_to_four(i + 5, b, s, index);
cmp_one_to_four(i + 6, c, s, index);
cmp_one_to_four(i + 7, d, s, index);
}
}
_mm_store_ps(fres, s);
_mm_store_ps((float*)ires, index);
}
How does this clock in? Running the same test with this code yields:
Random-Case: 12 msSo over a 4x improvement for the best/random cases, and only slightly slower in the worst case scenario.
Worst-Case: 159 ms
Best-Case: 12 ms
Anyway, that was a fun little distraction from the networking code I was otherwise working on...