原创:生产者和消费者线程内部原理深入解读

在多线程编程中,最经典的莫过于生产者和消费者线程了。比如,写一个简易的spider爬虫系统,生产者负责抓取网页,消费者查询网页内容。从内部深入理解运行机制,将会产生质的提升。最多线程开发时,基本流程是先设计公共类,然后设计任务类,包括生产者和消费者,再设计任务调度类,线程同步工具主要从任务调度类传入到任务类,包括CountDownLatch,Future Task,Semaphore等等。之前的一篇解决计算缓存的博客已经提过了。另外在ThreadPoolExecutor源代码的博客中,已经提到过ArrayBlockingQueue这个数据结构。在高级程序的开发中,比如lucene,更多的用CompletionService,它兼具ThreadPoolExecutor和ArrayBlockingQueue,call调用后产生的Future可以存储在ArrayBlockingQueue中,然后通过take方法取出来。这个工具非常好用!在Concurrent包下,由很多底层基于CAS算法的数据结果,包括Atomic,ConcurrentHashMap,队列主要由先进先出的ArrayBlockingQueue,优先级队列PriorityBlockingQueue等。今天主要讲ArrayBlockingQueue的lock机制和生产者消费者实现机制。先看一个图:


生产者向ArrayBlockingQueue中放入产品,消费者从中取出产品。put和take方法分别解决了线程阻塞的问题,当公共集合中的产品已经达到上限时,生产者线程阻塞,同理,当里面没有产品时,消费者线程阻塞。下面从代码角度分析一下原理:

java中的锁包括悲观锁和乐观锁,分别是synchronized和ReentrantLock。前者的性能逊色于后者,后者更具有可伸缩性和高的性能,因为底层调用了CAS算法。

ReentrantLock 是具有与隐式监视器锁定(使用 synchronized 方法和语句访问)相同的基本行为和语义的 Lock 的实现,但它具有扩展的能力。

    作为额外收获,在竞争条件下,ReentrantLock 的实现要比现在的 synchronized 实现更具有可伸缩性。(有可能在 JVM 的将来版本中改进 synchronized 的竞争性能。)

    这意味着当许多线程都竞争相同锁定时,使用 ReentrantLock 的吞吐量通常要比 synchronized 好。换句话说,当许多线程试图访问 ReentrantLock 保护的共享资源时,JVM 将花费较少的时间来调度线程,而用更多个时间执行线程。

    虽然 ReentrantLock 类有许多优点,但是与同步相比,它有一个主要缺点 -- 它可能忘记释放锁定。建议当获得和释放 ReentrantLock 时使用下列结构:

Lock lock = new ReentrantLock();

...

lock.lock();//底层执行CAS操作

try {

  // perform operations protected by lock

}

catch(Exception ex) {

 // restore invariants

}

finally {

  lock.unlock();//千万别忘了释放锁,在lucene排序的源代码中,当搜索处文档后,执行lock.lock(),然后用PriorityQueue排序,最后lock.unlock()

}

CAS无锁算法

要实现无锁(lock-free)的非阻塞算法有多种实现方法,其中 CAS(比较与交换,Compare and swap) 是一种有名的无锁算法。CAS, CPU指令,在大多数处理器架构,包括IA32、Space中采用的都是CAS指令,CAS的语义是“我认为V的值应该为A,如果是,那么将V的值更新为B,否则不修改并告诉V的值实际为多少”,CAS是项乐观锁 技术,当多个线程尝试使用CAS同时更新同一个变量时,只有其中一个线程能更新变量的值,而其它线程都失败,失败的线程并不会被挂起,而是被告知这次竞争中失败,并可以再次尝试。CAS有3个操作数,内存值V,旧的预期值A,要修改的新值B。当且仅当预期值A和内存值V相同时,将内存值V修改为B,否则什么都不做。CAS无锁算法的C实现如下:

int compare_and_swap (int* reg, int oldval, int newval) 
{
  ATOMIC();
  int old_reg_val = *reg;
  if (old_reg_val == oldval) 
     *reg = newval;
  END_ATOMIC();
  return old_reg_val;
}

CAS(乐观锁算法)的基本假设前提

CAS比较与交换的伪代码可以表示为:

do {   
       备份旧数据;  
       基于旧数据构造新数据;  
} while(!CAS( 内存地址,备份的旧数据,新数据 ))  

就是指当两者进行比较时,如果相等,则证明共享数据没有被修改,替换成新值,然后继续往下运行;如果不相等,说明共享数据已经被修改,放弃已经所做的操作,然后重新执行刚才的操作。容易看出 CAS 操作是基于共享数据不会被修改的假设,采用了类似于数据库的 commit-retry 的模式。当同步冲突出现的机会很少时,这种假设能带来较大的性能提升。

JVM对CAS的支持:AtomicInt, AtomicLong.incrementAndGet()

Java中的原子操作( atomic operations)

原子操作指的是在一步之内就完成而且不能被中断。原子操作在多线程环境中是线程安全的,无需考虑同步的问题。在java中,下列操作是原子操作:

  • all assignments of primitive types except for long and double
  • all assignments of references
  • all operations of java.concurrent.Atomic* classes
  • all assignments to volatile longs and doubles

问题来了,为什么long型赋值不是原子操作呢?例如:

long foo = 65465498L;

实时上java会分两步写入这个long变量,先写32位,再写后32位。这样就线程不安全了。如果改成下面的就线程安全了:

private volatile long foo;

因为volatile内部已经做了synchronized.


以上是补充内容,下面重点研究一下ReentrantLock:

里面有个抽象类Sync,有两个实现:NonfairLock和FairLock,公平锁和非公平锁。在构造器中默认使用非公平锁。关于CLH队列,稍后介绍,先看一下lock方法:lock方法执行逻辑为:当一个线程对公共队列进行操作时,提供底层线程同步支持,先执行CAS算法,如果成功的话,就进行下一步动作,比如take,如果失败了,重新尝试一次,如果成果了,逻辑同上,如果失败了,加入到CLH队列中等待。伪代码如下:

final void lock(){

  if(compareAndSetState(0,1)){

    setExclusiveOwnerThread(Thread.currentThread());//如果成功了,把当前线程设置为线程持有者,别的线程不能执行了

  } else {

    acquire(1);

  }

}

compareAndSetState(0,1)底层调用compareAndSwap(CAS),基本原理这样的:全局变量state初始值为0,当执行lock成功时,把它改为1,当解锁时,再改回0。当由线程进来的时候,先拿0和1进行比较(CAS)。compareAndSetState调用compareAndSwapState(this,sun.misc.unsafe.Unsafe,0,1),这个方法通过unsafe调用本地方法库中的compareAndSwap(CAS)。如果执行CAS失败了,调用acquire(1)方法。传递的参数1意思是先尝试一次CAS,如果还是还是失败的话,加入到CLH中:acquireQueued(addWaiter(Node.ExCLUSIVE),arg)。CLH的本质是一个双向链表,如果你之前自己实现过双向链表的话,理解起来就非常容易了。把Node节点添加到链表中时,仍然考虑线程同步的问题,还是通过CAS算法,先看一看addWaiter代码:

用图形象地表示一下,Node节点封装了图中的线程状态属性,不是真正的队形,所以起名虚拟队列。这些状态属性非常重要,通过判断信号状态的变化来唤醒队列中的等待线程。

 

直接上传草稿了,以后可以直接看这个图节省时间了,关于源代码的详细解读,不写了,没有这个必要,每个人都有自己的理解。最后要提醒的是,如何解决消费者线程无限等待的问题(当其中一个线程把公共队列中的产品用光了,其他的线程通过take方法只能无限等待下去的问题)。解决办法就是加入"毒丸对象",说白了就是这个对象不是真这个的产品,只是一个标识而已,当有线程取到最后一个产品时,加入这个标识,改变循环状态,让循环终止,伪代码如下:

run(){

  boolean done = false;

  while(!done) {

    产品 = queue.take();

    if(产品 == 标识对象){//注意,此时queue中已经空了,所以把标识对象再放回去

      queue.put(标识对象);

      done = true;//改变状态,循环终止

    }

    consumer.consume();//消费者线程消费产品

  }

}

生产者线程向queue中方产品时,最后也要放入"毒丸对象"。Ok,介绍完毕了!上传简易爬虫spider的代码,体会一下这个原理:

入口:

package httpClient;

public class Test {

    public static void main(String[] args) throws InterruptedException {
        UrlHanding urlHanding = new UrlHanding();
        String []seeds = {"http://www.oschina.net/code/explore/achartengine/client/AndroidManifest.xml",
                "http://www.oschina.net/code/explore","http://www.oschina.net/code/explore/achartengine",
                "http://www.oschina.net/code/explore/achartengine/client","http://map.baidu.com"};
        urlHanding.urlHanding(seeds);//加入种子文件
    }
}

任务调度类:

package httpClient;

import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
 * 网页抓取
 * @author TongXueQiang
 * @date 2016/05/16
 */
public class UrlHanding {
    private final int THREADS = 10;
    private final ExecutorService producerExecutor = Executors.newSingleThreadExecutor();
    BlockingQueue<Runnable> q = new ArrayBlockingQueue<Runnable>(10);
    private final ExecutorService consumerExecutor = new MyThreadPoolExecutor(12, 30, 1000,TimeUnit.MILLISECONDS, q, new ThreadPoolExecutor.CallerRunsPolicy());
    private final CountDownLatch startLatch = new CountDownLatch(1);
    private final CountDownLatch endLatch = new CountDownLatch(THREADS);
    private static UrlQueue queue;
    
    public void urlHanding(String[] seeds) throws InterruptedException {        
        queue = getUrlQueue();
        System.out.println("处理器数量:"+Runtime.getRuntime().availableProcessors());
        long start = (long) (System.nanoTime() / Math.pow(10, 9));
        
        producerExecutor.execute(new GetSeedUrlTask(queue,seeds,startLatch));        
        producerExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
        producerExecutor.shutdown();
        startLatch.await();
        
        UrlDataHandingTask []url_handings = new UrlDataHandingTask[THREADS];
        for (int i = 0;i < THREADS;i++) {
            url_handings[i] = new UrlDataHandingTask(startLatch,endLatch,queue);
            consumerExecutor.execute(url_handings[i]);            
        }
        consumerExecutor.shutdown();
        startLatch.countDown();
        doSomething();
        endLatch.await();
        
        long end = (long) (System.nanoTime() / Math.pow(10,9) - start);
        System.out.println("耗时: " + end + "秒");
    }

    private void doSomething() {
        
        
    }

    private UrlQueue getUrlQueue() {
        if (queue == null) {
            synchronized(UrlQueue.class){
                if (queue == null) {
                    queue = new UrlQueue();
                    return queue;
                }
            }
        }
        return queue;
    }
}
生产者任务类:

package httpClient;

import java.util.concurrent.CountDownLatch;

public class GetSeedUrlTask implements Runnable {
    private UrlQueue queue;
    private String[] seeds;
    private CountDownLatch startLatch;
    
    public GetSeedUrlTask(UrlQueue queue, String[] seeds,CountDownLatch startLatch) {
        this.queue = queue;
        this.seeds = seeds;
        this.startLatch = startLatch;
    }

    public void addUrl() {
        try {
            for (String url : seeds) {
                queue.addElem(url);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    @Override
    public void run() {
        addUrl();        
        try {
            queue.addElem("www.baidu.com");//鍔犲叆"姣掍父瀵硅薄"
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        startLatch.countDown();
    }

}
消费者任务类:

package httpClient;

import java.util.concurrent.CountDownLatch;

public class UrlDataHandingTask implements Runnable {
    private CountDownLatch startLatch;
    private CountDownLatch endLatch;
    private UrlQueue queue;

    public UrlDataHandingTask(CountDownLatch latch, CountDownLatch endLatch, UrlQueue queue) {
        this.startLatch = latch;
        this.endLatch = endLatch;
        this.queue = queue;        
    }

    /**
     * 下载对应的页面并抽取出链接,放入待处理队列中
     * @param url
     * @throws InterruptedException
     */
    public void dataHanding(String url) throws InterruptedException {
        getHrefOfContent(DownPage.getContentFromUrl(url));
        for (String url0 : VisitedUrlQueue.visitedUrlQueue) {
            System.out.println(url0);
        }
    }

    @Override
    public void run() {
        try {
            startLatch.await();
        } catch (InterruptedException e1) {
            Thread.currentThread().interrupt();
        }

        while (!queue.isEmpty()) {
            try {
                String url = queue.outElem();
                if ("www.baidu.com".equals(url.trim())) {
                    queue.addElem(url);
                    break;
                }
                dataHanding(url);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        endLatch.countDown();
    }

    /**
     * 获取页面源代码中的超链接
     *
     * @param content
     * @throws InterruptedException
     */
    public void getHrefOfContent(String content) throws InterruptedException {
        System.out.println("开始");
        String[] contents = content.split("<a href=\"");
        for (int i = 1; i < contents.length; i++) {
            int endHref = contents[i].indexOf("\"");
            String aHref = FunctionUtils.getHrefOfInOut(contents[i].substring(0, endHref));
            if (aHref != null) {
                String href = FunctionUtils.getHrefOfInOut(aHref);
                if (queue.isContains(href) && !VisitedUrlQueue.isContains(href)
                        && href.indexOf("/code/explore") != -1) {
                    // 放入待抓取队列中
                    queue.addElem(href);
                }
            }
        }
        System.out.println(queue.size() + "--抓取到的连接数");
        System.out.println(VisitedUrlQueue.size() + "--已处理的页面数");
    }

}

utils中的类:

package httpClient;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class FunctionUtils {
    //匹配超链接的正则表达式
    private static String pat = "http://www\\.oschina\\.net/code/explore/.*/\\w+\\.[a-zA-Z]+";
    private static Pattern pattern = Pattern.compile(pat);
    private static BufferedWriter writer = null;
    //爬虫搜索深度
    private static int depth = 0;
    /**
     * 以"/"来分割URL,获得超链接的元素
     * @param url
     * @return
     */
    public static String[] divUrl(String url){
        return url.split("/");
    }    
    /**
     * 判断能否创建文件
     * @param url
     * @return
     */
    public static boolean isCreateFile(String url){
        Matcher matcher = pattern.matcher(url);
        return matcher.matches();
    }
    /**
     * 创建对应文件
     * @param content
     * @param urlPath
     */
    public static void createFile(String content,String urlPath){
        //1.分割url
        String []elems = divUrl(urlPath);
        //2.拼接文件路径
        StringBuffer path = new StringBuffer();
        File file = null;
        for (int i = 1;i < elems.length;i++) {
            if (i != elems.length - 1) {
                path.append(elems[i]);
                path.append(File.separator);
                file = new File("E:" + File.separator + path.toString());
            }
            if (i == elems.length - 1) {
                Pattern pattern = Pattern.compile("\\w+\\.[a-zA-Z]");
                Matcher matcher = pattern.matcher(elems[i]);
                if (matcher.matches()) {
                    if (!file.exists()) {
                        file.mkdirs();
                    }
                    String []fileName = elems[i].split("\\.");
                    file = new File("E:" + File.separator + path.toString() + File.separator + fileName[0] + ".txt");
                }
                System.out.println("文件名称:"+ file.getName().toString());
                    try {
                        writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file)),512);
                        try {
                            writer.write(content);
                            writer.newLine();
                            writer.flush();
                            //writer.close();
                            System.out.println("创建文件成功");
                        } catch (IOException e) {                            
                            e.printStackTrace();
                        }
                        
                    } catch (FileNotFoundException e) {                        
                        e.printStackTrace();
                    }
                
            }
        }
    }
    /**
     * 获取页面的超链接并将其转换为正式的A标签
     * @param href
     * @return
     */
    public static String getHrefOfInOut(String href){
        String resultHref = null;
        if (href.startsWith("http://")) {//判断是否为外部链接
            resultHref = href;
        } else {
            if (href.startsWith("/")) {//为内部链接
                resultHref = "http://www.oschina.net" + href;
            }
        }
        return resultHref;
    }
    /**
     * 截取网页源文件的目标内容
     * @param content
     * @return
     */
    public static String getGoalContent(String content){
        String signContent = content.substring(content.indexOf("<pre class=\""));
        return signContent.substring(signContent.indexOf(">")+1, signContent.indexOf("</pre>"));
    }
    /**
     * 判断网页源文件是否含有目标文件
     * @param content
     * @return
     */
    public static int isHasGoalContent(String content){
        return content.indexOf("<pre class=\"");
    }
}
下载网页类:

package httpClient;

import java.io.IOException;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;

public class DownPage {
    /**
     * 根据url抓取网页内容
     * @param url
     * @return
     */
    public static String getContentFromUrl(String url){
        HttpClient client = new DefaultHttpClient();
        HttpGet getHttp = new HttpGet(url);
        String content = null;
        HttpResponse response;
        try {
            response = client.execute(getHttp);
            HttpEntity entity = response.getEntity();
            //把url加入到已抓取队列中
            VisitedUrlQueue.addElem(url);
            //转化为文本信息
            if (entity != null) {
                content = EntityUtils.toString(entity);
                System.out.println(content);
                System.out.println("-------------->>>");
                if (FunctionUtils.isCreateFile(url) && FunctionUtils.isHasGoalContent(content) != -1) {
                    //创建文件
                    FunctionUtils.createFile(FunctionUtils.getGoalContent(content), url);
                }
            }
        } catch(ClientProtocolException e){
            e.printStackTrace();
        }
        catch (IOException e) {            
            e.printStackTrace();
        }
        finally {
            client.getConnectionManager().shutdown();
        }
        return content;
    }
}
外部分装的公共产品类:

package httpClient;

import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;

public class UrlQueue {
    public final Queue<String> urlQueue = new ArrayBlockingQueue<String>(10);
    public final int MAX_SIZE = 10000;
    
    public void addElem(String url) throws InterruptedException{
        ((ArrayBlockingQueue<String>) urlQueue).put(url);
    }
    
    public String outElem() throws InterruptedException{
        return ((ArrayBlockingQueue<String>) urlQueue).take();
    }
    
    public boolean isContains(String url){
        return urlQueue.contains(url);
    }
    
    public int size(){
        return urlQueue.size();
    }
    
    public boolean isEmpty(){
        return urlQueue.isEmpty();
    }
}

package httpClient;

import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;

public class VisitedUrlQueue {
    public static Set<String> visitedUrlQueue = new ConcurrentSkipListSet<String>();
    
    public synchronized static void addElem(String url){
        visitedUrlQueue.add(url);
    }
    
    public synchronized static boolean isContains(String url){
        return visitedUrlQueue.contains(url);
    }
    
    public static int size(){
        return visitedUrlQueue.size();
    }
}

另外还有自定义扩展的TheadPoolExecutor:

package httpClient;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;

/**
 * 自定义线程池,实现计时和统计功能,并且自定义有界队列以及饱和策略
 * @author TongXueQiang
 * @date 2016/05/19
 */
public class MyThreadPoolExecutor extends ThreadPoolExecutor {
    private final ThreadLocal<Long> startTime = new ThreadLocal<Long>();
    private final Logger log = Logger.getLogger("MyThreadPoolExecutor");
    private final AtomicLong numTasks = new AtomicLong(1);
    private final AtomicLong totalTime = new AtomicLong();
    
    public MyThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit,
            BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);        
    }
    /**
     * 任务执行前
     */
    protected void beforeExecute(Thread t,Runnable r){
        super.beforeExecute(t, r);
        log.fine(String.format("Thread %s: start %s",t,r));
        startTime.set((long)(System.nanoTime()/Math.pow(10, 9)));
    }
    /**
     * 任务执行后
     * @param r 任务
     * @param t 执行任务的线程
     */
    protected void afterExecutor(Runnable r,Throwable t){
        try {
            Long endTime = (long) (System.nanoTime() / Math.pow(10,9));
            Long taskTime = endTime - startTime.get();
            numTasks.incrementAndGet();
            totalTime.addAndGet(taskTime);
            log.fine(String.format("Thread %s: end%s,time=%ds", taskTime));
        } finally {
            super.afterExecute(r, t);
        }
    }
    
    protected void terminated () {
        try {
            log.info(String.format("Terminated: avg time=%ds", totalTime.get() / numTasks.get()));
        } finally {
            super.terminated();
        }        
    }
}

佟氏出品,必属精品!坚持研究最底层的技术,坚持理论与编程结合,坚持写高质量代码,努力提升学术水平,尤其是机器学习算法的理论研究与创新,坚持阅读英文原版的学术论文,坚持原创,坚持独立思考!下一篇博客,上传一篇周国平的文章,具有很大的启发性,尤其是在互联网这样一个比较大的浮躁的环境下,所有的程序员和公司都应该深刻反省,什么才叫个性,什么叫成熟。怎样把自己的理念与企业文化融入到产品中,打造个性的产品,创造价值而不只是传播价值!

 

posted @ 2017-02-22 12:36  佟学强  阅读(591)  评论(0编辑  收藏  举报