[LeetCode] Find the Minimum Cost Array Permutation

You are given an array nums which is a permutation of [0, 1, 2, ..., n - 1]. The score of any permutation of [0, 1, 2, ..., n - 1] named perm is defined as:

score(perm) = |perm[0] - nums[perm[1]]| + |perm[1] - nums[perm[2]]| + ... + |perm[n - 1] - nums[perm[0]]|

Return the permutation perm which has the minimum possible score. If multiple permutations exist with this score, return the one that is lexicographically smallest among them.

 

Key observations

1. if we fix an answer perm, then rotating it does not change the score, so we might as well start with number 0. The final answer always start with 0.

 

 

When constructing an answer from dp state table, greedily picking a bigger number from right to left to break tie is INCORRECT. Counter example:

nums = [3,4,2,0,1]; perm1 = [0,3,1,2,4]; perm2 = [0,2,4,1,3]
score(perm1) = 0 + 1 + 1 + 1 + 1 = 4; score(perm2) = 2 + 1 + 0 + 1 + 0 = 4; 
Greedily picking a bigger number returns perm1 but perm2 is lexicographically smaller than perm1.
 
Given the minimum socre and dp state table, how do you construct the answer then?  
Instead of defining dp state as dp[i][j] = min (dp[i ^ (1 << j)][k] + abs(j - nums[k])), where we append number j in the end, we define dp state as:
dp[i][j] = min(dp[i | (1 << k)] + abs(j - nums[k])), this way we fixed the prefix of the answer and use dp result to decide which number to pick next. 
We can then use another dp array table to track for a current state and last picked number, what is the smallest next number to pick that also yields 
a minimum score. 
 
import java.util.*;

class Solution {
    private int[][] dp, val;
    public int[] findPermutation(int[] a) {
        int n = a.length;
        dp = new int[1 << n][n];
        val = new int[1 << n][n];
        for(int i = 0; i < dp.length; i++) {
            Arrays.fill(dp[i], -1);
            Arrays.fill(val[i], -1);
        }
        compute(a, 1, 0);
        int[] ans = new int[n];
        int prev = 0, idx = 1;
        for(int mask = 1; Integer.bitCount(mask) < n; mask += (1 << prev)) {
            ans[idx] = val[mask][prev];
            idx++;
            prev = val[mask][prev];
        }
        return ans;
    }
    private int compute(int[] a, int currMask, int currV) {
        if(Integer.bitCount(currMask) == a.length) {
            return Math.abs(currV - a[0]);
        }
        if(dp[currMask][currV] < 0) {
            dp[currMask][currV] = Integer.MAX_VALUE;
            for(int nextV = 1; nextV < a.length; nextV++) {
                if((currMask & (1 << nextV)) == 0) {
                    int score = Math.abs(currV - a[nextV]) + compute(a, currMask | (1 << nextV), nextV);
                    if(score < dp[currMask][currV]) {
                        dp[currMask][currV] = score;
                        val[currMask][currV] = nextV;
                    }
                }
            }
        }
        return dp[currMask][currV];
    }
    public static void main(String[] args) {
        Solution solution = new Solution();
        int[] a = new int[]{3,4,2,0,1};
        solution.findPermutation(a);
    }
}

 

 

 

posted @ 2024-05-14 23:37  Review->Improve  阅读(15)  评论(0编辑  收藏  举报