[LeetCode 1262] Greatest Sum Divisible by Three
Given an array nums
of integers, we need to find the maximum possible sum of elements of the array such that it is divisible by three.
Example 1:
Input: nums = [3,6,5,1,8]
Output: 18
Explanation: Pick numbers 3, 6, 1 and 8 their sum is 18 (maximum sum divisible by 3).
Example 2:
Input: nums = [4]
Output: 0
Explanation: Since 4 is not divisible by 3, do not pick any number.
Example 3:
Input: nums = [1,2,3,4,4]
Output: 12
Explanation: Pick numbers 1, 3, 4 and 4 their sum is 12 (maximum sum divisible by 3).
Constraints:
1 <= nums.length <= 4 * 10^4
1 <= nums[i] <= 10^4
Incorrect greedy solution
If a number is divisible by 3, always add it to the final sum. If not, there will be two general cases: remainder 1 or 2. We need to combine these non-3-divisible numbers. One way of doing this is to save remainder 1 and 2 numbers separately. Then from higher to lower, greedily combine 3 remainder 1 numbers and 3 remainder 2 numbers to get a sum divisible by 3. However, this is incorrect. Consider this counter example: [2,6,2,2,7].
We would have {7} in the remainder 1 list, {2,2,2} in the remainder 2 list. The above solution would pick all 2s to get a sum of 6. But by picking 7 and 2 we would get a better result. So the combination step is not limited to one case. We can have cross case picks (pick one remainder 1 and one remainder 2) to get an optimal result.
When greedy does not work, we should consider a dynamic programming solution.
dp[i][j]: the max sum from nums[0, i] with the sum % 3 == j.
Depending on nums[i] % 3, we update dp[i][j] using dp[i - 1]. The final answer is dp[n - 1][0].
The pitfall here is that depending on nums[i] % 3 and the current remainder j, nums[i] can or can not be picked. For example, if nums[i] % 3 == 1 and j = 2, then only if dp[i - 1][1] is > 0, meaning in nums[0, i - 1] we have picked at least 1 number to get a sum S such that S % 3 == 1, can we pick nums[i] to get a bigger sum S' such that S' % 3 == 2.
The state transition needs to correctly handle this logic.
class Solution { public int maxSumDivThree(int[] nums) { int n = nums.length; int[][] dp = new int[n][3]; if(nums[0] % 3 == 0) { dp[0][0] = nums[0]; } else if(nums[0] % 3 == 1) { dp[0][1] = nums[0]; } else { dp[0][2] = nums[0]; } for(int i = 1; i < n; i++) { int r = nums[i] % 3; for(int j = 0; j < 3; j++) { if(j == r) { dp[i][j] = Math.max(dp[i - 1][j], dp[i - 1][0] + nums[i]); } else { int diff = (j + 3 - r) % 3; dp[i][j] = Math.max(dp[i - 1][j], dp[i - 1][diff] > 0 ? dp[i - 1][diff] + nums[i] : 0); } } } return dp[n - 1][0]; } }