[코딩 인터뷰][HackerRank] Candles Counting

※ 이 포스트는 HackerRank 문제인 Candles Counting 해설을 담고 있습니다.
본 풀이는 Optimal하지 않을 수 있으며, HackerRank에서의 Submission 통과만을 보장합니다.


Candles Counting

문제 출처 : https://www.hackerrank.com/challenges/candles-2/problem
카테고리 : DP, Segment Tree
난이도 : Medium (85 point)


문제 설명

양초들이 왼쪽에서 오른쪽으로 줄지어있다. 각각의 양초는 높이 H와 색상 C를 가진다. 여기서 임의의 양초를 골랐을 때, 왼쪽에서 오른쪽으로 그 높이가 쭉 증가하는 것만 골라내고 싶다. 여기에서 또 하나의 조건이 붙는다. 양초의 가능한 색이 K가지라고 할 때, 임의로 고른 양초의 집합에 K개의 색이 적어도 한 번씩은 등장해야한다. 이 조건을 만족하는 집합의 수를 출력하여라.

예를 들어서 K = 3이고 (1, 1), (3, 2), (2, 2), (4, 3)의 양초 (H, C) 들이 있는 상황이라고 생각해보자.
(1, 1) (3, 2) (4, 3)이나 (1, 1) (2, 2) (4, 3)의 두 가지 경우를 꼽을 수 있다.


문제 풀이

일단 컬러를 생각하지 말고 풀어보자. 선택된 양초들이 왼쪽으로부터 오른쪽으로 높이가 증가하는 집합을 ‘증가 집합’이라고 정의하자.

왼쪽부터 하나씩 새로운 양초를 추가한다고 생각해보자. S(X, X)는 높이 X인 양초가 가장 오른쪽에 있는 증가 집합의 가짓수다. 새로운 양초가 추가될 때 어떤 일이 일어나는가. S(X, X)에 변화가 일어난다. 기존에 있던 양초들 중 높이가 X보다 작은 양초들을 마지막으로 하는 증가집합에 새로운 양초 X를 추가해도 여전히 증가집합이다. 따라서 S(X, X)는 자신의 원래 값에 새롭게 S(k, k)를 추가하게 된다. 여기서 k는 0을 포함해 X보다 작은 모든 수다. S(0, 0)은 1로 초기화되는데, 이 값을 더한다는 의미는 마지막에 추가된 높이 X의 양초 하나만으로 구성된 집합이 있다는 말이다. 즉, S(0, 0)은 공집합과 같은 의미를 가진다.

세그먼트 트리를 쓰면 S(0, 0) + S(1, 1) + … + S(X, X)를 좀 더 쉽게 구할 수 있다. log_2 X번만에 구할 수 있다. 세그먼트 트리를 구성하는 방법은 일반적인 세그먼트 트리 항목을 참조하길 바란다.

여기서 컬러를 생각해야한다. 컬러의 종류는 많아야 7가지다. 이를 비트로 나타낸다고 생각해보자. 모든 색이 사용되었다는 의미는 모든 비트가 1로 세팅된 상태라고 생각해보자. 어떤 양초가 컬러 C를 가진다고 한다면 C – 1번째 비트가 1이라고 생각하면 된다. 즉, 2 ** (C – 1)의 값을 가진다고 보면 된다.

어떤 Bit를 가지는 이전 상태에서 컬러 C를 가지는 높이 X의 양초를 추가했다고 하자. 그렇다면 새롭게 만들어지는 비트 NewBit는 bit | 2 ** (C – 1)과 같다. 즉, S(X, X)[ bit | 2 ** (C – 1) ] = S(0, 0)[bit] + S(1, 1)[bit] + … + S(X, X)[bit]이다. 모든 가능한 2 ** K가지 bit에 대해서, 새로 추가된 높이 X의 양초와 그 색상 C에 의해 달라지는 새로운 비트 상태의 S(X, X)값을 기록한다.

최종적으로는 모든 비트가 1로 켜진 bit = 111111인 상태에 대해서, S(1, X)를 리턴하면 된다. S(0, 0)은 공집합이므로 결과 계산에선 빠진다. 사실 있다고 하더라도 모든 비트가 켜진 S(0, 0)은 존재할 수가 없기 때문에 문제는 없을 것 같다.


코드

시간복잡도는 O(H x log H x 2^K)다. H = 5 x 10^4이고 K = 7이므로, 배열 연산이 느린 Python에선 함수로 구분해서 구현하면, 호출 비용이 높아 TLE가 뜬다. 아래코드는 함수 하나에 코드를 다 때려박아서 일부 중복되는 기능이 있다.

#!/bin/python3

import os
import sys

#
# Complete the candlesCounting function below.
#
MOD = 10 ** 9 + 7
S = []

def FindIdx(l, r, v):
    i = 0
    while l < r:
        mid = (l + r) // 2
        # left child
        if mid >= v:
            i = i * 2 + 1
            r = mid
        # right child
        else:
            i = (i + 1) * 2
            l = mid + 1
    return i

def FindSum(l, r, s, t, i):
    global S, MOD

    while True:
        if l == s and r == t:
            return [i]
        mid = (l + r) // 2
        # Index is right side
        if mid < s:
            l, r, s, t, i = mid + 1, r, s, t, (i + 1) * 2
        # Index is left side
        elif t <= mid:
            l, r, s, t, i = l, mid, s, t, i * 2 + 1
        # Partially overlapping
        else:
            return FindSum(l, mid, s, mid, i * 2 + 1) + FindSum(mid + 1, r, mid + 1, t, (i + 1) * 2)

def candlesCounting(k, candles):
    global S, MOD
    #
    # Write your code here.
    #
    N = 5 * (10 ** 4) + 1
    # Init Segment Tree
    S = [[0] * (2 ** k) for i in range (N * 4)]
    K = 2 ** k
    MOD = 10 ** 9 + 7
    bits = [2 ** i for i in range(K)]

    # To generalize, set default for zero value.
    Init = FindIdx(0, N, 0)
    S[Init][0] = 1    
    parent = (Init - 1) // 2

    # Propagate to parent intervals.
    while parent >= 0:
        S[parent][0] = (S[parent][0] + 1) % MOD
        parent = (parent - 1) // 2
    
    # Add each candles by order of appearence and update segment tree
    for i in range(len(candles)):
        (value, bit) = candles[i]
        bit = bits[bit - 1]
        # Find index of segment tree for interval (value, value)
        idx = FindIdx(0, N, value)
        # Find index of segment tree for interval (0, value - 1)
        # It means all the possible sequences less than value.
        idxs = FindSum(0, N, 0, value - 1, 0)
        # For possible bits (000000, ... , 1111111)
        for j in range(K):
            # New bit is calculated by '|' operation
            new_bit = bit | j
            # Calculate Sum of number of sequences which last element is less than value
            Sum = 0
            for k in range(len(idxs)):
                Sum = (Sum + S[idxs[k]][j]) % MOD
            # Add to S(N, N) and corresponding bit
            S[idx][new_bit] = (S[idx][new_bit] + Sum) % MOD

            # Propagate to intervals which includes S(N, N)
            parent = (idx - 1) // 2
            while parent >= 0:
                S[parent][new_bit] = (S[parent][new_bit] + Sum) % MOD
                parent = (parent - 1) // 2

    # Return sum of number of sequences which the greatest value is 1 ~ N and bit is 111111 which means that every color is used at least once.
    Sum = 0
    idxs = FindSum(0, N, 1, N, 0)
    for k in range(len(idxs)):
        Sum = (Sum + S[idxs[k]][K - 1]) % MOD
    return Sum

if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    nk = input().split()

    n = int(nk[0])

    k = int(nk[1])

    candles = []

    for _ in range(n):
        candles.append(list(map(int, input().rstrip().split())))

    result = candlesCounting(k, candles)

    fptr.write(str(result) + '\n')

    fptr.close()

코딩 인터뷰 – Candles Counting