[Baekjoon] 1167 - 트리의 지름(C++)
1. 문제
트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.
https://www.acmicpc.net/problem/1167
2. 해결 아이디어
1) 각 노드마다 갈 수 있는 최대 cost를 2개까지 저장한 뒤(재귀 방식이라 시간 복잡도 O(n)), 모든 노드의 최대 cost값 2개의 합을 비교하면서 가장 큰 최대 cost의 합을 출력한다. (= 트리의 지름)
한 노드에서 갈 수 있는 가장 먼 거리 (2개)의 합은 지름이다. 따라서 최대 cost의 합은 당연히 트리의 지름이 된다.
트리의 왼쪽 자손과 오른쪽 자손까지의 비용을 각각 더했을 때, 그 두 합이 최대가 되면 그 두 경로는 당연히 지름이 됨을 알 수 있다.
만약 자손이 없는 경우에는 초깃값을 0으로 설정함으로써 문제를 해결할 수 있다.
인접 행렬 방식으로 구현하면 메모리초과가 나서 인접 리스트로 구현했다.
그래서 방문했는지를 표시하는 visited[] 배열, 가중치를 저장하는 weights[][]배열, 노드 연결 관계를 저장하는 node[][]배열, 각 노드마다 2개씩 최대 cost를 저장하는 max_depth[][] 배열 이렇게 3개의 배열이 사용되었다.
2) 정석 풀이 - dfs를 2번 돌리면 된다.
1번째에서 dfs를 통해 얻은 최대 cost를 가지는 node에 가서 한번 더 dfs를 돌려 최대 cost를 구하면 그 값이 바로 트리의 지름이다.
직관적으로는 조금 생각하기 어려웠던 풀이였던 것 같다.
그럼 왜 dfs를 두 번 돌리면 트리의 지름이 나올까?
dfs를 한 번 돌리면 현재 노드에서 가장 먼 노드를 찾을 수 있다.
그리고 현재 노드에서 가장 먼 그 노드까지 가는 길은 반드시 지름의 일부를 포함한다. (1번째로 긴 구간과 2번째로 긴 구간 중 적어도 1개를 포함하므로)
따라서 가장 먼 노드는 지름의 한쪽 끝에 위치하는 노드이다. 그렇기 때문에 그 노드에서 dfs를 써주면 트리의 지름을 찾을 수 있다.
3. 코드
1) 번 풀이
// Baekjoon No. 1167
// Time Complexity
// #
#include <iostream>
#include <vector>
using namespace std;
int dfs(vector<vector<int>>& node, vector<int>& visited, vector<vector<int>>& weights, vector<vector<int>>& max_depth, int nodeIdx);
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int v, v1, v2, tmp, weight, tree_length;
int i;
cin >> v;
vector<int> visited(v, 0);
// vector<vector<int>> weights(v, vector<int>(v, 0));
vector<vector<int>> weights(v);
vector<vector<int>> node(v);
vector<vector<int>> max_depth(v, vector<int>(2, 0));
for (i = 0; i < v; i++) {
cin >> v1 >> v2 >> weight;
v1--;
v2--;
// node[idx]의 2k-1번째는 nodeIdx, 2k번째는 weight
node[v1].push_back(v2);
weights[v1].push_back(weight);
// node[v1].push_back(weight);
cin >> tmp;
while (tmp > 0) {
v2 = tmp;
v2--;
cin >> weight;
// save
node[v1].push_back(v2);
weights[v1].push_back(weight);
cin >> tmp;
}
}
// solve
/** 풀이법
1) 루트 노드를 하나 정한다
2) 루트 노드에서부터 dfs로 각 노드의 끝까지 탐색, 만약 depth요소보다 크다면 최대 depth에 갱신한다.
3) max_depth는 현재 노드에서 뻗을 수 있는 노드들 중 가장 긴 길이 2개를 저장한다.
4) 각 노드에서의 max_depth 2개의 합들 중(자식 node가 1개여도 초기화가 0으로 되어 있어 상관없음)
가장 큰 합 = 지름이므로 지름 값을 구해 출력한다.
*/
// root node = 0
visited[0] = 1;
dfs(node, visited, weights, max_depth, 0);
// output
tree_length = 0;
for (i = 0; i < v; i++)
if (max_depth[i][0] + max_depth[i][1] > tree_length)
tree_length = max_depth[i][0] + max_depth[i][1];
cout << tree_length;
return 0;
}
// returns the biggest depth current node can go
int dfs(vector<vector<int>>& node, vector<int>& visited, vector<vector<int>>& weights, vector<vector<int>>& max_depth, int nodeIdx){
int size = node[nodeIdx].size(), tmp_depth = 0;
for (int i = 0; i < size; i ++) { // node[][]
if (!visited[node[nodeIdx][i]]) {
visited[node[nodeIdx][i]] = 1;
tmp_depth = dfs(node, visited, weights, max_depth, node[nodeIdx][i]);
if (max_depth[nodeIdx][1] < tmp_depth + weights[nodeIdx][i]) {
max_depth[nodeIdx][1] = tmp_depth + weights[nodeIdx][i];
if (max_depth[nodeIdx][0] < max_depth[nodeIdx][1]) {
int tmp = max_depth[nodeIdx][0];
max_depth[nodeIdx][0] = max_depth[nodeIdx][1];
max_depth[nodeIdx][1] = tmp;
}
}
}
}
return max_depth[nodeIdx][0];
}
2) 번 풀이
// Baekjoon No. 1167
// Time Complexity
// #
#include <iostream>
#include <vector>
using namespace std;
vector<int> dfs(vector<vector<int>>& node, vector<vector<int>>& weights, vector<int>& visited, int nodeIdx, int depth);
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int v, v1, v2, tmp, weight;
int i;
cin >> v;
vector<int> visited(v, 0);
vector<vector<int>> node(v);
vector<vector<int>> weights(v);
for (i = 0; i < v; i++) {
cin >> v1 >> v2 >> weight >> tmp;
v1--;
v2--;
node[v1].push_back(v2);
weights[v1].push_back(weight);
while (tmp > 0) {
v2 = tmp;
cin >> weight;
cin >> tmp;
v2--;
node[v1].push_back(v2);
weights[v1].push_back(weight);
}
}
// solve
visited[0] = 1;
vector<int> node_weight = dfs(node, weights, visited, 0, 0);
// visited init
for (i = 0; i < v; i++)
visited[i] = 0;
visited[node_weight[0]] = 1;
node_weight = dfs(node, weights, visited, node_weight[0], 0);
// output
cout << node_weight[1];
return 0;
}
vector<int> dfs(vector<vector<int>>& node, vector<vector<int>>& weights, vector<int>& visited, int nodeIdx, int depth) {
int size = node[nodeIdx].size();
vector<int> node_weight = { nodeIdx, depth }, tmp = {nodeIdx, depth};
for (int i = 0; i < size; i++) {
if (!visited[node[nodeIdx][i]]) {
visited[node[nodeIdx][i]] = 1;
tmp = dfs(node, weights, visited, node[nodeIdx][i], depth + weights[nodeIdx][i]);
if (tmp[1] > node_weight[1]) {
node_weight[0] = tmp[0];
node_weight[1] = tmp[1];
}
}
}
return node_weight;
}
4. 배운 점
1) 번 풀이에서 계속 시간초과가 나길래 분명 알고리즘은 같은데 왜 시간 초과가 나지.. 하고 고민했었다.
그런데, 함수 인자로 이차원 배열을 보내줄때 실수로 포인터가 아닌 값으로 보내줘서 생긴 문제였다.
포인터로 보내 주면 속도가 빠르다는 건 알고 있었지만, 이걸로 시간 초과가 날 줄은 몰랐다.