/*
* @lc app=leetcode id=834 lang=cpp
*
* [834] Sum of Distances in Tree
*/
// @lc code=start
class Solution {
public:
vector<int> count, res;
vector<unordered_set<int>> tree;
void dfs(int cur, int pre) {
for (auto nei : tree[cur]) {
if (nei == pre) continue;
dfs(nei, cur);
count[cur] += count[nei];
// 每个结点到它下面所有结点的距离和 == sum(child到它下面所有结点和 + child以它下面结点个数)
res[cur] += res[nei] + count[nei];
}
}
void dfs2(int cur, int pre) {
for (auto nei : tree[cur]) {
if (nei == pre) continue;
// res[x] = x@x + y@y + #y
// res[y] = y@y + x@x + #x
// res[x] = res[y] + #y - #x = res[y] + (N - #x) - #x
res[nei] = res[cur] + (count.size() - count[nei]) - count[nei];
dfs2(nei, cur);
}
}
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
tree.resize(N);
res.assign(N, 0);
count.assign(N, 1);
for (const auto &e : edges) {
tree[e[0]].insert(e[1]);
tree[e[1]].insert(e[0]);
}
dfs(0, -1);
dfs2(0, -1);
return res;
}
};
// @lc code=end