3월 10, 2024

[백준] 1167번 트리의 지름 구해보기

1. 문제

1) 링크

www.acmicpc.net/problem/1167

2) 문제

트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.

3) 입력

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2≤V≤100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. (정점 번호는 1부터 V까지 매겨져 있다고 생각한다)

먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.

4) 출력

첫째 줄에 트리의 지름을 출력한다.

 

더 자세한 문제의 조건을 보기 위해서는 위 백준 링크를 클릭해보자

 


2. 풀이

이 문제는 Post-Order를 사용해서 풀 수 있다. 

이 문제는 height를 저장해놓는 배열을 만든 다음에 DFS를 실시하면 효과적으로 문제를 풀 수 있다. 여기서 height를 저장하는 것은 만약 루트 노드를 거쳐가는 경우, height과 루트노드와 서브노드 사이의 거리를 더하기 때문이다.

 

dfs 함수가 return 해주어야 할 것을 생각해보면 크게 두가지가 있다.

첫번째로는 현재까지의 최대노드를 루트노드로 했을 때 최대 지름, 그리고 두 번째로는 현재 노드의 높이 이렇게 두 가지를 리턴해주어야 한다. 따라서 이 두 가지를 인스턴스 멤버로 가지고 있는 class Pair를 하나 만들어준다. 왜냐하면 java에서 return 값은 두개가 될 수 없기 때문에 class를 만들어서 객체로 return을 해주려고 한다.

 

class Pair{
    int meter;
    int height;
    Pair(int meter, int height){
        this.meter=meter;
        this.height=height;
    }
}

이런식으로 meter는 현재까지의 지름의 길이, height는 루트에서 노드까지의 높이를 이야기하는 것이다.


 

그리고 또 다른 class로 문제 입력에서 종착 노드와 그 노드까지의 거리가 주어졌으므로 Edge라는 class를 하나 만들어준다.

class Edge{
    int destination;
    int distance;
    Edge(int destination, int distance){
        this.destination=destination;
        this.distance=distance;
    }
}

여기서 destination은 한 노드와 연결되어 있는 노드 번호, 그리고 distance는 그 노드와의 거리를 뜻한다. 


dfs 함수는 인접리스트를 for문으로 돌면서 재귀적으로 함수를 호출해주게 되는데,

만약 height와 관련된 ArrayList의 크기가 1개라면, 


ArrayList 크기가 1일 경우

 

위와 같은 이미지이므로 배열의 첫번째 요소에 1과 2 사이의 지름이 담겨 있을 것이다. 따라서 이 경우에는 배열의 첫 번째 요소를 살펴보아야 한다. 

 

하지만 대부분의 경우,


ArrayList 크기가 2 이상일 경우

 

위와 같은 이미지의 트리구조를 가지고 있기 때문에 이 경우에는 2와 1 사이의 지름과 1과 3 사이의 지름을 더한 것이 가장 긴 지름의 길이가 된다. 따라서 이 경우에는 가장 긴 지름 두개를 합친 값을 가장 긴 지름의 후보가 된다. 물론 현재 루트를 거치지 않고 서브트리에서 나온 값이 최대가 될 수 있기 때문에 이와 비교하는 작업이 한 번 필요하다. 

 


이런식으로 구해서 dfs의 함수 부분 코드는 아래와 같다.

static Pair dfs(int x){
        ArrayList<Integer> heights = new ArrayList<>();
        check[x]=true;
        int cur=0; //루트를 거치지 않고 서브트리의 지름의 최대를 구하기 위함
        for(Edge ed: arr[x]){
            int y=ed.destination;
            int dist=ed.distance;
            if (!check[y]){
                Pair p=dfs(y);
                if (p.meter>cur){
                    cur=p.meter;
                }
                heights.add(p.height+dist); //루트를 거쳐가는 경우를 생각
            }
        }
         Collections.sort(heights, Collections.reverseOrder()); //거꾸로
        int ans=cur; //지금까지는 서브트리의 최대 지름값이 있는데 heights 중 max 값과 비교하는 것 필요
        int height=0;
        if (heights.size()>=1){
           height=heights.get(0);
            
            if (ans<height) ans=height;
        }
        if (heights.size() >= 2) {
            int tmp = heights.get(0) + heights.get(1);
            if (ans < tmp) {
                ans = tmp;
            }
        }
   
        return new Pair(ans, height);
        
    }

즉 위 코드의 핵심은 (자식 노드에서 현재 노드와의 거리 + 다른 자식노드에서 현재 노드와의 거리) 와 자식 노드를 루트로 하는 서브트리에서의 지름과 무엇이 최대값인지를 비교해주는 것이다. 


3. 코드

이런식으로 구한 전체 코드를 살펴보면 아래와 같다.

import java.util.*;
class Pair{
    int meter;
    int height;
    Pair(int meter, int height){
        this.meter=meter;
        this.height=height;
    }
}

class Edge{
    int destination;
    int distance;
    Edge(int destination, int distance){
        this.destination=destination;
        this.distance=distance;
    }
}
public class Main{
    public static boolean check[];
    public static ArrayList<Edge>arr[];
    static Pair dfs(int x){
        ArrayList<Integer> heights = new ArrayList<>();
        check[x]=true;
        int cur=0; //루트를 거치지 않고 서브트리의 지름의 최대를 구하기 위함
        for(Edge ed: arr[x]){
            int y=ed.destination;
            int dist=ed.distance;
            if (!check[y]){
                Pair p=dfs(y);
                if (p.meter>cur){
                    cur=p.meter;
                }
                heights.add(p.height+dist); //루트를 거쳐가는 경우를 생각
            }
        }
         Collections.sort(heights, Collections.reverseOrder()); //거꾸로
        int ans=cur;
        int height=0;
        if (heights.size()>=1){
           height=heights.get(0);
            
            if (ans<height) ans=height;
        }
        if (heights.size() >= 2) {
            int tmp = heights.get(0) + heights.get(1);
            if (ans < tmp) {
                ans = tmp;
            }
        }
   
        return new Pair(ans, height);
        
    }
    public static void main(String[] args){
        Scanner sc=new Scanner(System.in);
        int n=sc.nextInt();
        arr=(ArrayList<Edge>[])new ArrayList[n+1];
        for(int i=1; i<=n; i++){
            arr[i]=new ArrayList<Edge>();
        }
        check=new boolean[n+1];
        for(int i=0; i<n; i++){
            int x=sc.nextInt();
            while(true){
                int y=sc.nextInt();
                if (y==-1){
                    break;
                }
                int z=sc.nextInt();
                arr[x].add(new Edge(y,z));
            }
        }
        Pair answer=dfs(1);
        System.out.println(answer.meter);
    }
}