An approach to median with SIMD

Introduction

In a previous post I looked into using Single instruction, multiple data, SIMD to Sum an array of integer elements.

When working with statistical data, we often need to find the mean/median element of a sample. Calculating the Sum is useful for finding the mean value, but finding the median element requires a different approach.

The definition of median in case of

  • an odd sample size, find 'middle' element of the ordered samples.

  • an even sample size, find the 2 elements in the middle of an ordered sample size, and calculate the average of those.

To calculate the median number, the input data needs to be sorted. While statistical approaches exist to determine the median value of a data set, these approaches only estimate the value. For example, such an approach could use a given confidence interval or linear interpolation. This post seeks approaches to find the median element of a sample based on sorting.

Let us first set assumptions and generalization to the original problem.

  • Assumption: the input sample is typed as an integer array (int[]).

  • Generalization: find the nth element of a collection, where 0 << n << m where m is the size of the collection, and n points towards the middle of the sample size. I consider this a generalization, as if a function f(sample, n) exists so that it returns the nth element of the ordered items while the given input is unordered, we can calculate median. Using f(sample, (sample.Length + 1)/2) and (f(sample, sample.Length/2) + f(sample, sample.Length/2 + 1)) / 2 one can get the median value for odd and even sized inputs respectively.

In practice one would search for an optimization that can preserve an intermediate state of the first f call that is also useful for the second f invocation in case of an even sample size.

In this post, I explore four different implementations of a method f that is given inputs of an int[] source and an int n number, returns the nth item of the ordered input. From data types point of view an IReadOnlyCollection<int> would be more accurate semantically however, for the brevity in this post I will use int[]. Code samples can be adapted to use the above interface type if required.

The code samples in this post are not production grade, they exist purely for demonstration purposes. All code samples below expect good inputs and disregard any input validation.

Non-SIMD solutions

There are plenty of non-SIMD approaches to find the mean element of a sample. This section describes three different solutions. None of these solutions use SIMD, which means processing the input is not vectorized (by the developer explicitly). Note, that the underlying implementation of the methods may use vectorization, as the BCL implements more-and-more methods with SIMD in mind.

Linq

The very first solution is a functional style implementation using Linq.

public int GetByLinq(int[] source, int n) => source.Order().Skip(n - 1).First();

The input source is ordered using the Order() method then we skip n-1 items, so then the nth element can be returned by using the First() extension method. Note, that input arguments are not validated, hence such a method could throw exceptions on evil inputs, such as null source or a negative n or n pointing beyond the length of source.

Array.Sort

The second solution builds up on arrays, and uses the Array.Sort method to sort the items in the array and return the nth element, given the array is 0 indexed.

public int GetByOrder(int[] source, int n)
{
    var copy = source.ToArray();
    Array.Sort(copy);
    return copy[n - 1];
}

Array.Sort sorts in-place, hence a defensive copy is made of the original input. This way it avoids modifying the input arguments. Also note, that using the ToArray() method also gives us freedom to choose a different input type for the source parameter, for example it could be IReadOnlyCollection<int> as described above.

PriorityQueue

.NET has recently added a new type PriorityQueue. This type orders items based on their priority. When all items of the source are added with their face value as the priority, then PriorityQueue will return items in an ordered fashion.

public int GetNthPriorityQueue(int[] source, int n)
{
    var queue = new PriorityQueue<int, int>(source.Select(x => (x, x)));
    int result = 0;
    for (int i = 0; i < n; i++)
        result = queue.Dequeue();
    return result;
}

To find the nth element, one can dequeue n elements, and return the value of the last dequeued item. PriorityQueue implementation is based on heaps.

Heaps

Heaps provide a unique data structure that fits well for the solution of the above problem. A min-heap is a tree structure that keeps the minimum element on the top. We can use the heapsort algorithm solve the problem of medians. Heapsort orders items of a collection. For a collection of n items, it builds a heap in O(n) then applies the siftDown() operation n times. Each invocation of siftDown() returns the next item in the ordered sequence. The complexity of siftDown() operation is O(log n). Altogether this runs the algorithm with O(n * log n) performance to completely order a collection. However, we do not need to run the siftDown() n times to solve the problem of medians. When searching for kth element, only k times we need to invoke siftDown(). While this does only differ in a c constant for the O() notation, in practice it can yield a performance boost for larger arrays.

Heapify operation is used typically for building a heap. Heapify internally uses the siftDown() operation. Another favorable advantage of using heaps is that it can easily represented by an array of items. In this particular case an int[] can represent the heap structure. In case of a binary heap, a parent node has 2 children. When stored by an array a parent node with index i would have children at 2i+1 and 2i+2 indexes. A child at index i has its parent at floor((i-1)/2) index.

SIMD

The advantage of SIMD is to process multiple data with the same instruction. To apply SIMD on heap operation, we need to identify the places where multiple data could be processed by the same operation. In case of heap, a closer look on siftDown() suggests parallelizing the comparison and swap operations. siftDown() compares all children with the parent, and in case of min-heap swaps the minimum child with the parent, so that the parent becomes the minimum. In the case of a binary heap this could mean comparing both children and the parent with a single a comparison.

Doing a comparison with only 2-3 items is vectorizable, however with AVX 256, we could execute the same instruction on 8 integers at the same time (32 bit integer * 8 = 256 bit vector). That yields the question: how we could utilize the remaining empty values of the vector. One solution to that is to increase the number of children of the heap.

d-ary heap is a heap, where nodes have d children instead of 2. With this in mind, we could build 8-heap, while keeping the semantics of the heap described above. Indexing child and parent nodes are slightly different, but the overall idea does not change. For example, node i has children 8i+1, 8i+2, ..., 8i+8.

While an 8-heap seems very convenient, 7-heap and 16-heap could be just as convenient. For example, in case of a 7-heap, the 8th integer of the Vector256<int> could be the parent item. In the case of the 16-heap, we only need one additional instruction to execute 8 integer comparisons at the same time.

16-heap

Based on my measurements it seems the 16-heap yields the best performance out of these options. Hence, in this blog post I will focus on vectorizing a 16-heap. GetNth16Heap implements the f. First, it creates a defensive copy of the source data, while appending 16 additional items to array with values set to int.MaxValue. When operating with vectors, data read into a vector has to be at least as long as the length of the vector itself. This can be a problem when the remaining items in the collection are 6, while the vector length is 16. In such cases an algorithm can fall back to non-vectorized instructions or in certain cases, it can process some of the items duplicated as part of the last instruction. In this post a third approach is chosen: as part of creating the copy of the input collection, the copy is extended with items that do not affect the final outcome of the algorithm but allows a full vector to be created beyond the length of the original source.

All extended values are set to the maximum value, so that in case of a min-heap, it does not influence the results (assuming there is at least one item in the collection).

At this point the code below runs the heapify function: it finds the last non-leaf node of the collection and applies siftdown function to this and all preceding nodes. Once the heap is created the root item has the minimum value. To return the items in an ordered fashion, remove the root (which is the next item in the order) and copy the very last child element to the root position. Then apply siftdown on the root node. This is repeated n-1 times, so then the nth item can be returned.

public int GetNth16Heap(int[] source, int n)
{
    var copy = new int[source.Length + v16Length];
    source.CopyTo(copy.AsSpan());
    for (int i = 1; i <= v16Length; i++)
        copy[^i] = int.MaxValue;
    int lastNonLeaf = (source.Length - 2) / v16Length;
    for (int i = lastNonLeaf; i >= 0; i--)
        Sift16Down(copy, source.Length, i);


    for (int i = 1; i < n; i++)
    {
        var last = _source.Length - i;
        source[0] = source[last];
        source[last] = int.MaxValue;
        Sift16Down(source, last, 0);
    }
    return copy[0];
}

SIMD is used in the Sift16Down method. The goal of siftdown is to find the smallest value of the given node and its children. Then replace it with the parent node. Repeat this recursively for the replaced child node (as the new parent) and its children.

There are 2 places where SIMD may be applied: finding the minimum value, and for swapping the parent with the minimum node.Finding the minimum value is rather straightforward. It loads all children into two vectors (a single vector contains 8 elements), then uses Min method to find the minimum at the matching indexes of the two vectors. Then permutates the items in a way so that it can compare remaining items of the vector. For example, in case of a vector [0,1,2,3,4,5,6,7], to compare section 0 - 4 with 1 - 5, etc. a new vector is created by permutating the items of the original vector in a way that yields [4,5,6,7,...]. So, when comparing this new vector with the original, the first 4 items will contain the smallest items. This operation can be repeated until two relevant items are left in the vector. Finally, the two smallest items are compared with the parent node's value using the regular Min method.

The second possibility to vectorize operations would be finding the index of the child node with the minimum value. The algorithm needs to find the index of this minimum value, so that it can replace it with the parent node. However, based on my testing, I did not find any SIMD based solution that was faster to manually iterating the children. Hence, the code is iterating the child nodes and comparing them one-by-one to the minimum value. The first value equal to the minimum is swapped with the parent.

Note, that the SIMD based implementation depends on using Avx2 and assume the CPU supports vector 256 instructions. In a production grade application, the code would need to check and validate if the given instructions are present.

private void Sift16Down(int[] source, int sourceLength, int index)
{
    var firstChild = index * v16Length + 1;
    while (true)
    {
        var firstChild = index * v16Length + 1;
        if (firstChild > sourceLength)
            return;

        var children0 = Vector256.LoadUnsafe(ref source[firstChild]);
        var children1 = Vector256.LoadUnsafe(ref source[firstChild + v16Length / 2]);

        Vector256<int> vMin0 = Vector256.Min(children0, children1);
        Vector256<int> vMin1 = Vector256.Min(vMin0, Avx2.PermuteVar8x32(vMin0, permutateInt));
        var vMin2 = Vector256.Min(vMin1, Avx2.PermuteVar8x32(vMin1, permutateInt1));
        var vMin30 = vMin2.GetElement(0);
        var vMin31 = vMin2.GetElement(1);
        var min = Math.Min(vMin30, vMin31);
            
        var parentValue = source[index];
        if (min >= parentValue)
            return;

        for (int i = 0; i < v16Length; i++)
        {
            if (source[firstChild + i] == min)
            {
                source[index] = min;
                index = firstChild + i;
                source[index] = parentValue;
                break;
            }
        }
        firstChild = index * v16Length + 1;
        if (firstChild > sourceLength)
            return;
    }
}

16-heap with Positive Numbers

A special case of the problem is when the numbers are known to be positive integers, that are less or equal to value 268435455. Such values could be duration measurements in milliseconds, people's height in millimeters.

One less desired part of the above solution is the manual iteration of children to find the minimum value's index. What if we could encode the index and the value. With the above limitations when using uint we could utilize the four least significant digits to encode the child node's index after the value.

The idea is to shift the bits in each value by 4 bits to left, then use these 4 bits to indicate the index of the child representing values 0-15. Hence 268435455 is the largest value that can be shifted so that it won't overflow. When finding the minimum values, the larger values remain larger, smaller values become smaller, while equal values become non-equal, however, other than the swap, it will not affect final outcome of the algorithm.

private void Sift16DownPositiveNumbers(uint[] source, int sourceLength, int index)
{
    var firstChild = index * v16Length + 1;
    do
    {
        var children0 = Vector256.LoadUnsafe(ref source[firstChild]);
        var children1 = Vector256.LoadUnsafe(ref source[firstChild + v8Length]);

        Vector256<uint> vMin = Vector256.Min(children0, children1);
        vMin = Vector256.Min(vMin, Avx2.PermuteVar8x32(vMin, permutate));
        vMin = Vector256.Min(vMin, Avx2.PermuteVar8x32(vMin, permutate1));
        var vMin30 = vMin.GetElement(0);
        var vMin31 = vMin.GetElement(1);
        uint min = Math.Min(vMin30, vMin31);

        uint parent = source[index];
        var minValue = (min >> 4) << 4;
        var parentValue = (parent >> 4) << 4;
        if (minValue >= parentValue)
            return;

        uint parentPosition = (parent << 28) >> 28;
        source[index] = minValue | parentPosition;
        var minPosition = (min << 28) >> 28;
        index = firstChild + (int)minPosition;
        source[index] = parentValue | minPosition;

        firstChild = index * v16Length + 1;
    } while (firstChild <= sourceLength);
}

Performance Characteristics

In this section let's compare these solutions from the performance point of view. I am focusing on inputs with elements 100, 1000, 10000, and respectively fetching the 50th, 500th and 5000th element.

There is no clear winner among the above presented solutions, each has its strengths and weaknesses based on different inputs.In general using the Array.Sort operation works reasonably well on shorter source collections, while the Linq based approach works well for larger collections. For fetching the median item, either of these two implementations would work best, and they are given out of the box.

In case of large source inputs, Linq's performance tends to vary more based on the actual inputs. Measurements of Source Length 10000: n: 5000 cases yield execution times ~140 us to ~220 on my machine. As a direct comparison for the same inputs GetNth16HeapPositiveNumbers yields a 'constant' ~160 us. This can be explained with that the OrderBy method used in the Linq also uses Introsort which is a combination of qsort and heapsort.

Source Length 100: n: 50

|                      Method |       Mean |    Error |   StdDev |
|---------------------------- |-----------:|---------:|---------:|
|                   GetByLinq | 1,209.3 ns | 21.78 ns | 25.09 ns |
|                  GetByOrder |   351.6 ns |  2.50 ns |  1.95 ns |
|         GetNthPriorityQueue | 1,569.5 ns | 16.89 ns | 14.11 ns |
|                GetNth16Heap |   557.2 ns |  7.84 ns |  6.95 ns |
| GetNth16HeapPositiveNumbers |   749.6 ns |  4.24 ns |  3.54 ns |

Source Length 1000: n: 500

|                      Method |      Mean |     Error |    StdDev |
|---------------------------- |----------:|----------:|----------:|
|                   GetByLinq | 12.219 us | 0.1583 us | 0.1403 us |
|                  GetByOrder |  5.545 us | 0.0782 us | 0.0989 us |
|         GetNthPriorityQueue | 19.953 us | 0.3785 us | 0.4649 us |
|                GetNth16Heap |  6.157 us | 0.0605 us | 0.0566 us |
| GetNth16HeapPositiveNumbers | 10.121 us | 0.0681 us | 0.0604 us |

Source Length 10000: n: 5000

|                      Method |     Mean |   Error |  StdDev |
|---------------------------- |---------:|--------:|--------:|
|                   GetByLinq | 166.2 us | 1.49 us | 1.24 us |
|                  GetByOrder | 361.0 us | 3.53 us | 3.30 us |
|         GetNthPriorityQueue | 536.4 us | 3.92 us | 3.28 us |
|                GetNth16Heap | 306.8 us | 2.87 us | 2.55 us |
| GetNth16HeapPositiveNumbers | 159.3 us | 1.68 us | 1.57 us |

Source Length 10000: n: 5000

|                      Method |     Mean |   Error |  StdDev |
|---------------------------- |---------:|--------:|--------:|
|                   GetByLinq | 207.3 us | 1.73 us | 1.62 us |
|                  GetByOrder | 357.6 us | 2.61 us | 2.32 us |
|         GetNthPriorityQueue | 536.9 us | 5.07 us | 4.74 us |
|                GetNth16Heap | 309.3 us | 2.57 us | 2.40 us |
| GetNth16HeapPositiveNumbers | 160.0 us | 2.00 us | 1.87 us |

Another aspect is a variation of n:

Source Length 10000: n: 1000

|                      Method |      Mean |    Error |   StdDev |
|---------------------------- |----------:|---------:|---------:|
|                   GetByLinq | 116.91 us | 2.288 us | 2.448 us |
| GetNth16HeapPositiveNumbers |  30.38 us | 0.538 us | 0.503 us |

In the above case, when fetching an nth item that is significantly smaller, the vectorized solution yields a significantly better solution compared to all above approaches.

Does it worth vectorizing the above presented f function? Based on my measurements, for most common cases, it does not. The built-in methods provide similar or better performance results, with significantly less and more readable code.

There are certain input combinations where the vectorization can give a better performance. However, these cases are extremely unique, hence I would only use SIMD for these specific cases if the actual measurements on the given inputs show a clear performance advantage, and the code executed is in a tight loop.