用java做物品识别和姿态识别

3,258 阅读4分钟

前言

之前搞得语音识别突然发现浏览器就有接口可以直接用,而且识别又快又准,参考:使用 JavaScript 的 SpeechRecognition API 实现语音识别_speechrecognition js-CSDN博客

进入正题

这个功能首先要感谢一下作者常康,仓库地址(gitee.com/agriculture… 这个项目很早之前就关注了,最近这段时间正好要用才真正实践了一下,只是初步测试了一下,在性能方面还需要进一步测试,本人电脑就很拉识别就很卡。

先看效果

20240912_090041 00_00_00-00_00_30.gif

20240912_091337 00_00_00-00_00_08 00_00_00-00_00_30.gif

改动

主要对姿态识别做了一些小改动,将原图片识别改成视频视频识别,如果要调用摄像头将video.open(0);的代码注释放开即可

package cn.ck;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import cn.ck.config.PEConfig;
import cn.ck.domain.KeyPoint;
import cn.ck.domain.PEResult;
import cn.ck.utils.Letterbox;
import nu.pattern.OpenCV;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.highgui.HighGui;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import org.opencv.videoio.Videoio;

import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/*
 *  姿态识别,可以识别动作等等.,比如跳绳技术
 */
public class PoseEstimation {

    static {
        // 加载opencv动态库
        //System.load(ClassLoader.getSystemResource("lib/opencv_java470-无用.dll").getPath());
        OpenCV.loadLocally();
    }

    public static void main(String[] args) throws OrtException {

        String model_path = "src\main\resources\model\yolov7-w6-pose-nms.onnx";
        // 加载ONNX模型
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        OrtSession session = environment.createSession(model_path, sessionOptions);
        // 输出基本信息
        session.getInputInfo().keySet().forEach(x -> {
            try {
                System.out.println("input name = " + x);
                System.out.println(session.getInputInfo().get(x).getInfo().toString());
            } catch (OrtException e) {
                throw new RuntimeException(e);
            }
        });

        VideoCapture video = new VideoCapture();

//        video.open(0);  //获取电脑上第0个摄像头

        //可以把识别后的视频在通过rtmp转发到其他流媒体服务器,就可以远程预览视频后视频,需要使用ffmpeg将连续图片合成flv 等等,很简单。
        if (!video.isOpened()) {
            System.err.println("打开视频流失败,未检测到监控,请先用vlc软件测试链接是否可以播放!,下面试用默认测试视频进行预览效果!");
            video.open("video/test2.mp4");
        }
        // 跳帧检测,一般设置为3,毫秒内视频画面变化是不大的,快了无意义,反而浪费性能
        int detect_skip = 4;

        // 跳帧计数
        int detect_skip_index = 1;

        // 最新一帧也就是上一帧推理结果
        float[][] outputData   = null;

        //当前最新一帧。上一帧也可以暂存一下
        Mat img = new Mat();


// 在这里先定义下线的粗细、关键的半径(按比例设置大小粗细比较好一些)
        int minDwDh = Math.min((int)video.get(Videoio.CAP_PROP_FRAME_WIDTH), (int)video.get(Videoio.CAP_PROP_FRAME_HEIGHT));
        int thickness = minDwDh / PEConfig.lineThicknessRatio;
        int radius = minDwDh / PEConfig.dotRadiusRatio;
            // 转换颜色空间
            Mat image = new Mat();

            // 图像预处理
            Letterbox letterbox = new Letterbox();
            letterbox.setNewShape(new Size(960, 960));
            letterbox.setStride(64);


// 使用多线程和GPU可以提升帧率,线上项目必须多线程!!!,一个线程拉流,将图像存到[定长]队列或数组或者集合,一个线程模型推理,中间通过变量或者队列交换数据,代码示例仅仅使用单线程
        while (video.read(img)) {
            if ((detect_skip_index % detect_skip == 0) || outputData == null) {
                Imgproc.cvtColor(img, image, Imgproc.COLOR_BGR2RGB);
                image = letterbox.letterbox(image);
                int rows = letterbox.getHeight();
                int cols = letterbox.getWidth();
                int channels = image.channels();
                // 将图像转换为模型输入格式
                float[] pixels = new float[channels * rows * cols];
                for (int i = 0; i < rows; i++) {
                    for (int j = 0; j < cols; j++) {
                        double[] pixel = image.get(j, i);
                        for (int k = 0; k < channels; k++) {
                            pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f;
                        }
                    }
                }
                detect_skip_index = 1;
                OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), new long[]{1L, (long) channels, (long) rows, (long) cols});
                OrtSession.Result output = session.run(Collections.singletonMap(session.getInputInfo().keySet().iterator().next(), tensor));

                // 处理输出结果并绘制
               outputData = ((float[][]) output.get(0).getValue());
            }else{
                detect_skip_index = detect_skip_index + 1;
            }
            double ratio = letterbox.getRatio();
            double dw =letterbox.getDw();
            double dh = letterbox.getDh();
            List<PEResult> peResults = new ArrayList<>();
            for (float[] outputDatum : outputData) {
                PEResult result = new PEResult(outputDatum);
                if (result.getScore() > PEConfig.personScoreThreshold) {
                    peResults.add(result);
                }
            }

            // 对结果进行非极大值抑制
            peResults = nms(peResults, PEConfig.IoUThreshold);

            for (PEResult peResult: peResults) {
                System.out.println(peResult);
                // 画框
                Point topLeft = new Point((peResult.getX0()-dw)/ratio, (peResult.getY0()-dh)/ratio);
                Point bottomRight = new Point((peResult.getX1()-dw)/ratio, (peResult.getY1()-dh)/ratio);
                // Imgproc.rectangle(img, topLeft, bottomRight, new Scalar(255,0,0), thickness);
                List<KeyPoint> keyPoints = peResult.getKeyPointList();
                // 画点
                keyPoints.forEach(keyPoint->{
                    if (keyPoint.getScore()>PEConfig.keyPointScoreThreshold) {
                        Point center = new Point((keyPoint.getX()-dw)/ratio, (keyPoint.getY()-dh)/ratio);
                        Scalar color = PEConfig.poseKptColor.get(keyPoint.getId());
                        Imgproc.circle(img, center, radius, color, -1); //-1表示实心
                    }
                });
                // 画线
                for (int i = 0; i< PEConfig.skeleton.length; i++){
                    int indexPoint1 = PEConfig.skeleton[i][0]-1;
                    int indexPoint2 = PEConfig.skeleton[i][1]-1;
                    if ( keyPoints.get(indexPoint1).getScore()>PEConfig.keyPointScoreThreshold &&
                            keyPoints.get(indexPoint2).getScore()>PEConfig.keyPointScoreThreshold ) {
                        Scalar coler = PEConfig.poseLimbColor.get(i);
                        Point point1 = new Point(
                                (keyPoints.get(indexPoint1).getX()-dw)/ratio,
                                (keyPoints.get(indexPoint1).getY()-dh)/ratio
                        );
                        Point point2 = new Point(
                                (keyPoints.get(indexPoint2).getX()-dw)/ratio,
                                (keyPoints.get(indexPoint2).getY()-dh)/ratio
                        );
                        Imgproc.line(img, point1, point2, coler, thickness);
                    }
                }
            }
            //服务器部署:由于服务器没有桌面,所以无法弹出画面预览,主要注释一下代码
            HighGui.imshow("result", img);

            // 多次按任意按键关闭弹窗画面,结束程序
            if(HighGui.waitKey(1) != -1){
                break;
            }
        }

        HighGui.destroyAllWindows();
        video.release();
        System.exit(0);

    }

    public static List<PEResult> nms(List<PEResult> boxes, float iouThreshold) {
        // 根据score从大到小对List进行排序
        boxes.sort((b1, b2) -> Float.compare(b2.getScore(), b1.getScore()));
        List<PEResult> resultList = new ArrayList<>();
        for (int i = 0; i < boxes.size(); i++) {
            PEResult box = boxes.get(i);
            boolean keep = true;
            // 从i+1开始,遍历之后的所有boxes,移除与box的IOU大于阈值的元素
            for (int j = i + 1; j < boxes.size(); j++) {
                PEResult otherBox = boxes.get(j);
                float iou = getIntersectionOverUnion(box, otherBox);
                if (iou > iouThreshold) {
                    keep = false;
                    break;
                }
            }
            if (keep) {
                resultList.add(box);
            }
        }
        return resultList;
    }
    private static float getIntersectionOverUnion(PEResult box1, PEResult box2) {
        float x1 = Math.max(box1.getX0(), box2.getX0());
        float y1 = Math.max(box1.getY0(), box2.getY0());
        float x2 = Math.min(box1.getX1(), box2.getX1());
        float y2 = Math.min(box1.getY1(), box2.getY1());
        float intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
        float box1Area = (box1.getX1() - box1.getX0()) * (box1.getY1() - box1.getY0());
        float box2Area = (box2.getX1() - box2.getX0()) * (box2.getY1() - box2.getY0());
        float unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }
}

姿态识别模型提取链接, 通过网盘分享的文件:yolov7-w6-pose-nms.onnx 链接: pan.baidu.com/s/1UdAUPWr1… 提取码: du6y

后言

就像原作者说的,不是每个同学都会python,不是每个项目都是python语言开发,不是每个岗位都会深度学习。
希望java在AI领域能有更好的发展