HackerRank: Spanning Tree Fraction 题解

题目:https://www.hackerrank.com/contests/w31/challenges/spanning-tree-fraction 一张连通图G=(V,E)上,每条边有a和b两个整数,求一个生成树使得Sum(a) / Sum(b) 最大,输出这个最大值的分数形式 p/q

\(\frac{\sum{a\_i}}{\sum{b\_i}} \ge c\) ,经过变换可得,\(\sum{(a\_i - b\_i c)} \ge 0\),这种形式下 ai - bi * c 就退化为一条边的 cost,能方便得用 Prim 或 Kruskal 算法求出 cost 最大的生成树。

因为我们知道 c 的取值是受限制的——显然不能任意大。根据题意,我们要求 c 的最大可能值 max(c),借助二分查找:如果 c 取值 (L+R)/2 时找不到生成树满足 sum cost >= 0,说明 max(c) 在右半边;反之,如果 c 的取值 (L+R)/2 时可行,说明在左半边(因为 C 是实数,这里说的都是闭区间)。 

Prim 和 Kruskal 算法的实现也是一个难点。本题中 Kruskal 算法更简单,从小到大取所有的边,利用并查集可以快速判断这条边是否已经被连通了,若还为连通就要选取此边。

最后的 p/q 怎么求呢?显然从实数 c 变回 p/q 是不可能的。假如循环 N 次,最后一次循环可以认为已经收敛了,最后一次计算中途的 \(\sum{a\_i}\)\(\sum{b\_i}\) 就是最优 case 下的取值,将 \(\frac{\sum{a\_i}}{\sum{b\_i}}\) 化简、消除公约数即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#include <numeric>
using namespace std;

int n, m;
const int N = 100002;
int u[N], v[N], a[N], b[N];

double w[N];
int p[N];

int set[N];

int find_set(int x) {
if (set[x] == x) return x;
return set[x] = find_set(set[x]);
}

bool union_set(int a, int b) {
a = find_set(a);
b = find_set(b);
if (a == b) return false;
set[a] = b;
return true;
}

int A, B;

bool check(double c) {
for (int i = 0; i < m; i++) w[i] = a[i] - b[i] * c;
sort(p, p + m, [](int i, int j){ return w[i] > w[j]; });
iota(set, set + n, 0);
A = B = 0;
for (int i = 0; i < m; i++) {
const int e = p[i];
if (union_set(u[e], v[e])) {
A += a[e];
B += b[e];
}
}
return A >= B * c;
}

int gcd(int a, int b) {
while (a && b) {
if (a > b) a %= b;
else b %= a;
}
return a | b;
}

int main() {
cin >> n >> m;
for (int i = 0; i < m; i++) {
scanf("%d %d %d %d", &u[i], &v[i], &a[i], &b[i]);
}
iota(p, p + m, 0);
double lo = 0, hi = 1e5;
for (int t = 0; t < 100; t++) {
double c = (hi + lo) * 0.5;
if (check(c)) lo = c;
else hi = c;
}
int g = gcd(A, B);
printf("%d/%d\n", A/g, B/g);
return 0;
}