[BOJ] 19566 - 수열의 구간 평균 (C++)
1. 문제
https://www.acmicpc.net/problem/19566
길이가 n인 수열이 주어질 때, 이 수열에서 평균이 k인 구간의 개수를 구하면 되는 문제이다.
2. 해결 방법
1. 먼저 평균의 의미를 다시 생각해보자.
평균 = 원소들의 합 / 원소들의 개수이다.
양변에 원소들의 개수를 곱하면, 다음과 같이 쓸 수 있다.
구간합 = 평균 * 구간길이 이다.
즉, 구간합을 원소들의 개수로 나눠 이를 평균과 비교할 것이 아니라, 구간합이 구간길이 * 평균과 같은지만 비교해주면 된다는 것이다.
2. 어떤 구간이 평균이 k일지 (구간합이 구간길이 * k일지) 생각해보자.
먼저 brute force 방식으로 수열에서 나올 수 있는 모든 구간을 구해보는 방법이 있다.
이는 다음과 같은 시간복잡도를 가진다.
우리는 n개의 수열 에서 1개, 2개, .., n개를 골라야 하므로, 결국 2^n-1 에서 1을 빼준 2^n-1 - 1번을 반복한다.
결국 시간복잡도는 O(2^n)이 된다.
이리저리 방법을 생각해봐도, 이 시간복잡도를 O(nlogn) 이하로 낮출 방법은 보이지 않지만, 9084번에서의 풀이를 떠올려 보면 그래도 힌트를 얻을 수 있다. (주어진 금액을 만들 수 있는 모든 동전 조합의 개수를 세는 문제)
이때는 시간복잡도를 O(n)만에 풀 수 있게 하기 위해 dp를 이용했는데, 먼저 가지고 있는 동전의 종류로 바깥의 반복문을 돌면서 안쪽의 반복문은 dp에서 경우의 수를 세는 memory배열을 채워간다.
그렇게 현재 동전의 종류에서는 이전에 반복문을 돌았던 모든 동전의 경우의 수를 덧칠하는 방식으로 진행하면 결국 모든 동전을 사용하는 조합을 거친 셈이 되는 문제였다.
이 아이디어를 이번에도 사용한다.
3. 해결법
// 멤버 변수들 초기화
int n, k;
// n, k 입력
vector<long long> arr(n, 0);
long long ans = 0, tmp;
map<long long, long long> m;
먼저 배열에는 누적합을 저장한다. 누적합을 저장하는 이유는 구간합을 빨리 구하기 위해서도 있고, 이번에 사용할 알고리즘이 누적합과 성격이 잘 맞기 때문이다.
누적합을 저장했다면, 반복문을 n번 반복하면서 i번째 원소까지의 구간합에서 구간길이 * k를 뺀 값을 tmp에 저장한다.
tmp = Si - i * k (1 <= i <= n)
> 그리고 map 자료형으로 맵 m을 하나 만들고, 이렇게 구해진 tmp값의 인덱스값을 정답에 더한다.
ans += m[tmp]++;
그리고, m[tmp]값을 1 증가시켜준다.
이게 어떤 의미를 가지냐면, 먼저 tmp에는 현재 구간합이 평균 * 구간길이보다 얼마 초과했는지를 저장한다.
그래서 m에는 이 초과된 케이스의 개수가 저장되는데, 만약 1 초과했다면, map[1] 값은 맨 처음에 0이었다가 1로 증가되는 것이다.
그래서 다음 번에 만약 어떤 구간에서 다시 초과된 값이 1인 케이스가 나온다면, (i2라고 하자) 이 둘의 구간합을 서로 빼준다면 초과된 양이 상쇄되어
Si2 - Si = (i2 - i) * k,
구간길이 = i2 - i이므로, 구간평균 = k가 된다.
평균이 k가 되는 케이스가 나오는 것이다.
또한, 이제까지 나왔던 모든 케이스들의 조합도 추가해줘야 하기에, m에 이제까지 나왔던 값을 저장해 두고, 이를 ans에 더해주고, m[tmp]값을 1 증가시키는 (케이스가 하나 늘었다는 뜻) 것이다.
그리고 맨 마지막에 가서는 ans에 m[0]값을 한번 더해주는데, 이건 마지막에 가서 전체 구간에서의 값을 한번 더해줘야 하기 때문이다.
3. 배운 점, 배울 점
사실 이번 문제를 풀긴 했는데, 친구의 도움이 있었기에 풀 수 있었지, 자료도 별로 없는데다 처음 보는 유형이라(트리를 이용한 맵) 풀이 방법이 도저히 떠올리기 어려웠던 문제이다.
특히 조합을 처리하는 아이디어에 맵을 쓰는 아이디어를 종합해 줘야 해서 더 어려웠던 것 같다.
코드 작성 시에는 각 원소가 10^9까지의 크기를 가지는 수열의 구간합을 다루다보니 오버플로가 나지 않게 long long형을 적절히 사용해줘야 한다.