알고리즘_PS/Baekjoon

[Baekjoon] 1167 - 트리의 지름(C++)

hanseongjun 2022. 9. 30. 17:56

1. 문제

트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.

https://www.acmicpc.net/problem/1167

 

1167번: 트리의 지름

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지

www.acmicpc.net

2. 해결 아이디어

1) 각 노드마다 갈 수 있는 최대 cost를 2개까지 저장한 뒤(재귀 방식이라 시간 복잡도 O(n)), 모든 노드의 최대 cost값 2개의 합을 비교하면서 가장 큰 최대 cost의 합을 출력한다. (= 트리의 지름)

한 노드에서 갈 수 있는 가장 먼 거리 (2개)의 합은 지름이다. 따라서 최대 cost의 합은 당연히 트리의 지름이 된다.

 

트리의 왼쪽 자손과 오른쪽 자손까지의 비용을 각각 더했을 때, 그 두 합이 최대가 되면 그 두 경로는 당연히 지름이 됨을 알 수 있다.
만약 자손이 없는 경우에는 초깃값을 0으로 설정함으로써 문제를 해결할 수 있다.

 

인접 행렬 방식으로 구현하면 메모리초과가 나서 인접 리스트로 구현했다.

그래서 방문했는지를 표시하는 visited[] 배열, 가중치를 저장하는 weights[][]배열, 노드 연결 관계를 저장하는 node[][]배열, 각 노드마다 2개씩 최대 cost를 저장하는 max_depth[][] 배열 이렇게 3개의 배열이 사용되었다.

 

 

2) 정석 풀이 - dfs를 2번 돌리면 된다.

1번째에서 dfs를 통해 얻은 최대 cost를 가지는 node에 가서 한번 더 dfs를 돌려 최대 cost를 구하면 그 값이 바로 트리의 지름이다.

직관적으로는 조금 생각하기 어려웠던 풀이였던 것 같다.

 

그럼 왜 dfs를 두 번 돌리면 트리의 지름이 나올까?

dfs를 한 번 돌리면 현재 노드에서 가장 먼 노드를 찾을 수 있다.

그리고 현재 노드에서 가장 먼 그 노드까지 가는 길은 반드시 지름의 일부를 포함한다. (1번째로 긴 구간과 2번째로 긴 구간 중 적어도 1개를 포함하므로)

따라서 가장 먼 노드는 지름의 한쪽 끝에 위치하는 노드이다. 그렇기 때문에 그 노드에서 dfs를 써주면 트리의 지름을 찾을 수 있다.

3. 코드

1) 번 풀이

// Baekjoon No. 1167
// Time Complexity
// #

#include <iostream>
#include <vector>
using namespace std;

int dfs(vector<vector<int>>& node, vector<int>& visited, vector<vector<int>>& weights, vector<vector<int>>& max_depth, int nodeIdx);
int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int v, v1, v2, tmp, weight, tree_length;
    int i;
    cin >> v;
    vector<int> visited(v, 0);
    // vector<vector<int>> weights(v, vector<int>(v, 0));
    vector<vector<int>> weights(v);
    vector<vector<int>> node(v);
    vector<vector<int>> max_depth(v, vector<int>(2, 0));
    for (i = 0; i < v; i++) {
        cin >> v1 >> v2 >> weight;
        v1--;
        v2--;
        // node[idx]의 2k-1번째는 nodeIdx, 2k번째는 weight
        node[v1].push_back(v2);
        weights[v1].push_back(weight);
        // node[v1].push_back(weight);
        cin >> tmp;
        while (tmp > 0) {
            v2 = tmp;
            v2--;
            cin >> weight;
            // save
            node[v1].push_back(v2);
            weights[v1].push_back(weight);
            cin >> tmp;
        }
    }

    // solve
    /** 풀이법
    1) 루트 노드를 하나 정한다
    2) 루트 노드에서부터 dfs로 각 노드의 끝까지 탐색, 만약 depth요소보다 크다면 최대 depth에 갱신한다.
    3) max_depth는 현재 노드에서 뻗을 수 있는 노드들 중 가장 긴 길이 2개를 저장한다.
    4) 각 노드에서의 max_depth 2개의 합들 중(자식 node가 1개여도 초기화가 0으로 되어 있어 상관없음)
    가장 큰 합 = 지름이므로 지름 값을 구해 출력한다.
    */
    // root node = 0
    visited[0] = 1;
    dfs(node, visited, weights, max_depth, 0);
    // output
    tree_length = 0;
    for (i = 0; i < v; i++)
        if (max_depth[i][0] + max_depth[i][1] > tree_length)
            tree_length = max_depth[i][0] + max_depth[i][1];

    cout << tree_length;
    return 0;
}

// returns the biggest depth current node can go
int dfs(vector<vector<int>>& node, vector<int>& visited, vector<vector<int>>& weights, vector<vector<int>>& max_depth, int nodeIdx){
    int size = node[nodeIdx].size(), tmp_depth = 0;
    for (int i = 0; i < size; i ++) { // node[][]
        if (!visited[node[nodeIdx][i]]) {
            visited[node[nodeIdx][i]] = 1;
            tmp_depth = dfs(node, visited, weights, max_depth, node[nodeIdx][i]);
            if (max_depth[nodeIdx][1] < tmp_depth + weights[nodeIdx][i]) {
                max_depth[nodeIdx][1] = tmp_depth + weights[nodeIdx][i];
                if (max_depth[nodeIdx][0] < max_depth[nodeIdx][1]) {
                    int tmp = max_depth[nodeIdx][0];
                    max_depth[nodeIdx][0] = max_depth[nodeIdx][1];
                    max_depth[nodeIdx][1] = tmp;
                }
            }
        }
    }
    return max_depth[nodeIdx][0];
}

 

2) 번 풀이

// Baekjoon No. 1167
// Time Complexity
// #

#include <iostream>
#include <vector>
using namespace std;

vector<int> dfs(vector<vector<int>>& node, vector<vector<int>>& weights, vector<int>& visited, int nodeIdx, int depth);
int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int v, v1, v2, tmp, weight;
    int i;
    cin >> v;
    vector<int> visited(v, 0);
    vector<vector<int>> node(v);
    vector<vector<int>> weights(v);
    for (i = 0; i < v; i++) {
        cin >> v1 >> v2 >> weight >> tmp;
        v1--;
        v2--;
        node[v1].push_back(v2);
        weights[v1].push_back(weight);
        while (tmp > 0) {
            v2 = tmp;
            cin >> weight;
            cin >> tmp;
            v2--;
            node[v1].push_back(v2);
            weights[v1].push_back(weight);
        }
    }

    // solve
    visited[0] = 1;
    vector<int> node_weight = dfs(node, weights, visited, 0, 0);
    // visited init
    for (i = 0; i < v; i++)
        visited[i] = 0;
    visited[node_weight[0]] = 1;
    node_weight = dfs(node, weights, visited, node_weight[0], 0);
    
    // output
    cout << node_weight[1];
    return 0;
}
vector<int> dfs(vector<vector<int>>& node, vector<vector<int>>& weights, vector<int>& visited, int nodeIdx, int depth) {
    int size = node[nodeIdx].size();
    vector<int> node_weight = { nodeIdx, depth }, tmp = {nodeIdx, depth};
    for (int i = 0; i < size; i++) {
        if (!visited[node[nodeIdx][i]]) {
            visited[node[nodeIdx][i]] = 1;
            tmp = dfs(node, weights, visited, node[nodeIdx][i], depth + weights[nodeIdx][i]);
            if (tmp[1] > node_weight[1]) {
                node_weight[0] = tmp[0];
                node_weight[1] = tmp[1];
            }
        }
    }
    return node_weight;
}

 

4. 배운 점

1) 번 풀이에서 계속 시간초과가 나길래 분명 알고리즘은 같은데 왜 시간 초과가 나지.. 하고 고민했었다.

그런데, 함수 인자로 이차원 배열을 보내줄때 실수로 포인터가 아닌 값으로 보내줘서 생긴 문제였다.

포인터로 보내 주면 속도가 빠르다는 건 알고 있었지만, 이걸로 시간 초과가 날 줄은 몰랐다.

LIST