4월 17, 2024

[백준] 16638번 괄호 추가하기 2 비트마스크로 풀어보기

1. 문제

1) 링크

www.acmicpc.net/problem/16638

2) 문제

길이가 N인 수식이 있다. 수식은 0보다 크거나 같고, 9보다 작거나 같은 정수와 연산자(+, -, ×)로 이루어져 있다. 곱하기의 연산자 우선순위가 더하기와 빼기보다 높기 때문에, 곱하기를 먼저 계산 해야 한다. 수식을 계산할 때는 왼쪽에서부터 순서대로 계산해야 한다. 예를 들어, 3+8×7-9×2의 결과는 41이다.

수식에 괄호를 추가하면, 괄호 안에 들어있는 식은 먼저 계산해야 한다. 단, 괄호 안에는 연산자가 하나만 들어 있어야 한다. 예를 들어, 3+8×7-9×2에 괄호를 (3+8)×7-(9×2)와 같이 추가했으면, 식의 결과는 59가 된다. 하지만, 중첩된 괄호는 사용할 수 없다. 즉, 3+((8×7)-9)×2, 3+((8×7)-(9×2))은 모두 괄호 안에 괄호가 있기 때문에, 올바른 식이 아니다.

수식이 주어졌을 때, 괄호를 적절히 추가해 만들 수 있는 식의 결과의 최댓값을 구하는 프로그램을 작성하시오. 추가하는 괄호 개수의 제한은 없으며, 추가하지 않아도 된다.

3) 입력

첫째 줄에 수식의 길이 N(1 ≤ N ≤ 19)가 주어진다. 둘째 줄에는 수식이 주어진다. 수식에 포함된 정수는 모두 0보다 크거나 같고, 9보다 작거나 같다. 문자열은 정수로 시작하고, 연산자와 정수가 번갈아가면서 나온다. 연산자는 +, -, * 중 하나이다. 여기서 *는 곱하기 연산을 나타내는 × 연산이다. 항상 올바른 수식만 주어지기 때문에, N은 홀수이다.

4) 출력

첫째 줄에 괄호를 적절히 추가해서 얻을 수 있는 결과의 최댓값을 출력한다. 정답은 231보다 작고, -231보다 크다.


2. 풀이

더 자세한 입출력 예시는 위 백준 링크에서 확인할 수 있다. 

이 문제는 먼저 Class를 하나 더 만들어주는 것이 편하다. class Calc를 하나 만들어주고 instance로 num과 op를 가지고 있도록 만들어준다. 여기서 op는 operator의 약자로 숫자면 0, 더하기면 1, 빼기면 2, 곱하기면 3을 가지게 만들어준다. 

 

즉 아래와 같은 형태인 것이다.

class Calc{
    int num, op;
    Calc(int num, int op) {
        this.num = num;
        this.op = op;
    }
}

그런 다음에 비트마스크를 활용하여 괄호가 올 수 있는 모든 경우를 체크할 것인데, 이 문제는 괄호 안에 하나의 연산자밖에 존재하지 않고 중첩이 불가능하므로 오히려 쉬운 문제이다. 연산자의 개수는 (n-1)/2개이므로 연산자의 개수를 기준으로 비트마스크를 해주면 된다. 즉 for문의 형태가 아래와 같은 식인 것이다.

int m = (n-1)/2; //연산자의 개수
for(int i=0; i<(1<<m); i++){
            boolean possible = true;
            for (int j=0; j<m-1; j++) {
                if ((i&(1<<j)) > 0 && (i&(1<<(j+1))) > 0) {
                    possible = false; //중첩 괄호 확인
                }
            }
            if (!possible) continue;
            
            }

이런식으로 for문이 돌면 중첩괄호가 아닌 모든 괄호의 경우를 체크할 수 있고 이제는 괄호가 있는 경우를 먼저 계산해준다. 이 문제는 순서가 괄호가 있는 수 먼저 계산 -> 곱하기 먼저 계산 -> 나머지 계산 이런 식으로 진행되어야 한다. 

 

괄호를 먼저 계산하면, 원래 있는 수 배열이 훼손될 수 있기 때문에 tmp라는 새로운 배열을 하나 더 만들어주고 괄호를 계산해준다. 아래 코드는 괄호를 계산하는 부분의 코드이다.

Calc[] tmp=new Calc[n]; //tmp 배열에 옮기기
            for (int j=0; j<n; j++) {
                tmp[j] = new Calc(a[j].num, a[j].op);
            }
            for(int j=0; j<m; j++){
                if ((i&(1<<j))>0){ //괄호가 있으면 
                    int k=2*j+1; //실제 괄호의 위치
                     if (tmp[k].op == 1) { //더하기
                         tmp[k-1].num += tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    } else if (tmp[k].op == 2) { //빼기
                        tmp[k-1].num -= tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    } else if (tmp[k].op == 3) { //곱하기
                        tmp[k-1].num *= tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    }
                }
            }

다음에 *, +, -을 더 계산해야 하기 때문에 괄호로 이미 계산한 연산자의 경우 op의 값으로 -1을 가지게 업데이트 시켜주어 다음번 계산 시에 고려하지 않게 한다. 

 

이렇게 되었다면 괄호 부분의 숫자가 다 계산이 된 것이다. 이제 곱하기 부분을 먼저 계산해주고, 그 다음에는 순차적으로 하나씩 계산해주어서 최댓값을 찾아주면 된다. 

 


3. 코드

이 모든 것을 종합한 전체 Java code는 아래와 같다.

import java.util.*;
class Calc{
    int num, op;
    Calc(int num, int op) {
        this.num = num;
        this.op = op;
    }
}
public class Main{
    public static void main(String[] args){
          Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        String s = sc.next();
        Calc[] a = new Calc[n];
         for (int i=0; i<n; i++) {
            if (i%2 == 0) {
                a[i] = new Calc(s.charAt(i)-'0', 0);
            } else {
                int op = 1; //+일 경우
                if (s.charAt(i) == '-') {
                    op = 2;
                } else if (s.charAt(i) == '*') {
                    op = 3;
                }
                a[i] = new Calc(0, op);
            }
        }
        int m = (n-1)/2; //연산자의 개수
        int ans = -2147483648; //가장 최소값
        for(int i=0; i<(1<<m); i++){
            boolean possible = true;
            for (int j=0; j<m-1; j++) {
                if ((i&(1<<j)) > 0 && (i&(1<<(j+1))) > 0) {
                    possible = false; //중첩 괄호 확인
                }
            }
            if (!possible) continue;
            Calc[] tmp=new Calc[n]; //tmp 배열에 옮기기
            for (int j=0; j<n; j++) {
                tmp[j] = new Calc(a[j].num, a[j].op);
            }
            for(int j=0; j<m; j++){
                if ((i&(1<<j))>0){ //괄호가 있으면 
                    int k=2*j+1; //실제 괄호의 위치
                     if (tmp[k].op == 1) { //더하기
                         tmp[k-1].num += tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    } else if (tmp[k].op == 2) { //빼기
                        tmp[k-1].num -= tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    } else if (tmp[k].op == 3) { //곱하기
                        tmp[k-1].num *= tmp[k+1].num;
                        tmp[k].op = -1;
                        tmp[k+1].num = 0;
                    }
                }
            }
            //괄호 계산 완료
            ArrayList<Calc> c=new ArrayList<>();
            for(int j=0; j<n; j++){
                if (j%2==0){ //숫자일 경우
                    c.add(tmp[j]);
                }else if (tmp[j].op==-1){
                 j++; //이미 괄호로 처리한 것
                }
                    else{
                    //우선 곱하기만 먼저 계산
                    if (tmp[j].op==3){
                        int num=c.get(c.size()-1).num* tmp[j+1].num;
                        c.remove(c.size()-1);
                        c.add(new Calc(num, 0));
                        j += 1;
                    }
                        else{
                            c.add(tmp[j]);
                        }
                }
            }
            Calc b[] = c.toArray(new Calc[c.size()]);
            int m2 = (b.length-1)/2;
            int val = b[0].num;
            for (int j=0; j<m2; j++) {
                int k = 2*j+1;
                if (b[k].op == 1) {
                    val += b[k+1].num;
                } else if (b[k].op == 2) {
                    val -= b[k+1].num;
                } else if (b[k].op == 3) {
                    val *= b[k+1].num;
                }
            }
            if (ans < val) {
                ans = val;
            }
        }
          System.out.println(ans);
    }
}

조금 긴 코드이지만 하나씩 이해해보면 어려움이 없을 것이다.