[백준] 1167번 트리의 지름 구해보기
1. 문제
1) 링크
2) 문제
트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.
3) 입력
트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2≤V≤100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. (정점 번호는 1부터 V까지 매겨져 있다고 생각한다)
먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.
4) 출력
첫째 줄에 트리의 지름을 출력한다.
더 자세한 문제의 조건을 보기 위해서는 위 백준 링크를 클릭해보자
2. 풀이
이 문제는 Post-Order를 사용해서 풀 수 있다.
이 문제는 height를 저장해놓는 배열을 만든 다음에 DFS를 실시하면 효과적으로 문제를 풀 수 있다. 여기서 height를 저장하는 것은 만약 루트 노드를 거쳐가는 경우, height과 루트노드와 서브노드 사이의 거리를 더하기 때문이다.
dfs 함수가 return 해주어야 할 것을 생각해보면 크게 두가지가 있다.
첫번째로는 현재까지의 최대노드를 루트노드로 했을 때 최대 지름, 그리고 두 번째로는 현재 노드의 높이 이렇게 두 가지를 리턴해주어야 한다. 따라서 이 두 가지를 인스턴스 멤버로 가지고 있는 class Pair를 하나 만들어준다. 왜냐하면 java에서 return 값은 두개가 될 수 없기 때문에 class를 만들어서 객체로 return을 해주려고 한다.
class Pair{
int meter;
int height;
Pair(int meter, int height){
this.meter=meter;
this.height=height;
}
}
이런식으로 meter는 현재까지의 지름의 길이, height는 루트에서 노드까지의 높이를 이야기하는 것이다.
그리고 또 다른 class로 문제 입력에서 종착 노드와 그 노드까지의 거리가 주어졌으므로 Edge라는 class를 하나 만들어준다.
class Edge{
int destination;
int distance;
Edge(int destination, int distance){
this.destination=destination;
this.distance=distance;
}
}
여기서 destination은 한 노드와 연결되어 있는 노드 번호, 그리고 distance는 그 노드와의 거리를 뜻한다.
dfs 함수는 인접리스트를 for문으로 돌면서 재귀적으로 함수를 호출해주게 되는데,
만약 height와 관련된 ArrayList의 크기가 1개라면,
위와 같은 이미지이므로 배열의 첫번째 요소에 1과 2 사이의 지름이 담겨 있을 것이다. 따라서 이 경우에는 배열의 첫 번째 요소를 살펴보아야 한다.
하지만 대부분의 경우,
위와 같은 이미지의 트리구조를 가지고 있기 때문에 이 경우에는 2와 1 사이의 지름과 1과 3 사이의 지름을 더한 것이 가장 긴 지름의 길이가 된다. 따라서 이 경우에는 가장 긴 지름 두개를 합친 값을 가장 긴 지름의 후보가 된다. 물론 현재 루트를 거치지 않고 서브트리에서 나온 값이 최대가 될 수 있기 때문에 이와 비교하는 작업이 한 번 필요하다.
이런식으로 구해서 dfs의 함수 부분 코드는 아래와 같다.
static Pair dfs(int x){
ArrayList<Integer> heights = new ArrayList<>();
check[x]=true;
int cur=0; //루트를 거치지 않고 서브트리의 지름의 최대를 구하기 위함
for(Edge ed: arr[x]){
int y=ed.destination;
int dist=ed.distance;
if (!check[y]){
Pair p=dfs(y);
if (p.meter>cur){
cur=p.meter;
}
heights.add(p.height+dist); //루트를 거쳐가는 경우를 생각
}
}
Collections.sort(heights, Collections.reverseOrder()); //거꾸로
int ans=cur; //지금까지는 서브트리의 최대 지름값이 있는데 heights 중 max 값과 비교하는 것 필요
int height=0;
if (heights.size()>=1){
height=heights.get(0);
if (ans<height) ans=height;
}
if (heights.size() >= 2) {
int tmp = heights.get(0) + heights.get(1);
if (ans < tmp) {
ans = tmp;
}
}
return new Pair(ans, height);
}
즉 위 코드의 핵심은 (자식 노드에서 현재 노드와의 거리 + 다른 자식노드에서 현재 노드와의 거리) 와 자식 노드를 루트로 하는 서브트리에서의 지름과 무엇이 최대값인지를 비교해주는 것이다.
3. 코드
이런식으로 구한 전체 코드를 살펴보면 아래와 같다.
import java.util.*;
class Pair{
int meter;
int height;
Pair(int meter, int height){
this.meter=meter;
this.height=height;
}
}
class Edge{
int destination;
int distance;
Edge(int destination, int distance){
this.destination=destination;
this.distance=distance;
}
}
public class Main{
public static boolean check[];
public static ArrayList<Edge>arr[];
static Pair dfs(int x){
ArrayList<Integer> heights = new ArrayList<>();
check[x]=true;
int cur=0; //루트를 거치지 않고 서브트리의 지름의 최대를 구하기 위함
for(Edge ed: arr[x]){
int y=ed.destination;
int dist=ed.distance;
if (!check[y]){
Pair p=dfs(y);
if (p.meter>cur){
cur=p.meter;
}
heights.add(p.height+dist); //루트를 거쳐가는 경우를 생각
}
}
Collections.sort(heights, Collections.reverseOrder()); //거꾸로
int ans=cur;
int height=0;
if (heights.size()>=1){
height=heights.get(0);
if (ans<height) ans=height;
}
if (heights.size() >= 2) {
int tmp = heights.get(0) + heights.get(1);
if (ans < tmp) {
ans = tmp;
}
}
return new Pair(ans, height);
}
public static void main(String[] args){
Scanner sc=new Scanner(System.in);
int n=sc.nextInt();
arr=(ArrayList<Edge>[])new ArrayList[n+1];
for(int i=1; i<=n; i++){
arr[i]=new ArrayList<Edge>();
}
check=new boolean[n+1];
for(int i=0; i<n; i++){
int x=sc.nextInt();
while(true){
int y=sc.nextInt();
if (y==-1){
break;
}
int z=sc.nextInt();
arr[x].add(new Edge(y,z));
}
}
Pair answer=dfs(1);
System.out.println(answer.meter);
}
}