Code
注意 sizeMap[p] = max(sizeMap[p], size[i]);
,因為經過 union and find 後,同一個 connected component 中的 element 其 parent 都一樣,但是 size 會不一樣,因此要取 max。
class Solution {
public:
long long countPairs(int n, vector<vector<int>>& edges) {
typedef long long ll;
ll pairs = (ll)n * (ll)(n - 1) / 2;
vector<int> parent(n);
vector<int> size(n);
for(int i = 0; i < n; i++) {
parent[i] = i;
size[i] = 1;
}
for(int i = 0; i < edges.size(); i++) {
int p1 = findParent(parent, edges[i][0]);
int p2 = findParent(parent, edges[i][1]);
if(p1 != p2) {
if(size[p1] < size[p2]) {
size[p2] += size[p1];
parent[p1] = p2;
} else {
size[p1] += size[p2];
parent[p2] = p1;
}
}
}
unordered_map<int, int> sizeMap;
for(int i = 0; i < n; i++) {
int p = findParent(parent, i);
sizeMap[p] = max(sizeMap[p], size[i]);
}
for (auto& it: sizeMap) {
pairs -= (ll) it.second * (ll) (it.second - 1) / 2;
}
return pairs;
}
int findParent(vector<int> &parent, int node) {
if(parent[node] == node) return node;
return parent[node] = findParent(parent, parent[node]);
}
};