[CodeForces 466C] Number of Ways
Given an array that has n integers, count the number of ways to split all elements of this array into 3 contiguous parts so that the sum of each part is the same. Each part must not be empty.
Algorithm: O(N) runtime
1. If total sum % 3 != 0, return 0.
2. Keep a count of prefix sums that are 1 / 3 of the total sum, denote it by psOneThird.
3. Visit all elements from index 1 to n - 2 and for each newly visited element a[i], do the following:
(a). If the current sum is 2 / 3 of total sum, we know we just found a partition that leaves the unvisited elements summing up to 1 / 3 of total sum. Since we just find this partition by adding a[i] to the current sum, we can conclude that there are psOneThird distinct pairs that have 1 / 3 of total sum each, with the right half of 1 / 3 ends right after a[i]. Add psOneThird to the final answer.
(b). If the current sum is 1 / 3 of total sum, update psOneThird.
Why does the above algorithm work?
Each valid partition must have the 2nd part ends on an element from a[1] to a[n - 2]. And each time the 2nd part ends on one of these elements, it generates some more unique partitions. (The 2nd part is different). So we check from a[1] to a[n - 2]. After adding one element to the current sum and it becomes 2 /3 of the total sum, we've just found such a unique partition. Now we just need to add the count of (1/3, 2/3) pairs, contributed by this newly found 2nd part. Since the 2nd part is fixed, the count is just how many prefix sums of 1 / 3 of total sum we've seen so far.
One key note here is that we must check if the current sum is 2/3 of total sum first, then check if 1/3. This ensures we don't have 1/3 and 2/3 partition ends on the same spot, right after a[i], for the corner case of 0 total sum.
private static void solve(int q, FastScanner in, PrintWriter out) {
int n = in.nextInt();
int[] a = new int[n];
long sum = 0;
for(int i = 0; i < n; i++) {
a[i] = in.nextInt();
sum += a[i];
}
if(sum % 3 != 0) {
out.println(0);
}
else {
long currSum = a[0], oneThirdSum = sum / 3, ans = 0;
int psOneThird = (a[0] == oneThirdSum ? 1 : 0);
for(int i = 1; i < n - 1; i++) {
currSum += a[i];
if(currSum == oneThirdSum * 2) {
ans += psOneThird;
}
if(currSum == oneThirdSum) {
psOneThird++;
}
}
out.println(ans);
}
out.close();
}