1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| #include <iostream> #include <vector> #include <queue> #include <functional> #include <algorithm> #include <cstdlib> #include <ctime> using namespace std;
class MaxKNumberSolver { public: virtual vector<int> Solve(vector<int> &input, int k) = 0; };
class MaxKNumber_Heap : public MaxKNumberSolver { public: vector<int> Solve(vector<int> &input, int k) { priority_queue<int, vector<int>, greater<int>> heap; int n = input.size(); if (n <= k) return input; for (int i = 0; i < k; i++) { heap.push(input[i]); } for (int i = k; i < n; i++) { if (input[i] > heap.top()) { heap.pop(); heap.push(input[i]); } } vector<int> result; while (!heap.empty()) { result.push_back(heap.top()); heap.pop(); } return result; } };
class MaxKNumber_QuickSelect : public MaxKNumberSolver { int partition(vector<int> &input, int first, int last) { int i = first; for (int j = first; j < last; j++) { if (input[j] < input[last]) swap(input[i++], input[j]); } swap(input[i], input[last]); return i; }
int quickSelect(vector<int> &input, int first, int last, int k) { int pivot = partition(input, first, last); if (pivot == k) return input[k]; else if (pivot < k) return quickSelect(input, pivot + 1, last, k); else if (pivot > k) return quickSelect(input, first, pivot - 1, k); }
public: vector<int> Solve(vector<int> &input, int k) { int n = input.size(); quickSelect(input, 0, n - 1, n - k); vector<int> result; for (int i = n - k; i < n; i++) { result.push_back(input[i]); } return result; } };
int main() { clock_t cbegin, cend; vector<int> test; for (int i = 0; i < 1000000; i++) { test.push_back(rand()); } MaxKNumber_Heap solverHeap; MaxKNumber_QuickSelect solverQS; cbegin = clock(); vector<int> r1 = solverHeap.Solve(test, 1000); cend = clock(); printf("MaxKNumber_Heap: %d ms\n", cend - cbegin); cbegin = clock(); vector<int> r2 = solverQS.Solve(test, 1000); cend = clock(); printf("MaxKNumber_QuickSelect: %d ms\n", cend - cbegin); sort(r2.begin(), r2.end()); for (int i = 0; i < 50; i++) { if (r1[i] != r2[i]) { cout << "Failed\n"; return 0; } } cout << "Success\n"; }
|