백준 17526번: Star Trek [C++]

다이아몬드 V 다이아몬드 V

문제

Star Trek

풀이

$i$번째 행성의 위치를 $pos_i$, 우주선의 준비 시간을 $pre_i$, 속도를 $pace_i$라 하고 $dp[i]$를 $i$번째 행성까지 여행하는 데 드는 시간의 최솟값이라 합시다. 그렇다면 아래 점화식을 통해 $dp[n]$을 찾을 수 있습니다.

\[dp[i] = \min(dp[j] + (pos_i - pos_j) \times pace_j + pre_j)\]

식을 전개해서 아래처럼 써 봅시다.

\[dp[i] = \min(pace_j \times pos_i - pace_j \times pos_j + dp[j] + pre_j)\]

$\min$ 안의 식이 변수가 $pos_i$, 기울기가 $pace_j$, y 절편이 $-pace_j \times pos_j + dp[j] + pre_j$인 일차함수꼴임을 알 수 있고, 따라서 볼록 껍질을 이용해 최적화할 수 있습니다. 다만 기울기의 단조조건이 없으므로 리-차오 트리, BBST 기반 LineContainer, 제곱근 분할법 등의 방법이 필요합니다.

이 글에서는 가장 구현이 쉬운 제곱근 분할법 풀이에 대해 설명하겠습니다.

제곱근 분할법

구사과님의 블로그에서 본 방법입니다.

직선들을 기울기 단조성을 유지하도록 저장하는 주 배열과 그렇지 않은 버퍼까지 총 두 개의 배열을 준비합니다. 새로 삽입된 직선이 단조성을 유지하면서 삽입될 수 없다면 버퍼에 삽입합니다.

버퍼의 크기가 $\sqrt n$ 을 초과한다면 기울기 단조성을 만족하도록 주 배열에 병합 정렬하듯이 합쳐줍니다.

쿼리를 처리할 때에는 주 배열에는 이분 탐색을, 버퍼에 담긴 직선은 그냥 하나씩 돌면서 처리하면 시간 복잡도 $\mathcal{O}(\log n + \sqrt n)$에 처리할 수 있습니다.

쿼리는 최대 $n$번 주어지므로 시간 복잡도는 $\mathcal{O}(n\sqrt n)$입니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
#ifndef ONLINE_JUDGE
#define ASSERT(x) assert(x)
#else
#define ASSERT(ignore) ((void)0)
#endif

using namespace std;

using ll = long long;
using ldouble = long double;

struct Line {
    ll slope, yIntercept;
    ldouble start;

    Line(ll _slope, ll _yIntercept)
        : slope(_slope), yIntercept(_yIntercept), start(-INFINITY) {}

    ldouble getIntersection(const Line& l) const {
        return (ldouble)(l.yIntercept - yIntercept) / (slope - l.slope);
    }

    ll operator()(ll x) const { return slope * x + yIntercept; }
};

// lower hull
struct LineContainer {
    size_t bufSize;
    vector<Line> lines, buffer;

    LineContainer(size_t _bufSize) : bufSize(_bufSize) {}

    void push(Line l) {
        while (not lines.empty()) {
            const Line& back = lines.back();

            if (l.slope > back.slope) {
                buffer.push_back(l);
                return;
            }

            if (l.slope == back.slope) {
                if (l.yIntercept < back.yIntercept) {
                    lines.pop_back();
                    continue;
                }
                return;
            }

            l.start = l.getIntersection(back);
            if (l.start <= back.start)
                lines.pop_back();
            else
                break;
        }

        lines.push_back(l);
    }

    void flush() {
        sort(buffer.begin(), buffer.end(), [](const Line& l1, const Line& l2) {
            return l1.slope > l2.slope;
        });

        vector<Line> oldLines;
        swap(oldLines, lines);

        int i = 0, j = 0;
        while (i < (int)oldLines.size() and j < (int)buffer.size()) {
            if (oldLines[i].slope > buffer[j].slope)
                push(oldLines[i++]);
            else
                push(buffer[j++]);
        }

        while (i < (int)oldLines.size())
            push(oldLines[i++]);
        while (j < (int)buffer.size())
            push(buffer[j++]);

        buffer.clear();
    }

    ll query(ll x) {
        if (buffer.size() > bufSize)
            flush();

        ll ret = (*(upper_bound(lines.begin(), lines.end(), (ldouble)x,
                                [](ldouble val, const Line& l) {
                                    return val < l.start;
                                }) -
                    1))(x);

        for (const Line& l : buffer)
            ret = min(ret, l(x));

        return ret;
    }
};

ll dp[100100], pos[100100];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;

    for (int i = 2; i <= n; i++) {
        cin >> pos[i];
        pos[i] += pos[i - 1];
    }

    LineContainer lineContainer(sqrt(n));

    ll pre, pace;
    cin >> pre >> pace;
    lineContainer.push(Line(pace, pre));

    for (int i = 2; i < n; i++) {
        cin >> pre >> pace;

        ll totalTime = lineContainer.query(pos[i]);
        lineContainer.push(Line(pace, -pace * pos[i] + totalTime + pre));
    }

    cout << lineContainer.query(pos[n]);

    return 0;
}

백준 17526번: Star Trek [C++]