“取最大的K的数”的两种解法

输入n个整数,输出其中最大的k个。例如输入1,2,3,4,5,6,7和8这8个数字,则最小的4个数字为5,6,7和8。

解法一

开一个容量为K的最堆。每次来一个数,比较堆顶(最小值)和这个数谁大,如果当前的数更大,就替换掉堆顶。时间复杂度 \(O(n \log k)\)

优点:对于海量数据(流),可以用\(O(k)\)的内存搞定

缺点:堆的实现略复杂,建议直接用STL

解法二

用QuickSelect(快速选择)算法,是基于快排的一种变体,平均复杂度为 \(O(n)\)

优点:平均情况下速度快

缺点:需要保存所有输入数据,并且会修改输入数据,对于海量数据不合适。

代码

包含一些测试写在main()里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <iostream>
#include <vector>
#include <queue>
#include <functional>
#include <algorithm>
#include <cstdlib>
#include <ctime>
using namespace std;

class MaxKNumberSolver {
public:
virtual vector<int> Solve(vector<int> &input, int k) = 0;
};

class MaxKNumber_Heap : public MaxKNumberSolver {
public:
vector<int> Solve(vector<int> &input, int k) {
priority_queue<int, vector<int>, greater<int>> heap; // MinHeap
int n = input.size();
if (n <= k) return input;
for (int i = 0; i < k; i++) {
heap.push(input[i]);
}
for (int i = k; i < n; i++) {
if (input[i] > heap.top()) {
heap.pop();
heap.push(input[i]);
}
}
vector<int> result;
while (!heap.empty()) {
result.push_back(heap.top());
heap.pop();
}
return result;
}
};

class MaxKNumber_QuickSelect : public MaxKNumberSolver {
int partition(vector<int> &input, int first, int last) {
int i = first;
for (int j = first; j < last; j++) {
if (input[j] < input[last]) swap(input[i++], input[j]);
}
swap(input[i], input[last]);
return i;
}

int quickSelect(vector<int> &input, int first, int last, int k) {
int pivot = partition(input, first, last);
if (pivot == k) return input[k];
else if (pivot < k) return quickSelect(input, pivot + 1, last, k);
else if (pivot > k) return quickSelect(input, first, pivot - 1, k);
}

public:
vector<int> Solve(vector<int> &input, int k) {
int n = input.size();
quickSelect(input, 0, n - 1, n - k);
vector<int> result;
for (int i = n - k; i < n; i++) {
result.push_back(input[i]);
}
return result;
}
};

int main() {
clock_t cbegin, cend;
vector<int> test;
for (int i = 0; i < 1000000; i++) {
test.push_back(rand());
}
MaxKNumber_Heap solverHeap;
MaxKNumber_QuickSelect solverQS;
cbegin = clock();
vector<int> r1 = solverHeap.Solve(test, 1000);
cend = clock();
printf("MaxKNumber_Heap: %d ms\n", cend - cbegin);
cbegin = clock();
vector<int> r2 = solverQS.Solve(test, 1000);
cend = clock();
printf("MaxKNumber_QuickSelect: %d ms\n", cend - cbegin);
sort(r2.begin(), r2.end()); // convenient for comparing
for (int i = 0; i < 50; i++) {
if (r1[i] != r2[i]) {
cout << "Failed\n";
return 0;
}
}
cout << "Success\n";
}

结果

\(n=10^6, k=1000\)