3Sum algorithm - 非常容易理解的实现 (java)

原题重述:(点击图片可以进入来源链接)

这到题目的中文解释是,

输入一个数组,例如{-1 0 1 2 -1 -4},从数组中找三个数(a,b,c),使得其和0,输出所有的(a,b,c)组合。

要求abc不能重复,并且a<=b<=c。

 

 

 

拿到这个题目的时候,其实每个程序猿都能想到如下的算法,也就是暴力破解,其时间复杂度为o(n^3):

1             for(int i=0;i<nums.length;i++){
2                 for(int j=i+1;j<nums.length;j++){
3                     for(int k=j+1;k<nums.length;j++){
4                         if(nums[i]+nums[j]+nums[k]==0){
5                             addResult(nums[i], nums[j], nums[k]); 
6                         }
7                     }
8                 }
9             }

 首先需要对输入的数组进行排序,这样的话由于上面的i<j<k,所以可以保证nums[i]<nums[j]<nums[k]。

 

 

 

其实我的算法的思路就是在暴力破解的基础上进行优化,尽量降低时间复杂度。

 

在java中对数组排序的方法是:Arrays.sort(nums); 

 

第三个循环其实是没有必要的,因为在确定了i,j的值之后,需要寻找的nums[k]的值就已经确定了,即-(nums[i]+nums[j])。

因此无需循环,只需要判断数组剩下的元素中是否存在这个值就可以了。

基于这个思路我构建了一个hashmap作为hash索引表,用于查找每个值出现的位置:(考虑到一个值可能出现在多个位置的情况,用arraylist)

因为nums是已经排序过的,所以索引表中的arraylist也是已排序好的。

HashMap<Integer, ArrayList<Integer>> index = new HashMap<>();

构建这个索引表的代码如下:

 1         for(int i=0;i<nums.length;i++){
 2             int num = nums[i];
 3             if(num==0){
 4                 n++;
 5             }
 6             if(index.get(num)==null){
 7                 
 8                 
 9                 index.put(num, new ArrayList<Integer>());
10             }
11             
12             index.get(num).add(i);
13             
14             
15         }

这里面的n是表示0的个数,如果n>=3,就直接输出一个[0 0 0]了。

 

 

从索引表查询需要的数的方式,我想了很久,最后想到一个很不错的方法:

 

 1                 int p = -(nums[i]+nums[j]);
 2                 if(p<0) continue;
 3                 ArrayList<Integer> in = index.get(p);
 4                 if(in==null) continue;
 5                 if(in.get(in.size()-1)>j){
 6                     if(p>nums[j]){
 7                         addResult(nums[i], nums[j],p); 
 8                     }else if(p>nums[i]){
 9                         addResult(nums[i], p,nums[j]);
10                     }else{
11                         addResult(p,nums[i], nums[j]);
12                     }
13                     
14                 }

第2行,为什么要舍弃p<0的情况?因为要避免重复。如果p也是负数的话,由于nums[i]<nums[j]那么会出现两种情况:

①nums[i]和nums[j]都是正数;

②nums[i]是负数,nums[j]是正数 。

那么在其他的扫描过程中一定会出现:

①那时的nums[i]'=p,p'=nums[i],nums[j]'=nums[j];

②那时的p'=nums[j],nums[i]'=min(nums[i],p),nums[j]'=max(nums[i],p)。

 

第5行in.get(in.size()-1)>j是什麼意思?

我们这个时候是需要找一个k(k>j),使得nums[k]=p,如果有就输出nums[k]。

ArrayList in表示使得nums[k]=p的所有k值,如果最大的k值大于j,那不就表示存在k>j,使得nums[k]=p了吗?

一定要求k>j,因为如下的情况是不符合要求的:

输入 [-1 0 1 2] 不能输出 [-1 -1 2] 因为-1的索引是[0],在遍历时它不满足k>j

 

关于避免重复的问题,

我用addResult函数来避免重复,大家一看应该就懂

 1     HashSet<String> repeat = new HashSet<String>(); // 查重
 2     List<List<Integer>> result = new LinkedList<List<Integer>>();
 3 
 4 
 5     public void addResult(int n1, int n2, int n3) {
 6         String s = n1 + "&"+n2;
 7 
 8         if (!repeat.contains(s)) {
 9             List<Integer> p = new ArrayList<>();
10             p.add(n1);
11             p.add(n2);
12             p.add(n3);
13             result.add(p);
14             repeat.add(s);
15         }
16     }

 

最终详细的代码如下:

  1 import java.util.ArrayList;
  2 import java.util.Arrays;
  3 import java.util.HashMap;
  4 import java.util.HashSet;
  5 import java.util.LinkedList;
  6 import java.util.List;
  7 
  8 /**
  9  * 优化了的o(n^2) 3sum算法 
 10  * @author user
 11  *
 12  */
 13 public class Solution {
 14     HashSet<String> repeat = new HashSet<String>(); // 查重
 15     List<List<Integer>> result = new LinkedList<List<Integer>>();
 16 
 17 
 18     public void addResult(int n1, int n2, int n3) {
 19         String s = n1 + "&"+n2;
 20 
 21         if (!repeat.contains(s)) {
 22             List<Integer> p = new ArrayList<>();
 23             p.add(n1);
 24             p.add(n2);
 25             p.add(n3);
 26             result.add(p);
 27             repeat.add(s);
 28         }
 29     }
 30     
 31     
 32     
 33 
 34     public List<List<Integer>> threeSum(int[] nums) {
 35         if (nums.length < 3) {
 36             return result;
 37         }
 38         Arrays.sort(nums);
 39         if (nums.length == 3) {
 40             if (nums[0] + nums[1] + nums[2] == 0) {
 41 
 42                 addResult(nums[0], nums[1], nums[2]);
 43                 return result;
 44             } else {
 45                 return result;
 46             }
 47         }
 48         HashMap<Integer, ArrayList<Integer>> index = new HashMap<>();
 49         int n=0;
 50         
 51         for(int i=0;i<nums.length;i++){
 52             int num = nums[i];
 53             if(num==0){
 54                 n++;
 55             }
 56             if(index.get(num)==null){
 57                 
 58                 
 59                 index.put(num, new ArrayList<Integer>());
 60             }
 61             
 62             index.get(num).add(i);
 63             
 64             
 65         }
 66         if(n>=3) addResult(0, 0, 0);
 67 
 68         
 69         
 70         
 71         for(int i=0;i<nums.length;i++){
 72             if((nums[i]<0&&nums[nums.length-1]<-i)||(nums[i]>0&&nums[0]>-nums[i])) continue;
 73             
 74             for(int j=i+1;j<nums.length;j++){
 75                 
 76                 
 77                 
 78                 int p = -(nums[i]+nums[j]);
 79                 
 80                 if(p<0) continue;
 81                 ArrayList<Integer> in = index.get(p);
 82                 if(in==null) continue;
 83                 if(in.get(in.size()-1)>j){
 84                     if(p>nums[j]){
 85                         addResult(nums[i], nums[j],p); 
 86                     }else if(p>nums[i]){
 87                         addResult(nums[i], p,nums[j]);
 88                     }else{
 89                         addResult(p,nums[i], nums[j]);
 90                     }
 91                     
 92                 }
 93                 
 94                 
 95                 
 96             }
 97             
 98 
 99             
100             
101         }
102 
103         return result;
104     }
105     
106     
107 
108 
109     
110     public static void main(String[] args) {
111         long m = System.currentTimeMillis();
112         int a[] = {-1,-2,-3,4,1,3,0,3,-2,1,-2,2,-1,1,-5,4,-3};
113 
114         Solution solution = new Solution();
115         System.out.println(solution.threeSum(a).size());
116 
117         long n = System.currentTimeMillis();
118         System.out.println(n - m);
119 
120     }
121 
122 }
View Code

 

posted @ 2016-03-17 09:28  &amp;nbsp;  阅读(1686)  评论(1编辑  收藏  举报