3Sum algorithm - 非常容易理解的实现 (java)
原题重述:(点击图片可以进入来源链接)
这到题目的中文解释是,
输入一个数组,例如{-1 0 1 2 -1 -4},从数组中找三个数(a,b,c),使得其和0,输出所有的(a,b,c)组合。
要求abc不能重复,并且a<=b<=c。
拿到这个题目的时候,其实每个程序猿都能想到如下的算法,也就是暴力破解,其时间复杂度为o(n^3):
1 for(int i=0;i<nums.length;i++){ 2 for(int j=i+1;j<nums.length;j++){ 3 for(int k=j+1;k<nums.length;j++){ 4 if(nums[i]+nums[j]+nums[k]==0){ 5 addResult(nums[i], nums[j], nums[k]); 6 } 7 } 8 } 9 }
首先需要对输入的数组进行排序,这样的话由于上面的i<j<k,所以可以保证nums[i]<nums[j]<nums[k]。
其实我的算法的思路就是在暴力破解的基础上进行优化,尽量降低时间复杂度。
在java中对数组排序的方法是:Arrays.sort(nums);
第三个循环其实是没有必要的,因为在确定了i,j的值之后,需要寻找的nums[k]的值就已经确定了,即-(nums[i]+nums[j])。
因此无需循环,只需要判断数组剩下的元素中是否存在这个值就可以了。
基于这个思路我构建了一个hashmap作为hash索引表,用于查找每个值出现的位置:(考虑到一个值可能出现在多个位置的情况,用arraylist)
因为nums是已经排序过的,所以索引表中的arraylist也是已排序好的。
HashMap<Integer, ArrayList<Integer>> index = new HashMap<>();
构建这个索引表的代码如下:
1 for(int i=0;i<nums.length;i++){ 2 int num = nums[i]; 3 if(num==0){ 4 n++; 5 } 6 if(index.get(num)==null){ 7 8 9 index.put(num, new ArrayList<Integer>()); 10 } 11 12 index.get(num).add(i); 13 14 15 }
这里面的n是表示0的个数,如果n>=3,就直接输出一个[0 0 0]了。
从索引表查询需要的数的方式,我想了很久,最后想到一个很不错的方法:
1 int p = -(nums[i]+nums[j]); 2 if(p<0) continue; 3 ArrayList<Integer> in = index.get(p); 4 if(in==null) continue; 5 if(in.get(in.size()-1)>j){ 6 if(p>nums[j]){ 7 addResult(nums[i], nums[j],p); 8 }else if(p>nums[i]){ 9 addResult(nums[i], p,nums[j]); 10 }else{ 11 addResult(p,nums[i], nums[j]); 12 } 13 14 }
第2行,为什么要舍弃p<0的情况?因为要避免重复。如果p也是负数的话,由于nums[i]<nums[j]那么会出现两种情况:
①nums[i]和nums[j]都是正数;
②nums[i]是负数,nums[j]是正数 。
那么在其他的扫描过程中一定会出现:
①那时的nums[i]'=p,p'=nums[i],nums[j]'=nums[j];
②那时的p'=nums[j],nums[i]'=min(nums[i],p),nums[j]'=max(nums[i],p)。
第5行in.get(in.size()-1)>j是什麼意思?
我们这个时候是需要找一个k(k>j),使得nums[k]=p,如果有就输出nums[k]。
ArrayList in表示使得nums[k]=p的所有k值,如果最大的k值大于j,那不就表示存在k>j,使得nums[k]=p了吗?
一定要求k>j,因为如下的情况是不符合要求的:
输入 [-1 0 1 2] 不能输出 [-1 -1 2] 因为-1的索引是[0],在遍历时它不满足k>j
关于避免重复的问题,
我用addResult函数来避免重复,大家一看应该就懂
1 HashSet<String> repeat = new HashSet<String>(); // 查重 2 List<List<Integer>> result = new LinkedList<List<Integer>>(); 3 4 5 public void addResult(int n1, int n2, int n3) { 6 String s = n1 + "&"+n2; 7 8 if (!repeat.contains(s)) { 9 List<Integer> p = new ArrayList<>(); 10 p.add(n1); 11 p.add(n2); 12 p.add(n3); 13 result.add(p); 14 repeat.add(s); 15 } 16 }
最终详细的代码如下:
1 import java.util.ArrayList; 2 import java.util.Arrays; 3 import java.util.HashMap; 4 import java.util.HashSet; 5 import java.util.LinkedList; 6 import java.util.List; 7 8 /** 9 * 优化了的o(n^2) 3sum算法 10 * @author user 11 * 12 */ 13 public class Solution { 14 HashSet<String> repeat = new HashSet<String>(); // 查重 15 List<List<Integer>> result = new LinkedList<List<Integer>>(); 16 17 18 public void addResult(int n1, int n2, int n3) { 19 String s = n1 + "&"+n2; 20 21 if (!repeat.contains(s)) { 22 List<Integer> p = new ArrayList<>(); 23 p.add(n1); 24 p.add(n2); 25 p.add(n3); 26 result.add(p); 27 repeat.add(s); 28 } 29 } 30 31 32 33 34 public List<List<Integer>> threeSum(int[] nums) { 35 if (nums.length < 3) { 36 return result; 37 } 38 Arrays.sort(nums); 39 if (nums.length == 3) { 40 if (nums[0] + nums[1] + nums[2] == 0) { 41 42 addResult(nums[0], nums[1], nums[2]); 43 return result; 44 } else { 45 return result; 46 } 47 } 48 HashMap<Integer, ArrayList<Integer>> index = new HashMap<>(); 49 int n=0; 50 51 for(int i=0;i<nums.length;i++){ 52 int num = nums[i]; 53 if(num==0){ 54 n++; 55 } 56 if(index.get(num)==null){ 57 58 59 index.put(num, new ArrayList<Integer>()); 60 } 61 62 index.get(num).add(i); 63 64 65 } 66 if(n>=3) addResult(0, 0, 0); 67 68 69 70 71 for(int i=0;i<nums.length;i++){ 72 if((nums[i]<0&&nums[nums.length-1]<-i)||(nums[i]>0&&nums[0]>-nums[i])) continue; 73 74 for(int j=i+1;j<nums.length;j++){ 75 76 77 78 int p = -(nums[i]+nums[j]); 79 80 if(p<0) continue; 81 ArrayList<Integer> in = index.get(p); 82 if(in==null) continue; 83 if(in.get(in.size()-1)>j){ 84 if(p>nums[j]){ 85 addResult(nums[i], nums[j],p); 86 }else if(p>nums[i]){ 87 addResult(nums[i], p,nums[j]); 88 }else{ 89 addResult(p,nums[i], nums[j]); 90 } 91 92 } 93 94 95 96 } 97 98 99 100 101 } 102 103 return result; 104 } 105 106 107 108 109 110 public static void main(String[] args) { 111 long m = System.currentTimeMillis(); 112 int a[] = {-1,-2,-3,4,1,3,0,3,-2,1,-2,2,-1,1,-5,4,-3}; 113 114 Solution solution = new Solution(); 115 System.out.println(solution.threeSum(a).size()); 116 117 long n = System.currentTimeMillis(); 118 System.out.println(n - m); 119 120 } 121 122 }