https://oj.leetcode.com/problems/merge-k-sorted-lists/
Merge k sorted linked lists and return it as one sorted list. Analyze and describe its complexity.
解题思路:
首先想到的思路很generic,遍历k个链表的第一个节点,找出最小的那个,加入待返回的链表,同时这个节点往后一个,其他不动,然后再这样比较。在这个过程中,如果哪个链表的节点已经到了最后,就在lists里面删去他,这样直到lists为空就可以了。
这个过程中需要注意几点。第一是,在lists的循环内,进行remove操作,是有问题的。因为已经改变了lists的size,这样i其实已经往后走一个了,必须也i--。否则,就要使用iterator,这个迭代器可以避免这个问题。
第二就是,lists中获取某个节点,往后迭代,还要重新写回lists中,需要用到List.set(index, Object)的方法。
类似于lists.get(begin) = merge2Lists(lists.get(begin), lists.get(end));的写法有什么不对?因为左侧不是一个variable,而是一个对象,怎么能给他赋值?这时一个很初级却比较容易犯的错误。
假设有k个链表,最大的链表有n个元素,这个算法是需要k*nk的时间的,会超时。
/** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * ListNode(int x) { * val = x; * next = null; * } * } */ public class Solution { public ListNode mergeKLists(List<ListNode> lists) { ListNode mergedNode = new ListNode(Integer.MAX_VALUE); ListNode headNode = mergedNode; while(lists.size() > 0){ int minNodeIndex = 0; for(int i = 0; i < lists.size(); i++){ if(lists.get(i) == null){ lists.remove(i); i--; continue; } int min = Integer.MAX_VALUE; if(lists.get(i).val < min){ minNodeIndex = i; } } if(lists.size() == 0){ break; } mergedNode.next = lists.get(minNodeIndex); mergedNode = mergedNode.next; lists.set(minNodeIndex, lists.get(minNodeIndex).next); } return headNode.next; } }
然后想到,可以把这kn个节点全部放入数组中,排序,然后形成新的链表,时间复杂度为O(knlog(kn)),空间复杂度为O(kn)。
也可以采用堆而不是数组的数据结构。建造一个最小堆,堆的大小为k。将本次遍历的k个元素放入堆,取出堆顶的元素,也就是当前最小的元素,然后和上面的步骤一样。因为堆的取出顶元素的时间为O(1),insert时间为O(logn),n为堆的大小。这样可以把kn*n的时间节省到kn*logk,空间复杂为O(k)。
在java语言中,PriorityQueue的类就是最小堆的实现。在实际实现的过程中,需要重写compareTo的方法,有两种。一种是让ListNode类去实现Comparable的接口,然后override它的compareTo()方法。这种方法要求修改ListNode类,不太可行。第二种就是自己写一个Compare implements Comparator<T>的类,去override它的compare()方法。然后建立PriorityQueue的时候,用new PriorityQueue(size, new Compare())去构造一个最小堆。
借助内部类,可以去尝试实现第二种方法,下面是具体的代码。
/** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * ListNode(int x) { * val = x; * next = null; * } * } */ public class Solution { public class Compare implements Comparator<ListNode>{ public int compare(ListNode l1, ListNode l2){ return l1.val - l2.val; } } public ListNode mergeKLists(List<ListNode> lists) { if(lists.size() == 0){ return null; } ListNode headNode = new ListNode(0); ListNode iterateNode = headNode; PriorityQueue<ListNode> queue = new PriorityQueue<ListNode>(lists.size(), new Compare()); for(int i = 0; i < lists.size(); i++){ if(lists.get(i) == null){ lists.remove(i); i--; continue; } queue.add(lists.get(i)); } while(queue.size() > 0){ iterateNode.next = queue.poll(); iterateNode = iterateNode.next; if(iterateNode.next != null){ queue.add(iterateNode.next); } } return headNode.next; } }
上面的方法是AC的。
又想到借助于merge2Lists的方法,从第一个链表开始,每次merge下一个,两两归并,也就是在外面套了一层循环。很可惜,这种解法也是超时的。
/** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * ListNode(int x) { * val = x; * next = null; * } * } */ public class Solution { public ListNode mergeKLists(List<ListNode> lists) { ListNode headNode = null; for(int i = 0; i < lists.size(); i++){ headNode = merge2Lists(headNode, lists.get(i)); } return headNode; } public ListNode merge2Lists(ListNode l1, ListNode l2) { if(l1 == null){ return l2; } if(l2 == null){ return l1; } ListNode dummy = new ListNode(0); ListNode returnNode = dummy; while(l1 != null && l2 != null){ if(l1.val < l2.val){ dummy.next = l1; dummy = dummy.next; l1 = l1.next; }else { dummy.next = l2; dummy = dummy.next; l2 = l2.next; } } if(l1 == null){ dummy.next = l2; } if(l2 == null){ dummy.next = l1; } return returnNode.next; } }
为了解决超时的方法,这里使用两两归并排序的方法。使用两个指针,指向头尾节点,然后对他们排序,将排序得到的链表放入首节点的位置,然后两个节点往中间靠,直到首尾指针相遇。这时尾指针不动,首指针从0再开始循环,重复上述过程,直到尾指针到0。这时返回首元素,就是归并好的链表。
这个方法的时间复杂度为O(kn*logk)。
/** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * ListNode(int x) { * val = x; * next = null; * } * } */ public class Solution { public ListNode mergeKLists(List<ListNode> lists) { if(lists.size() == 0){ return null; } int end = lists.size() - 1; int begin = 0; while(end > 0){ begin = 0; while(begin < end){ lists.set(begin, merge2Lists(lists.get(begin), lists.get(end))); begin++; end--; } } return lists.get(0); } public ListNode merge2Lists(ListNode l1, ListNode l2) { if(l1 == null){ return l2; } if(l2 == null){ return l1; } ListNode dummy = new ListNode(0); ListNode returnNode = dummy; while(l1 != null && l2 != null){ if(l1.val < l2.val){ dummy.next = l1; dummy = dummy.next; l1 = l1.next; }else { dummy.next = l2; dummy = dummy.next; l2 = l2.next; } } if(l1 == null){ dummy.next = l2; } if(l2 == null){ dummy.next = l1; } return returnNode.next; } }
总结一下,这道题有两种解法,一种是利用最小堆的数据结构,另一种利用merge2Lists的方法,而且在中间采用了二分排序,每次两两merge,也可以解决问题。
这两种解法虽然时间复杂度都是O(kn*logk),但是思路上完全是两个方向,值得好好体会。特别是在解决一个问题,时间复杂度很高的时候,如果去降低它。我们看到可以借助更好的数据结构,比如堆,但是要花费一些额外的空间,可是能够很简便的解决比较复杂的问题。在不允许使用这种已有的数据结构时,多数是可以采用折半或者两两归并的方法,借助已有的方法,取解决问题,将线性的时间复杂度降低到对数时间。
参考文章:
http://bangbingsyb.blogspot.jp/2014/11/leetcode-merge-k-sorted-lists.html
http://fmarss.blogspot.jp/2014/09/leetcode-solution_11.html
http://blog.csdn.net/linhuanmars/article/details/19899259
https://oj.leetcode.com/discuss/26/is-the-complexity-o-kn
https://oj.leetcode.com/discuss/23855/java-solution-without-recursion-feel-free-to-comment
update 2015/06/21:
三刷,上面归并算法的一个递归的写法
/** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * ListNode(int x) { val = x; } * } */ public class Solution { public ListNode mergeKLists(ListNode[] lists) { if(lists.length == 0) { return null; } return mergeHelper(lists, 0, lists.length - 1); } public ListNode mergeHelper(ListNode[] lists, int start, int end) { if(start == end) { return lists[start]; } if(start > end) { return null; } int mid = start + (end - start) / 2; ListNode left = mergeHelper(lists, start, mid); ListNode right = mergeHelper(lists, mid + 1, end); return mergeTwoLists(left, right); } public ListNode mergeTwoLists(ListNode l1, ListNode l2) { ListNode dummy = new ListNode(0); ListNode res = dummy; while(l1 != null && l2 != null) { if(l1.val <= l2.val) { dummy.next = l1; l1 = l1.next; dummy = dummy.next; } else { dummy.next = l2; l2 = l2.next; dummy = dummy.next; } } if(l1 == null) { dummy.next = l2; } else { dummy.next = l1; } return res.next; } }