《算法设计与分析(第2版)》黄宇著 14章 堆与偏序关系

14.6 有一组元素,它们不断地被动态加入和删除,但是我们需要随时找出当前所有元素的中位数。为此,请设计一个数据结构,以支持对数时间的插入、删除和常数时间的找出中位数(提示:利用两个堆来实现该数据结构)。

2021.2.5更新:本文属于对顶堆求数据量第k大数的特例:求中位数,基本思想和代码可以从下述特例中泛化。


​ 首先,我们思考一下中位数的性质:如果一个数是中位数,那么在这个数组中,大于中位数的数目和小于中位数的数目,要么相等,要么就相差一

​ 我们自然可以考虑到利用维护这个数据结构,大顶堆存放前半小的数,小顶堆存放后半大的数。

​ C++ 优先队列实现:

1
2
priority_queue<int, vector<int>, greater<int>> minheap;
priority_queue<int> maxheap;

​ 分以下两种情况:

当总数数个时,大顶堆比小顶堆多一个元素,中位数就是大顶堆堆顶

当总数数个时,大顶堆比小顶堆元素一样多,中位数就是大顶堆和小顶堆堆顶的平均数

接下来我们考虑维护这个动态找寻中位数的数据结构 MedianFinder 接口如下:

  • insert(num):将一个数 num 加入数据结构;

  • erase(num): 将一个数 num 移出数据结构;

  • getMedian():返回当前数据结构中所有数的中位数

getMedian() 就如上所说:

1
2
3
4
5
6
double getMedian(){
    if (minheap.size() == maxheap.size() )
        return ((double)minheap.top() + maxheap.top()) / 2;//防范int溢出,增强鲁棒性
    else
        return (double)maxheap.top();
}

下面主要分析insert(num) 与 erase(num) 。

insert(num)

对于insert(num)而言,初始堆全空情况num加入大顶堆,然后加入的 num 与大顶堆顶元素比较,num 较小加入小顶堆,更大加入大顶堆。insert(num)之后还要注意保持两个堆的相对数量关系。即包含元素数量相等,或者储存了小半的大顶堆多一个元素。当初次失去平衡也即 大顶堆比小顶堆多了两个元素小顶堆比大顶堆多了一个元素 时把堆顶转移。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
void makebalance()
{
    if (maxheap.size() > minheap.size() + 1){
        minheap.push(maxheap.top());
        maxheap.pop();
    }
    else if (maxheap.size() < minheap.size() ){
        maxheap.push(minheap.top());
        minheap.pop();
    }
}

有了辅助函数makebalance()后,insert(num)就可以给出这样的参考实现了:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
void insert(int num)
{
    if (minheap.empty() && maxheap.empty())
        maxheap.push(num);
    else{
        int topnum = maxheap.top();
        if (topnum < num)
            minheap.push(num);
        else
            maxheap.push(num);
    }
    makebalance();
}

erase(num)

​ 由于堆是不支持移出非堆顶元素这一操作的,因此我们可以考虑使用「延迟删除」的技巧,即:

当我们需要移出优先队列中的某个元素时,我们只将这个删除操作「记录」下来,而不去真的删除这个元素。当这个元素出现在大顶堆或小顶堆堆顶时,我们再去将其移出。

​ 「延迟删除」使用到的辅助数据结构一般为哈希表 delayed,其中的每个键值对 (num,freq),表示元素 num 还需要被删除 freq 次。

​ C++ 哈希表实现:

1
unordered_map<int, int> delayed;

​ 增添辅助函数prune(T &heap)对每次堆顶改变后的堆进行修剪。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
template <typename T>
void prune(T &heap){
    while (!heap.empty()){
        int num = heap.top();
        if (delayed.count(num)){
            delayed[num]--;
            if (delayed[num] == 0)
                delayed.erase(num);
            heap.pop();
        }
        else
            break;
    }
}

​ 结合prune()makebalance()之后erase(num)的实现就容易了,要注意的一点是因为「延迟删除」,堆的 size 不再是当前真正的堆中元素数量,因为有一些应该要被删除的还暂时存放在堆之中,所以再设置全局变量maxSize,minSize,每当heap进行push时Size++,每当pop时Size–。根据Size做出大小判断。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
void erase(int num)
{
    delayed[num]++;
    if(num<=maxheap.top()){
        maxSize--;
        if (num == maxheap.top())
            prune(maxheap);
    }
    else{
        minSize--;
        if (num == minheap.top())
            prune(minheap);
    }
    makebalance();
}

​ 全局变量Size的思想对之前insert(num) 等函数皆同理,对前面一致改写后给出最后的参考实现MedianFinder如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <queue>
#include <unordered_map>
using namespace std;

class MedianFinder
{
    priority_queue<int, vector<int>, greater<int>> minheap;
    priority_queue<int> maxheap;
    unordered_map<int, int> delayed;
    int minSize,maxSize;    //decrease delayed
    template <typename T>
    void prune(T &heap)
    {
        while (!heap.empty())
        {
            int num = heap.top();
            if (delayed.count(num))
            {
                delayed[num]--;
                if (delayed[num] == 0)
                    delayed.erase(num);
                heap.pop();
            }
            else
                break;
        }
    }

    void makebalance()
    {
        if (maxSize > minSize + 1)
        {
            minheap.push(maxheap.top());
            maxheap.pop();
            minSize++;
            maxSize--;
            prune(maxheap);
        }
        else if (maxSize < minSize )
        {
            maxheap.push(minheap.top());
            minheap.pop();
            maxSize++;
            minSize--;
            prune(minheap);
        }
    }

public:
    MedianFinder():minSize(0),maxSize(0){}

    void insert(int num)
    {
        if (minheap.empty() && maxheap.empty()){
            maxheap.push(num);
            maxSize++;
        }
        else
        {
            int topnum = maxheap.top();
            if (topnum < num){
                minheap.push(num);
                minSize++;
            }
            else{
                maxheap.push(num);
                maxSize++;
            }
        }
        makebalance();
    }

    void erase(int num)
    {
        delayed[num]++;
        if(num<=maxheap.top()){
            maxSize--;
            if (num == maxheap.top())
                prune(maxheap);
        }
        else{
            minSize--;
            if (num == minheap.top())
                prune(minheap);
        }
        makebalance();
    }

    double getMedian()
    {
        if (minSize == maxSize )
            return ((double)minheap.top() + maxheap.top()) / 2;//防范int溢出
        else
            return (double)maxheap.top();
    }
};

代码里涉及细节处理问题颇多,还是很要仔细。