[NOTE] Updated August 27, 2023. This article may have outdated content or subject matter.
《算法设计与分析(第2版)》黄宇著 14章 堆与偏序关系
14.6 有一组元素,它们不断地被动态加入和删除,但是我们需要随时找出当前所有元素的中位数。为此,请设计一个数据结构,以支持对数时间的插入、删除和常数时间的找出中位数(提示:利用两个堆来实现该数据结构)。
2021.2.5更新:本文属于对顶堆求数据量第k大数的特例:求中位数,基本思想和代码可以从下述特例中泛化。
首先,我们思考一下中位数的性质:如果一个数是中位数,那么在这个数组中,大于中位数的数目和小于中位数的数目,要么相等,要么就相差一。
我们自然可以考虑到利用堆维护这个数据结构,大顶堆存放前半小的数,小顶堆存放后半大的数。
C++ 优先队列实现:
1
2
|
priority_queue<int, vector<int>, greater<int>> minheap;
priority_queue<int> maxheap;
|
分以下两种情况:
当总数奇数个时,大顶堆比小顶堆多一个元素,中位数就是大顶堆堆顶。
当总数偶数个时,大顶堆比小顶堆元素一样多,中位数就是大顶堆和小顶堆堆顶的平均数。
接下来我们考虑维护这个动态找寻中位数的数据结构 MedianFinder 接口如下:
-
insert(num):将一个数 num 加入数据结构;
-
erase(num): 将一个数 num 移出数据结构;
-
getMedian():返回当前数据结构中所有数的中位数。
getMedian() 就如上所说:
1
2
3
4
5
6
|
double getMedian(){
if (minheap.size() == maxheap.size() )
return ((double)minheap.top() + maxheap.top()) / 2;//防范int溢出,增强鲁棒性
else
return (double)maxheap.top();
}
|
下面主要分析insert(num) 与 erase(num) 。
insert(num)
对于insert(num)而言,初始堆全空情况num加入大顶堆,然后加入的 num 与大顶堆顶元素比较,num 较小加入小顶堆,更大加入大顶堆。insert(num)之后还要注意保持两个堆的相对数量关系。即包含元素数量相等
,或者储存了小半的大顶堆多一个元素
。当初次失去平衡也即 大顶堆比小顶堆多了两个元素
或 小顶堆比大顶堆多了一个元素
时把堆顶转移。
1
2
3
4
5
6
7
8
9
10
11
|
void makebalance()
{
if (maxheap.size() > minheap.size() + 1){
minheap.push(maxheap.top());
maxheap.pop();
}
else if (maxheap.size() < minheap.size() ){
maxheap.push(minheap.top());
minheap.pop();
}
}
|
有了辅助函数makebalance()
后,insert(num)就可以给出这样的参考实现了:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
void insert(int num)
{
if (minheap.empty() && maxheap.empty())
maxheap.push(num);
else{
int topnum = maxheap.top();
if (topnum < num)
minheap.push(num);
else
maxheap.push(num);
}
makebalance();
}
|
erase(num)
由于堆是不支持移出非堆顶元素这一操作的,因此我们可以考虑使用「延迟删除」的技巧,即:
当我们需要移出优先队列中的某个元素时,我们只将这个删除操作「记录」下来,而不去真的删除这个元素。当这个元素出现在大顶堆或小顶堆堆顶时,我们再去将其移出。
「延迟删除」使用到的辅助数据结构一般为哈希表 delayed,其中的每个键值对 (num,freq),表示元素 num 还需要被删除 freq 次。
C++ 哈希表实现:
1
|
unordered_map<int, int> delayed;
|
增添辅助函数prune(T &heap)
对每次堆顶改变后的堆进行修剪。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
template <typename T>
void prune(T &heap){
while (!heap.empty()){
int num = heap.top();
if (delayed.count(num)){
delayed[num]--;
if (delayed[num] == 0)
delayed.erase(num);
heap.pop();
}
else
break;
}
}
|
结合prune()
和makebalance()
之后erase(num)的实现就容易了,要注意的一点是因为「延迟删除」,堆的 size 不再是当前真正的堆中元素数量,因为有一些应该要被删除的还暂时存放在堆之中,所以再设置全局变量maxSize,minSize,每当heap进行push时Size++,每当pop时Size–。根据Size做出大小判断。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
void erase(int num)
{
delayed[num]++;
if(num<=maxheap.top()){
maxSize--;
if (num == maxheap.top())
prune(maxheap);
}
else{
minSize--;
if (num == minheap.top())
prune(minheap);
}
makebalance();
}
|
全局变量Size的思想对之前insert(num) 等函数皆同理,对前面一致改写后给出最后的参考实现MedianFinder
如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
#include <queue>
#include <unordered_map>
using namespace std;
class MedianFinder
{
priority_queue<int, vector<int>, greater<int>> minheap;
priority_queue<int> maxheap;
unordered_map<int, int> delayed;
int minSize,maxSize; //decrease delayed
template <typename T>
void prune(T &heap)
{
while (!heap.empty())
{
int num = heap.top();
if (delayed.count(num))
{
delayed[num]--;
if (delayed[num] == 0)
delayed.erase(num);
heap.pop();
}
else
break;
}
}
void makebalance()
{
if (maxSize > minSize + 1)
{
minheap.push(maxheap.top());
maxheap.pop();
minSize++;
maxSize--;
prune(maxheap);
}
else if (maxSize < minSize )
{
maxheap.push(minheap.top());
minheap.pop();
maxSize++;
minSize--;
prune(minheap);
}
}
public:
MedianFinder():minSize(0),maxSize(0){}
void insert(int num)
{
if (minheap.empty() && maxheap.empty()){
maxheap.push(num);
maxSize++;
}
else
{
int topnum = maxheap.top();
if (topnum < num){
minheap.push(num);
minSize++;
}
else{
maxheap.push(num);
maxSize++;
}
}
makebalance();
}
void erase(int num)
{
delayed[num]++;
if(num<=maxheap.top()){
maxSize--;
if (num == maxheap.top())
prune(maxheap);
}
else{
minSize--;
if (num == minheap.top())
prune(minheap);
}
makebalance();
}
double getMedian()
{
if (minSize == maxSize )
return ((double)minheap.top() + maxheap.top()) / 2;//防范int溢出
else
return (double)maxheap.top();
}
};
|
代码里涉及细节处理问题颇多,还是很要仔细。
Author
lawrshen
LastMod
2023-08-27
(73a236c)
License
CC BY-NC-ND 4.0