函数式接口

72 阅读5分钟

下面用几个封装的工具类感受一下函数式接口的作用。

1. 模拟实现 Stream流

源码:

public class Stream<T> {

    private final Collection<T> collection;

    private Stream(Collection<T> collection) {
        this.collection = collection;
    }

    public static <T> Stream<T> of(Collection<T> collection) {
        return new Stream<>(collection);
    }

    @SuppressWarnings("all")
    public Optional<T> findAny() {
        List<T> newCollection = new ArrayList<>(collection);
        Collections.shuffle(newCollection);
        T result = null;
        for (T t : newCollection) {
            result = t;
            if (result != null) {
                return Optional.of(result);
            }
        }
        return Optional.empty();
    }


    public int sum(Function<T, Integer> function) {
        int sum = 0;
        for (T t : collection) {
            sum += function.apply(t);
        }
        return sum;
    }

    public T min(Comparator<T> comparator) {
        List<T> tempList = new ArrayList<>(collection);
        T min = tempList.get(0);
        for (T t : tempList) {
            if (comparator.compare(t, min) < 0) {
                min = t;
            }
        }
        return min;
    }

    public T max(Comparator<T> comparator) {
        List<T> tempList = new ArrayList<>(collection);
        T max = tempList.get(0);
        for (T t : tempList) {
            if (comparator.compare(t, max) > 0) {
                max = t;
            }
        }
        return max;
    }

    public double avg(Function<T, Integer> function) {
        double avg = 0;
        int count = this.count();
        for (T t : collection) {
            avg += function.apply(t)* 1.0/ count;
        }
        return avg;
    }

    public int count() {
        return collection.size();
    }

    public Stream<T> distinct() {
        Set<T> set = new HashSet<>(collection);
        return of(set);
    }

    public boolean anyMatch(Predicate<T> predicate) {
        for (T t : collection) {
            if (predicate.test(t)) {
                return true;
            }
        }
        return false;
    }


    public boolean allMatch(Predicate<T> predicate) {
        for (T t : collection) {
            if (!predicate.test(t)) {
                return false;
            }
        }
        return true;
    }

    public boolean noneMatch(Predicate<T> predicate) {
        for (T t : collection) {
            if (predicate.test(t)) {
                return false;
            }
        }
        return true;
    }

    public Optional<T> findFirst() {
        T res;
        for (T t : collection) {
            res = t;
            if (res != null) {
                return Optional.of(res);
            }
        }
        return Optional.empty();
    }

    public <K> Map<K, List<T>> groupBy(Function<T, K> keyMapper) {
        Map<K, List<T>> map = new HashMap<>();
        for (T t : collection) {
            if (map.containsKey(keyMapper.apply(t))) {
                map.get(keyMapper.apply(t)).add(t);
            } else {
                map.put(keyMapper.apply(t), new ArrayList<>(Collections.singletonList(t)));
            }
        }
        return map;
    }

    public Stream<T> sort(Comparator<T> comparator) {
        List<T> newCollection = new ArrayList<>(collection);
        newCollection.sort(comparator);
        return of(newCollection);
    }

    public Stream<T> filter(Predicate<T> predicate) {
        List<T> newCollection = new ArrayList<>();
        for (T t : collection) {
            if (predicate.test(t)) {
                newCollection.add(t);
            }
        }
        return of(newCollection);
    }

    public <V> Stream<V> map(Function<T, V> function) {
        List<V> newCollection = new ArrayList<>();
        for (T t : collection) {
            newCollection.add(function.apply(t));
        }
        return of(newCollection);
    }


    public void forEach(Consumer<T> consumer)
    {
        for (T t : collection) {
            consumer.accept(t);
        }
    }

    public <K, V> Map<K, V> collectToMap(Function<T, K> keyMapper, Function<T, V> valueMapper) {
        Map<K, V> map = new HashMap<>();
        for (T t : collection) {
            map.put(keyMapper.apply(t), valueMapper.apply(t));
        }
        return map;
    }

    public <C> C collect(Supplier<C> supplier, BiConsumer<C, T> biConsumer) {
        C c = supplier.get();
        for (T t : collection) {
            biConsumer.accept(c, t);
        }
        return c;
    }

    public T reduce(T init, BinaryOperator<T> binaryOperator) {
        T result = init;
        for (T t : collection) {
            result = binaryOperator.apply(result, t);
        }
        return result;
    }

}

使用:

public class StreamTest {



    @Test
    public void testMinMaxAvgSumCount() {
        List<Student> students = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            Student student = MockService.mock(Student.class);
            students.add(student);
        }
        Student maxAgeStudent = Stream.of(students).max(Comparator.comparing(Student::getAge));
        Student minAgeStudent = Stream.of(students).min(Comparator.comparing(Student::getAge));
        int ageSum = Stream.of(students).sum(Student::getAge);
        double ageAvg = Stream.of(students).avg(Student::getAge);
        int count = Stream.of(students).count();
        System.out.println(maxAgeStudent);
        System.out.println(minAgeStudent);
        System.out.println(ageSum);
        System.out.println(ageAvg);
        System.out.println(count);
    }

    @Test
    public void testFind() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 5);
        boolean hasZero = Stream.of(list).anyMatch(i -> i == 0);
        System.out.println(hasZero);
        boolean allGreaterThan0 = Stream.of(list).allMatch(i -> i > 0);
        System.out.println(allGreaterThan0);
        System.out.println(Stream.of(list).filter(i -> i == 2).findFirst().orElse(3));
        boolean allNotZero = Stream.of(list).noneMatch(i -> i == 0);
        System.out.println(allNotZero);
        int randomIntThatIsGreaterThan0 = Stream.of(list).filter(i -> i > 0).findAny().orElse(0);
        System.out.println(randomIntThatIsGreaterThan0);
    }

    @Test
    public void testDistinct() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        Stream.of(list).distinct().forEach(System.out::println);
    }


    @Test
    public void testGroupBy() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        Map<Integer, List<Integer>> resMap = Stream.of(list).groupBy(i -> i % 2);
        resMap.forEach((k, v) -> System.out.println(k + " : " + v));
    }

    @Test
    public void testFilter() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        Stream.of(list).filter(i -> i > 5).forEach(System.out::println);
    }

    @Test
    public void testMap() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        Stream.of(list).map(i -> i * 2).forEach(System.out::println);
    }

    @Test
    public void testReduce() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        System.out.println("求和 : " + Stream.of(list).reduce(0, Integer::sum));
        System.out.println("求积 : " + Stream.of(list).reduce(1, (a, b) -> a * b));
        System.out.println("求最大值 : " + Stream.of(list).reduce(Integer.MIN_VALUE, Integer::max));
        System.out.println("求最小值 : " + Stream.of(list).reduce(Integer.MAX_VALUE, Integer::min));
    }

    @Test
    public void testSort() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        Stream.of(list).sort((a, b) -> b - a).forEach(System.out::println);
    }

    @Test
    public void testCollect() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 4, 4, 5, 6, 6, 6, 67);
        HashSet<Integer> resultSet = Stream.of(list).collect(
            HashSet::new,
            HashSet::add
        );
        resultSet.forEach(System.out::println);
        StringBuilder res = Stream.of(list).collect(
            StringBuilder::new,
            StringBuilder::append
        );
        System.out.println(res);

        StringJoiner stringJoiner = Stream.of(list).collect(
            () -> new StringJoiner("-"),
            (str, i) -> str.add(i.toString())
        );

        System.out.println(stringJoiner.toString());

        Map<Integer, Integer> resMap = Stream.of(list).collectToMap(
            k -> k * 2,
            Function.identity()
        );
        System.out.println(resMap);
    }
}

2. redis 缓存穿透工具类

源码

public class RedisQueryUitls {

    private final static Jedis JEDIS = new Jedis("localhost", 6379);

    private  final static String EMPTY_VALUE_STR = "null";


    /**
     * 缓存穿透封装
     * @param key
     * @param supplier
     * @param clazz
     * @return
     * @param <T>
     */
    public static <T> T query(String key,
                              Supplier<T> supplier,
                              Class<T> clazz) {
        return query(key, null, null, supplier, clazz);
    }


    /**
     * 缓存穿透封装
     * @param key
     * @param expireTime
     * @param timeUnit
     * @param supplier
     * @param clazz
     * @return
     * @param <T>
     */
    public static <T> T query(String key,
                              Long expireTime, TimeUnit timeUnit,
                              Supplier<T> supplier,
                              Class<T> clazz)
    {
        String resultJsonStr = JEDIS.get(key);
        if (resultJsonStr != null && !resultJsonStr.isEmpty() && !EMPTY_VALUE_STR.equals(resultJsonStr)) {
            return JSON.parseObject(resultJsonStr, clazz);
        }
        if (EMPTY_VALUE_STR.equals(resultJsonStr)) {
            return null;
        }

        T result = supplier.get();
        String value = Objects.nonNull(result) ? JSON.toJSONString(result) : EMPTY_VALUE_STR;
        if (Objects.nonNull(expireTime)) {
            JEDIS.setex(key, timeUnit.toSeconds(expireTime), value);
        }
        else {
            JEDIS.set(key, value);
        }
        return result;
    }

}

使用

public class RedisQueryUtilsTest {

    @Test
    public void test() {
        int queryId = 1;
        Student student = RedisQueryUitls.query(
            "student:" + queryId,
            60L, TimeUnit.SECONDS,
            () -> mockGetStudentFromDB(queryId),
            Student.class
        );
        assert (Objects.isNull(student)) || student.getId() == queryId;
    }

    private Student mockGetStudentFromDB(int id) {
        Student student = new Student();
        student.setId(id);
        student.setName("张三");
        student.setAge(18);
        student.setSex(1);
        return student;
    }

}

3. 并发工具

该工具类主要是将大的任务拆分成n个小的任务,采用多线程去执行。然后采用 countDownLatch进行结果同步。

源码

package com.hdu.parallelUtils;

import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;


@Slf4j
public class ParallelExecuteTemplate {


    private static final Executor EXECUTOR = Executors.newFixedThreadPool(10);


    @SuppressWarnings("all")
    public static <K, V, Q> void parallelConsume(List<Q> list, int limitSize, ParallelConsume<Q> function) {

        if (list.size() <= limitSize) {
            function.execute(list);
        }

        int total = list.size();
        int threadNum = (total + limitSize - 1) / limitSize;
        CountDownLatch latch = new CountDownLatch(threadNum);
        Map<K, V> result = new ConcurrentHashMap<>();

        int start = 0;
        int end = 0;
        long curTime = System.currentTimeMillis();
        while (total > 0) {
            end = end + limitSize;
            asyncConsume(list, start, end, function, latch);
            start = end;
            total = total - limitSize;
        }
        try {
            latch.await();
            log.info("paramConsume success, cost : {} ms", System.currentTimeMillis() - curTime);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }


    @SuppressWarnings("all")
    public static <K, V, Q> Map<K, V> parallelGetMap(List<Q> list, int limitSize, ParallelGetMap<K, V, Q> function) {

        if (list.size() <= limitSize) {
            return new ConcurrentHashMap<>(function.execute(list));
        }

        int total = list.size();
        int threadNum = (total + limitSize - 1) / limitSize;
        CountDownLatch latch = new CountDownLatch(threadNum);
        Map<K, V> result = new ConcurrentHashMap<>();

        int start = 0;
        int end = 0;
        long curTime = System.currentTimeMillis();
        while (total > 0) {
            end = end + limitSize;

使用

import com.hdu.parallelUtils.ParallelExecuteTemplate;
import com.hdu.mockService.service.MockService;
import com.hdu.Student;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Slf4j
public class ParallelUtilsTest {

    public static List<Student> mockStudents = new ArrayList<>();

    static {
        /**
         * mock 学生数据
         */
        for (int i = 0; i < 1000; i++) {
            Student student = MockService.mock(Student.class);
            mockStudents.add(student);
        }
    }

    /**
     * 将1000个学生id 分成 20组并行处理(1000 / 50 = 20), 处理逻辑将 id * 2
     */
    @Test
    public void testParallelConsume() {
        List<Integer> ids = new ArrayList<>();
        for (int i = 1; i <= 1000; i++) {
            ids.add(i);
        }
        ParallelExecuteTemplate.parallelConsume(
                ids,
                50,
                idList -> {
                    for (Integer id : idList) {
                        for (Student mockStudent : mockStudents) {
                            if (mockStudent.getId() == id) {
                                mockStudent.setId(id * 2);
                            }
                        }
                    }
                }
        );
        mockStudents.forEach(student -> log.info("student : {}", student));
    }

    /**
     * 将1000个学生id 分成 20组并行处理(1000 / 50 = 20), 处理逻辑是查出这些学生 返回 List<Student>
     */
    @Test
    public void testParallelGetList() {
        List<Integer> ids = new ArrayList<>();
        for (int i = 1; i <= 1000; i++) {
            ids.add(i);
        }
        List<Student> students = ParallelExecuteTemplate.parallelGetList(
                ids,
                50,
                idList -> {
                    List<Student> studentList = new ArrayList<>();
                    for (Integer id : idList) {
                        for (Student mockStudent : mockStudents) {
                            if (mockStudent.getId() == id) {
                                studentList.add(mockStudent);
                            }
                        }
                    }
                    return studentList;
                }
        );
        int total = students.size();
        log.info("total query size : {}", total);
    }

    /**
     * 将1000个学生id 分成 20组并行处理(1000 / 50 = 20), 处理逻辑是查出这些学生 并且按照id分组 <id, List<Name>>
     */
    @Test
    public  void testParallelGetMap() {
        List<Integer> ids = new ArrayList<>();
        for (int i = 1; i <= 1000; i++) {
            ids.add(i);
        }
        Map<Integer, List<String>> id2studentName = ParallelExecuteTemplate.parallelGetMap(
                ids,
                50,
                subIdList -> {
                    Map<Integer, List<String>> id2names = new HashMap<>();
                    for (Integer id : subIdList) {
                        for (Student mockStudent : mockStudents) {
                            if (mockStudent.getId() == id) {
                                id2names.putIfAbsent(
                                        id,
                                        new ArrayList<>()
                                );
                                id2names.get(id).add(mockStudent.getName());
                            }
                        }
                    }
                    return id2names;
                }
        );
        int total = 0;
        for (Integer id : id2studentName.keySet()) {
            total += id2studentName.get(id).size();
        }
        log.info("total query size : {}", total);
        log.info("{}", id2studentName);
    }
}

4. 源码

functionDemo: functionDemo (gitee.com)