[백준/Baekjoon] 2696번: 중앙값 구하기

2 minute read

1. 문제

어떤 수열을 읽고, 홀수번째 수를 읽을 때 마다, 지금까지 입력받은 값의 중앙값을 출력하는 프로그램을 작성하시오.

예를 들어, 수열이 1,5,4,3,2 이면, 홀수번째 수는 1번째 수, 3번째 수, 5번째 수이고, 1번째 수를 읽었을 때 중앙값은 1, 3번째 수를 읽었을 때는 4, 5번째 수를 읽었을 때는 3이다.

2. 입력

첫째 줄에 테스트 케이스의 개수 T(1<=T<=1,000)가 주어진다. 각 테스트 케이스의 첫째 줄에는 수열의 크기 M(1<=M<=9999, M=홀수)이 주어지고, 그 다음 줄부터 이 수열의 원소가 차례대로 주어진다. 원소는 한 줄에 10개씩 나누어져있고, 32비트 부호있는 정수이다. (대부분의 언어에서 int)

3. 출력

각 테스트 케이스에 대해 첫째 줄에 출력하는 중앙값의 개수를 출력하고, 둘째 줄에는 홀수 번째 수를 읽을 때 마다 구한 중앙값을 차례대로 공백으로 구분하여 출력한다. 이때, 한 줄에 10개씩 출력해야 한다.

4. 예제 입력 1

3
9
1 2 3 4 5 6 7 8 9
9
9 8 7 6 5 4 3 2 1
23
23 41 13 22 -3 24 -31 -11 -8 -7
3 5 103 211 -311 -45 -67 -73 -81 -99
-33 24 56

5. 예제 출력 1

5
1 2 3 4 5
5
9 8 7 6 5
12
23 23 22 22 13 3 5 5 3 -3
-7 -3

6. 풀이

힙(Heap) 은 힙의 특성(최소 힙(Min Heap) 에서는 부모가 항상 자식보다 작거나 같다) 을 만족하는 거의 완전한 트리인 특수한 트리 기반의 자료구조다. 이를 구현하기 위해 heapq 모듈 을 사용한다.

두 개의 힙을 사용하여 중앙값을 구한다. 기본적인 원리는 에디터 문제와 비슷하다고 볼 수 있다. 가운데 값을 중심으로 왼쪽 힙을 최대 힙, 오른쪽 힙을 최소 힙으로 구현한다.

if val > middle:  # 중앙 값보다 크다면
    heapq.heappush(right, val)  # 최소 힙인 오른쪽 힙에 삽입
else:  # 중앙 값보다 작다면
    heapq.heappush(left, -val)  # 최대 힙인 왼쪽 힙에 삽입

그리고 2번 삽입할 때마다 왼쪽 힙과 오른쪽 힙의 길이를 비교한다. 이 중 더 길이가 긴 힙에서 값을 하나 꺼내 이를 중앙값으로 이용하게 된다. 이 때, 왼쪽 힙은 최대 힙으로 구현되었으므로 값을 꺼낼 때 부호가 바뀐다는 것에 주의해야 한다.

if idx % 2 == 0:  # 삽입이 2번 이루어질 때마다 길이를 비교
    if len(left) < len(right):  # 오른쪽 힙이 왼쪽 힙보다 길이가 길다면
        heapq.heappush(left, -middle)  # 기존의 중앙값은 왼쪽 힙에 삽입
        middle = heapq.heappop(right)  # 오른쪽 힙의 값을 중앙값으로 이용
    elif len(left) > len(right):  # 왼쪽 힙이 오른쪽 힙보다 길이가 길다면
        heapq.heappush(right, middle)  # 기존의 중앙값을 오른쪽 힙에 삽십
        middle = -heapq.heappop(left)  # 왼쪽 힙의 값을 중앙값으로 이용
    result.append(middle)

또한 입력에서 시간 초과 방지를 위해 sys 모듈의 sys.stdin을 사용한다. 이유는 이 글 에서 확인할 수 있다.

7. 코드

from sys import stdin
import heapq


def check(data):
    left, right = [], []
    middle = data[0]
    result = [middle]
    for idx, val in enumerate(data[1:], 1):
        if val > middle:
            heapq.heappush(right, val)
        else:
            heapq.heappush(left, -val)
        if idx % 2 == 0:
            if len(left) < len(right):
                heapq.heappush(left, -middle)
                middle = heapq.heappop(right)
            elif len(left) > len(right):
                heapq.heappush(right, middle)
                middle = -heapq.heappop(left)
            result.append(middle)

    print(len(result))

    for i in range(len(result)):
        if (i + 1) != 1 and (i + 1) % 10 == 1:
            print()
        print(result[i], end=' ')
    print()


t = int(stdin.readline().rstrip())

for _ in range(t):
    m = int(input())
    data = []

    if m % 10 == 0:
        for _ in range(m // 10):
            data.extend(list(map(int, stdin.readline().rstrip().split())))
    else:
        for _ in range(m // 10 + 1):
            data.extend(list(map(int, stdin.readline().rstrip().split())))

    check(data)

Leave a comment