Sum of Distances in Tree
Sum of Distances in Tree
https://leetcode.com/problems/sum-of-distances-in-tree/
An undirected, connected tree with N nodes labelled 0...N-1 and N-1 edges are given.
The ith edge connects nodes edges[i][0] and edges[i][1] together.
Return a list ans, where ans[i] is the sum of the distances between node i and all other nodes.
Example 1:
Input: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation:
Here is a diagram of the given tree:
0
/ \
1 2
/|\
3 4 5
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8. Hence, answer[0] = 8, and so on.
Note: 1 <= N <= 10000
Thoughts
针对node 0, 分治遍历,每次返回当前节点到下面所有结点distance的sum和以它为子树的结点个数,那么它父亲的结点对应的distance就是每个子节点的distance + count[sub], 因为相当于下面每个结点的距离+1。针对所有结点1~N-1都这么做一次,总时间复杂度O(N^2).
经观察res[x] - res[y] = #(Y) - #(X). 因此从0开始做先序遍历,可依次算出res[child] = res[parent] - #(C) + #(P) = res[parent] - #(C) + N - #(C). 因此只要额外一次先序遍历即可,O(N).
https://leetcode.com/problems/sum-of-distances-in-tree/solution/

Code
class Solution {
public:
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
vector<int> res(N, 0);
if (N == 1) return res;
vector<vector<int>> nodes;
for (const auto e : edges) {
nodes[e[0]].push_back(e[1]);
nodes[e[1]].push_back(e[0]);
}
vector<int> count(N, 0);
dfs(0, nodes, visited, res, count);
visited.clear();
dfs2(0, nodes, visited, res, count, N);
return res;
}
void dfs(int node, const vector<vector<int>>& nodes, unordered_set<int>& visited, vector<int>& res, vector<int>& count) {
visited.insert(node);
for (const auto i : nodes[node]) {
if (visited.find(i) == visited.end()) {
dfs(i, nodes, visited, res, count);
count[node] += count[i];
res[node] += res[i] + count[i];
}
}
++count[node];
}
void dfs2(int node, const vector<vector<int>>& nodes, unordered_set<int>& visited, vector<int>& res, vector<int>& count, int N) {
visited.insert(node);
for (const auto i : nodes[node]) {
if (visited.find(i) == visited.end()) {
res[i] = res[node] - count[i] + N - count[i];
dfs2(i, nodes, visited, res, count, N);
}
}
}
};
Analysis
时间复杂度O(N)
Last updated
Was this helpful?