Sum of Distances in Tree

Sum of Distances in Tree

https://leetcode.com/problems/sum-of-distances-in-tree/

An undirected, connected tree with N nodes labelled 0...N-1 and N-1 edges are given.

The ith edge connects nodes edges[i][0] and edges[i][1] together.

Return a list ans, where ans[i] is the sum of the distances between node i and all other nodes.

Example 1:

Input: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]

Output: [8,12,6,10,10,10]

Explanation:

Here is a diagram of the given tree:

0

/ \

1 2

/|\

3 4 5

We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)

equals 1 + 1 + 2 + 2 + 2 = 8. Hence, answer[0] = 8, and so on.

Note: 1 <= N <= 10000

Thoughts

  1. 针对node 0, 分治遍历,每次返回当前节点到下面所有结点distance的sum和以它为子树的结点个数,那么它父亲的结点对应的distance就是每个子节点的distance + count[sub], 因为相当于下面每个结点的距离+1。针对所有结点1~N-1都这么做一次,总时间复杂度O(N^2).

  2. 经观察res[x] - res[y] = #(Y) - #(X). 因此从0开始做先序遍历,可依次算出res[child] = res[parent] - #(C) + #(P) = res[parent] - #(C) + N - #(C). 因此只要额外一次先序遍历即可,O(N).

https://leetcode.com/problems/sum-of-distances-in-tree/solution/

Code

class Solution {
public:
    vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
        vector<int> res(N, 0);
        if (N == 1) return res;
        vector<vector<int>> nodes;
        for (const auto e : edges) {
            nodes[e[0]].push_back(e[1]);
            nodes[e[1]].push_back(e[0]);
        }
        vector<int> count(N, 0);
        dfs(0, nodes, visited, res, count);
        visited.clear();
        dfs2(0, nodes, visited, res, count, N);
        return res;
    }

    void dfs(int node, const vector<vector<int>>& nodes, unordered_set<int>& visited, vector<int>& res, vector<int>& count) {
        visited.insert(node);
        for (const auto i : nodes[node]) {
            if (visited.find(i) == visited.end()) {
                dfs(i, nodes, visited, res, count);
                count[node] += count[i];
                res[node] += res[i] + count[i];
            }
        }
        ++count[node];
    }

    void dfs2(int node, const vector<vector<int>>& nodes, unordered_set<int>& visited, vector<int>& res, vector<int>& count, int N) {
        visited.insert(node);
        for (const auto i : nodes[node]) {
            if (visited.find(i) == visited.end()) {
                res[i] = res[node] - count[i] + N - count[i];
                dfs2(i, nodes, visited, res, count, N);
            }
        }
    }
};

Analysis

时间复杂度O(N)

Last updated