백준 1289번: 트리의 가중치 [C++]

플래티넘 III 플래티넘 III

문제

트리의 가중치

풀이

$u$에서 $v$로 가는 경로의 가중치를 $\operatorname{W}(u, v)$라 합시다.

노드 $v$와 $v$의 서브트리에 속하는 노드 $u$에 대해서 $\operatorname{up}[v] = \sum\operatorname{W}(u, v)$로 정의합시다. 또, $v$의 서브트리에 속하는 두 노드 $a, b$에 대해 $\operatorname{in}[v] = \sum\operatorname{W}(a, b)$로 정의합시다.

트리의 루트를 임의로 1로 잡으면 구하고자 하는 답은 $\operatorname{in}[1]$입니다.

$v$의 자식 $c$에 대해 $c$와 $v$를 잇는 간선의 가중치를 $w$라 합시다. 그렇다면 $\operatorname{up}[v] = \sum (\operatorname{up}[c] + 1) \times w$ 입니다.

$\operatorname{in}[v]$의 경우에는 경로가 $c$의 서브트리 내에만 포함되는 경우, $c$의 서브트리에서 $v$로 가는 경우, 한 자식 서브트리에서 다른 자식 서브트리로 가는 경우로 나눌 수 있습니다.

경로가 $c$의 서브트리 내에만 포함되는 경우는 $\sum \operatorname{in}[c]$으로 구할 수 있습니다. $c$의 서브트리에서 $v$로 가는 경우는 $\sum (\operatorname{up}[c] + 1) \times w = \operatorname{up}[v]$로 구할 수 있습니다.

한 자식 서브트리에서 다른 자식 서브트리로 가는 경우를 $\mathcal{O}(N)$에 처리하는 것이 관건입니다.

$upSum = \operatorname{up}[v]$로 초기화합시다. 각 자식 $c$에 대해 $upSum$에서 $(\operatorname{up}[c] + 1) \times w$를 뺀 뒤 $upSum$과 $(\operatorname{up}[c] + 1) \times w$를 곱하면 $c$의 서브트리에서 다른 자식 서브트리로 가는 경우를 중복 없이 셀 수 있습니다.

좀 더 자세히 설명해 보겠습니다. $(\operatorname{up}[c] + 1) \times w$를 $f(c)$로 나타내 봅시다. 우리가 구하고 싶은 것은 $v$의 자식 $a, b, c, \dots$에서 두개를 뽑아 곱한 것을 모두 더한 값 $f(a)f(b) + f(a)f(c) + \dots + f(b)f(c) + f(b)f(d) + \dots$입니다.

두개를 뽑을 때 하나가 $a$인 경우는 $f(a)(f(b) + f(c) + f(d) + \dots)$로 계산할 수 있습니다. 이는 $f(a)(\operatorname{up}[v] - f(a))$와 같음을 알 수 있습니다.

두개를 뽑을 때 하나가 $b$인 경우는 $f(b)(f(a) + f(c) + f(d) + \dots)$이지만 $a$가 뽑히는 경우는 이미 계산했으므로 이를 제외하면 $f(b)(f(c) + f(d) + \dots)$이고 이는 $f(b)(\operatorname{up}[v] - f(a) - f(b))$와 같음을 알 수 있습니다.

따라서 $upSum = \operatorname{up}[v]$로 초기화한 뒤 $upSum$에서 $f(c)$를 빼고 $upSum$과 $f(c)$를 곱한 값을 모두 더해주면 한 자식 서브트리에서 다른 자식 서브트리로 가는 경로의 가중치의 합을 구할 수 있습니다.

시간 복잡도는 $\mathcal{O}(N)$입니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

constexpr ll MOD = 1e9 + 7;

vector<pair<int, ll>> adj[100'100];
ll in[100'100], up[100'100];

void solve(int node, int parent) {
    for (auto [child, weight] : adj[node]) {
        if (child == parent)
            continue;

        solve(child, node);

        up[node] += (up[child] + 1) * weight % MOD;
        up[node] %= MOD;

        in[node] += in[child];
        in[node] %= MOD;
    }
    in[node] += up[node];

    ll upSum = up[node];
    for (auto [child, weight] : adj[node]) {
        if (child == parent)
            continue;

        upSum -= (up[child] + 1) * weight % MOD;
        upSum = (upSum + MOD) % MOD;

        in[node] += ((up[child] + 1) * weight % MOD) * upSum % MOD;
        in[node] %= MOD;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    cin >> N;

    for (int i = 0; i < N - 1; i++) {
        int A, B, W;
        cin >> A >> B >> W;

        adj[A].emplace_back(B, W);
        adj[B].emplace_back(A, W);
    }

    solve(1, -1);
    cout << in[1];

    return 0;
}

백준 1289번: 트리의 가중치 [C++]