自定义阻塞队列

71 阅读1分钟

自定义阻塞队列

 package thread01;

import javax.sound.midi.SoundbankResource;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
 * Created by andy on 2020/10/9.
 */
public class MyLinkedBlockingQueue<E>{
    private final AtomicInteger count = new AtomicInteger();
    private int capacity;
    public MyLinkedBlockingQueue(int capacity) {
        this.capacity = capacity;
        last = head = new MyLinkedBlockingQueue.Node<E>(null);
    }
    static class Node<E> {
        E item;
        MyLinkedBlockingQueue.Node<E> next;
        Node(E x) { item = x; }
    }
    private Node<E> head;
    private Node<E> last;
    private ReentrantLock putLock = new ReentrantLock();
    private ReentrantLock takeLock = new ReentrantLock();

    private Condition notFull = putLock.newCondition();
    private Condition notEmpty = takeLock.newCondition();

    public void put(E e) throws InterruptedException {
        int c = -1;
        Node node = new Node<E>(e);
        putLock.lockInterruptibly();
        try{
            while (this.count.get() == capacity) {
                notFull.await();
            }
            enqueue(node);
            c = count.getAndIncrement();
            if (c + 1 < capacity)
                notFull.signal();
        }finally {
            putLock.unlock();
        }
        if (c == 0)  // 这一步很重要,通知消费者消费
            signalNotEmpty();
    }

    public E take() throws InterruptedException {
        E x = null;
        int c = -1;
        takeLock.lockInterruptibly();
        try {
            while (this.count.get() == 0) {
                notEmpty.await();
            }
            x = dequeue();
            c = count.getAndDecrement();
            if(c > 1) {
                notEmpty.signal();
            }
        } finally {
            takeLock.unlock();
        }
        if (c == capacity)
            signalNotFull();
        return x;
    }

    private void signalNotEmpty() {
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lock();
        try {
            notEmpty.signal();
        } finally {
            takeLock.unlock();
        }
    }

    private void signalNotFull() {
        final ReentrantLock putLock = this.putLock;
        putLock.lock();
        try {
            notFull.signal();
        } finally {
            putLock.unlock();
        }
    }

    private void enqueue(Node node) {
        last = last.next = node;
    }
    private E dequeue() {
        MyLinkedBlockingQueue.Node<E> h = head;
        MyLinkedBlockingQueue.Node<E> first = h.next;
        h.next = h; // help GC
        head = first;
        E x = first.item;
        first.item = null;
        return x;
    }

    public static void main(String[] args) throws InterruptedException {
        final MyLinkedBlockingQueue queue = new MyLinkedBlockingQueue<Integer>(1);
        Thread producer = new Thread(new Runnable() {
            public void run() {
                try {
                    for (int i = 0; i < 10; i++) {
                        queue.put(i);
                        System.out.println("生产元素: " + i);
                        Thread.sleep(100);
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });
        Thread consumer = new Thread(new Runnable() {
            public void run() {
                try {
                    for (int i = 0; i < 10; i++) {
                        System.out.println("消费元素: " + queue.take());
                        Thread.sleep(100);
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });
        producer.start();
        consumer.start();
        producer.join();
        consumer.join();
    }
}