数据增强代码实现——即插即用(持续更新中)

304 阅读7分钟

前言

第一次做博客,既当作学习过程的记录,也希望这种开源精神能一直激励自己持续学习吧。其中一些代码完全是自己手搓的,还有一些是在github上面找大佬的代码,根据自己理解和修改为适合自己使用的,本文代码的重点就是即插即用,没有繁琐的使用过程。如代码中有错误的地方,欢迎大家的批评指正。

参照的Github网址: DataAugmentationForObjectDetection/data_aug at master · Paperspace/DataAugmentationForObjectDetection · GitHub

工具库导入

import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom

原图

2007_000032.jpg

1.旋转

其实就是图片按照90、180、270的角度进行旋转 需要注意的是旋转后图片的shape可能会发生改变

def rotate_img(img_path, new_img_path, angle):
        img = cv2.imread(img_path)
        # angle 取值为 1 2 3  表示顺时针旋转 90 180 270
        if angle == 1:
            rotate = cv2.rotate(img, rotateCode=cv2.ROTATE_90_CLOCKWISE)
        elif angle == 2:
            rotate = cv2.rotate(img, rotateCode=cv2.ROTATE_180)
        elif angle == 3:
            rotate = cv2.rotate(img, rotateCode=cv2.ROTATE_90_COUNTERCLOCKWISE)
        else:
            return
        # 图片直接写入到本地
        cv2.imwrite(new_img_path, rotate)

旋转180度

image.png

2.水平垂直翻转

def a_MirrorImg(img_path, img_write_path, xml_path=None, new_xml_path=None):
    img = cv2.imread(img_path)
    mirror_img = cv2.flip(img, -1)
    cv2.imwrite(img_write_path, mirror_img)
    if xml_path:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        objects = root.findall("object")
        for obj in objects:
            bbox = obj.find('bndbox')
            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text)
            y2 = float(bbox.find('ymax').text)
​
            x1 = w - x1 + 1
            x2 = w - x2 + 1
​
            y1 = h - y1 + 1
            y2 = h - y2 + 1
​
            assert x1 > 0
            assert x2 > 0
            assert y1 > 0
            assert y2 > 0
​
            bbox.find('xmin').text = str(int(x2))
            bbox.find('xmax').text = str(int(x1))
            bbox.find('ymin').text = str(int(y2))
            bbox.find('ymax').text = str(int(y1))
​
        tree.write(new_xml_path)  # 保存修改后的XML文件

image.png

3.垂直翻转

def v_MirrorImg(img_path, img_write_path, xml_path = None, new_xml_path=None):
    img = cv2.imread(img_path)
    mirror_img = cv2.flip(img, 0)
    cv2.imwrite(img_write_path, mirror_img)
    if xml_path:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        size = root.find('size')
        h = int(size.find('height').text)
        objects = root.findall("object")
        for obj in objects:
            bbox = obj.find('bndbox')
            y1 = float(bbox.find('ymin').text)
            y2 = float(bbox.find('ymax').text)
​
            y1 = h - y1 + 1
            y2 = h - y2 + 1
​
            assert y1 > 0
            assert y2 > 0
​
            bbox.find('ymin').text = str(int(y2))
            bbox.find('ymax').text = str(int(y1))
​
        tree.write(new_xml_path)  # 保存修改后的XML文件

image.png

4.水平镜像翻转

# 水平镜像翻转
def h_MirrorImg(img_path, img_write_path, xml_path = None, new_xml_path = None):
    img = cv2.imread(img_path)
    mirror_img = cv2.flip(img, 1)
    cv2.imwrite(img_write_path, mirror_img)
    if xml_path: 
        tree = ET.parse(xml_path)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        objects = root.findall("object")
        for obj in objects:
            bbox = obj.find('bndbox')
            x1 = float(bbox.find('xmin').text)
            x2 = float(bbox.find('xmax').text)
            x1 = w - x1 + 1
            x2 = w - x2 + 1
​
            assert x1 > 0
            assert x2 > 0
​
            bbox.find('xmin').text = str(int(x2))
            bbox.find('xmax').text = str(int(x1))
​
        tree.write(new_xml_path)  # 保存修改后的XML文件

image.png

5.亮度——伽马值调整

# 亮度调整
# alpha大于1那么就变亮 小于1就变暗 保留一位小数即可
def getColorImg(img_path, new_img_path, alpha=1.4, beta=0, xml_path=None, new_xml_path=None):
    img = cv2.imread(img_path)
    colored_img = np.uint8(np.clip((alpha * img + beta), 0, 255))
    cv2.imwrite(new_img_path, colored_img)
    if xml_path:
        tree = ET.parse(xml_path)
        tree.write(new_xml_path)

image.png

6.高斯噪点

def gaussian_noise(img_path, new_img_path, mean=0, sigma=0.1, xml_path=None, new_xml_path=None):
    img = cv2.imread(img_path)
    '''
    此函数用将产生的高斯噪声加到图片上
    均值为0,是保证图像的亮度不会有变化,而方差大小则决定了高斯噪声的强度。
    方差/标准差越大,噪声越强。
​
    传入参数:
        img   :  原图
        mean  :  均值
        sigma :  标准差
    返回值:
        gaussian_out : 噪声处理后的图片
    '''
    # 将图片灰度标准化
    img = img / 255
    # 产生高斯 noise
    noise = np.random.normal(mean, sigma, img.shape)
    # 将噪声和图片叠加
    gaussian_out = img + noise
    # 将超过 1 的置 1,低于 0 的设置为 0
    gaussian_out = np.clip(gaussian_out, 0, 1)
    # 将图片灰度范围的恢复为 0-255
    gaussian_out = np.uint8(gaussian_out * 255)
    # 将噪声范围搞为 0-255
    # noise = np.uint8(noise*255)
    cv2.imwrite(new_img_path, gaussian_out)
    if xml_path:
        tree = ET.parse(xml_path)
        tree.write(new_xml_path)

image.png

7.椒盐噪点

​
def sp_noise(img_path, new_img_path, prob, xml_path=None, new_xml_path=None):
    '''
    添加椒盐噪声
    prob:噪声比例
    '''
    image = cv2.imread(img_path)
    output = np.zeros(image.shape, np.uint8)  # 创建空矩阵 数值范围为0-255整型
    thres = 1 - prob
​
    for i in range(image.shape[0]):  # 横向
        for j in range(image.shape[1]):  # 纵向 j表示每一个像素位置
            rdn = random.random()  # 从0-1中随机挑选一个浮点型数字
            if rdn < prob:  # 如果我的随机数字小于我的阈值4 该点像素为椒点 如果prob设为1则图片全黑
                output[i][j] = 0  # 椒点 
            elif rdn > thres:
                output[i][j] = 255  # 盐点
            else:  # rdn在 [prob,thres]时 不修改
                output[i][j] = image[i][j]  # 不进行修改
​
    cv2.imwrite(new_img_path, output)
    if xml_path:
        tree = ET.parse(xml_path)
        tree.write(new_xml_path)

image.png

8.任意角度旋转

使用仿射矩阵可以实现按照图片任意点进行任意角度旋转 一般该数据增强只用于图像分类模型中,不太适合那种不是旋转标注框的数据增强

ef all_angle_rotate_img(img_path, new_img_path, angle=30):
    '''
    img   --image
    angle --rotation angle
    return--rotated img
    '''
    img = cv2.imread(img_path)
    h, w = img.shape[:2]
    rotate_center = (w / 2, h / 2)  # 也可以自己设置 这里设置的是绕中心点进行旋转
    # 获取旋转矩阵
    # 参数1为旋转中心点;
    # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转
    # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍
    M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0)
    # 计算图像新边界
    # new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0]))
    # new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1]))
    # 调整旋转矩阵以考虑平移
    # M[0, 2] += (new_w - w) / 2
    # M[1, 2] += (new_h - h) / 2
    # 仿射变换
    rotated_img = cv2.warpAffine(img, M, (w, h))
    cv2.imwrite(new_img_path, rotated_img)

image.png

计算工具

1.get_boxes

输入xml文件的地址,返回文件内所有object标注的坐标信息以及在该图片中所有物体的类别的集合。

def get_boxes(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    size = root.find('size')
    h = int(size.find('height').text)
    objects = root.findall("object")
    boxes = []
    categories = []
    # 获取到所有的种类,并且去重创建一个种类字典
    for obj in objects:
        category_elem = obj.find('name')
        category = category_elem.text
        categories.append(category)
    key = list(set(categories))
    values = [value for value in range(len(key))]
    # 创建字典
    category_dict = dict(zip(key, values))  # {'aeroplane': 0, 'person': 1}
    for obj in objects:
        # num表示其所属类别
        bbox = obj.find('bndbox')
        x1 = float(bbox.find('xmin').text)
        y1 = float(bbox.find('ymin').text)
        x2 = float(bbox.find('xmax').text)
        y2 = float(bbox.find('ymax').text)
        category_elem = obj.find('name')
        category = category_elem.text
        boxes.append([x1, y1, x2, y2, category_dict[category]])
    # 返回标注框的信息以及所有物体的类别
    return np.array(boxes), key

2.draw_rect

输入图片路径、get_boxes获取到的标注框信息、以及框的颜色信息,默认标注框为白色。

def draw_rect(img_path, cords, color=None):
    """Draw the rectangle on the image
    
    Parameters
    ----------
    
    im : numpy.ndarray
        numpy image 这里的图片是进行色彩空间转换后的
    
    cords: numpy.ndarray
        Numpy array containing bounding boxes of shape `N X 4` where N is the 
        number of bounding boxes and the bounding boxes are represented in the
        format `x1 y1 x2 y2`
        
    Returns
    -------
    
    numpy.ndarray
        numpy image with bounding boxes drawn on it
        
    """
    im = cv2.imread(img_path)
    im = im.copy()
​
    cords = cords[:, :4]
    cords = cords.reshape(-1, 4)
    if not color:
        color = [255, 255, 255]
    for cord in cords:
        pt1, pt2 = (cord[0], cord[1]), (cord[2], cord[3])
​
        pt1 = int(pt1[0]), int(pt1[1])
        pt2 = int(pt2[0]), int(pt2[1])
​
        im = cv2.rectangle(im.copy(), pt1, pt2, color, int(max(im.shape[:2]) / 200))
        # 返回绘制好的im
    return im

3.bbox_area

返回标注框面积

def bbox_area(bbox):
    return (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])

4.clip_box

输入图片未发生转换前标注框的信息、图片转换后标注框也对应转换后的信息、如果转换后的框的面积大于原本框的面积*alpha那么就保留,否则就舍去。alpha越大,框被舍去的可能性就越高。

返回裁剪后的标注框的信息,裁剪后标注框其所属类别的索引。

def clip_box(bbox, clip_box, alpha):
    """Clip the bounding boxes to the borders of an image
​
    Parameters
    ----------
​
    bbox: numpy.ndarray
        Numpy array containing bounding boxes of shape `N X 5` where N is the
        number of bounding boxes and the bounding boxes are represented in the
        format `x1 y1 x2 y2 class`
​
    clip_box: numpy.ndarray
        An array of shape (4,) specifying the diagonal co-ordinates of the image
        The coordinates are represented in the format `x1 y1 x2 y2`
​
    alpha: float
        If the fraction of a bounding box left in the image after being clipped is
        less than `alpha` the bounding box is dropped.
​
    Returns
    -------
​
    numpy.ndarray
        Numpy array containing **clipped** bounding boxes of shape `N X 5` where N is the
        number of bounding boxes left after being clipped and the bounding boxes are represented in the
        format `x1 y1 x2 y2 class` 最后一列为所属类别index
​
    """
​
    ar_ = (bbox_area(bbox))
    x_min = np.maximum(bbox[:, 0], clip_box[0]).reshape(-1, 1)
    y_min = np.maximum(bbox[:, 1], clip_box[1]).reshape(-1, 1)
    x_max = np.minimum(bbox[:, 2], clip_box[2]).reshape(-1, 1)
    y_max = np.minimum(bbox[:, 3], clip_box[3]).reshape(-1, 1)
​
    bbox = np.hstack((x_min, y_min, x_max, y_max, bbox[:, 4:]))
    # 丢失掉的面积
    delta_area = ((ar_ - bbox_area(bbox)) / ar_)
​
    mask = (delta_area < (1 - alpha)).astype(int)
​
    bbox = bbox[mask == 1, :]
​
    num = bbox[:, -1]
    bb = bbox[:, :-1]
    # 返回裁剪后的标注框和其所属类别的index
    return bb, num

5.save_xml_as_path

输入新的xml文件保存路径、图片路径、标注框坐标、裁剪之后剩下的目标物体其所属的种类

def save_xml_as_path(new_xml_path, img_path, bboxes, categories):
    root = ET.Element("annotation")
    img = cv2.imread(img_path)
    h, w, c = img.shape
    img_path.split()
    # 创建子元素
    folder = ET.SubElement(root, "folder")
​
    filename = ET.SubElement(root, "filename")
​
    path = ET.SubElement(root, "path")
​
    source = ET.SubElement(root, "source")
    size = ET.SubElement(root, "size")
    segmented = ET.SubElement(root, "segmented")
    segmented.text = '0'
    database = ET.SubElement(source, "database")
    database.text = "Unknown"
    # 图片大小这里暂且设置为0
    width = ET.SubElement(size, "width")
    width.text = str(w)
    height = ET.SubElement(size, "height")
    height.text = str(h)
    depth = ET.SubElement(size, "depth")
    depth.text = str(c)
​
    for index, category in enumerate(categories):
        # object需要根据bboxes和categories进行循环添加
        object = ET.SubElement(root, "object")
        name = ET.SubElement(object, "name")
        name.text = category
        pose = ET.SubElement(object, "pose")
        pose.text = 'Unspecified'
        truncated = ET.SubElement(object, "truncated")
        truncated.text = '0'
        difficult = ET.SubElement(object, "difficult")
        difficult.text = '0'
        bndbox = ET.SubElement(object, "bndbox")
        # bndbox结构
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(bboxes[index][0])
        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(bboxes[index][1])
        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(bboxes[index][2])
        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(bboxes[index][3])
​
    # 创建XML树
    tree = ET.ElementTree(root)
​
    # 将XML树转换为字符串
    xml_str = ET.tostring(root)
​
    # 格式化XML字符串
    dom = minidom.parseString(xml_str)
    formatted_xml = dom.toprettyxml(indent="  ")
    print(formatted_xml)
    # 将格式化后的XML保存为XML文件
    with open(new_xml_path, "w") as file:
        file.write(formatted_xml)

9.随机平移

class RandomTranslate(object):
    """Randomly Translates the image
​
​
    Bounding boxes which have an area of less than 25% in the remaining in the
    transformed image is dropped. The resolution is maintained, and the remaining
    area if any is filled by black color.
​
    Parameters
    ----------
    translate: float or tuple(float)
        if **float**, the image is translated by a factor drawn
        randomly from a range (1 - `translate` , 1 + `translate`). If **tuple**,
        `translate` is drawn randomly from values specified by the
        tuple
​
    Returns
    -------
​
    numpy.ndaaray
        Translated image in the numpy format of shape `HxWxC`
​
    numpy.ndarray
        Tranformed bounding box co-ordinates of the format `n x 4` where n is
        number of bounding boxes and 4 represents `x1,y1,x2,y2` of the box
​
    """
​
    def __init__(self, translate=0.2, diff=False):
        # diff=False默认宽高移动一致
        self.translate = translate
​
        if type(self.translate) == tuple:
            assert len(self.translate) == 2, "Invalid range"
            assert self.translate[0] > 0 & self.translate[0] < 1
            assert self.translate[1] > 0 & self.translate[1] < 1
​
        else:
            assert self.translate > 0 and self.translate < 1
            self.translate = (-self.translate, self.translate)
​
        self.diff = diff
​
    def __call__(self, img_path, new_img_path, xml_path=None, new_xml_path=None):
        # Chose a random digit to scale by
        img = cv2.imread(img_path)
        img_shape = img.shape
​
        # translate the image
​
        # percentage of the dimension of the image to translate
        translate_factor_x = random.uniform(*self.translate)
        translate_factor_y = random.uniform(*self.translate)
​
        if not self.diff:
            translate_factor_y = translate_factor_x
​
        canvas = np.zeros(img_shape).astype(np.uint8)
​
        corner_x = int(translate_factor_x * img.shape[1])
        corner_y = int(translate_factor_y * img.shape[0])
​
        # change the origin to the top-left corner of the translated box
        orig_box_cords = [max(0, corner_y), max(corner_x, 0), min(img_shape[0], corner_y + img.shape[0]),
                          min(img_shape[1], corner_x + img.shape[1])]
​
        mask = img[max(-corner_y, 0):min(img.shape[0], -corner_y + img_shape[0]),
               max(-corner_x, 0):min(img.shape[1], -corner_x + img_shape[1]), :]
        canvas[orig_box_cords[0]:orig_box_cords[2], orig_box_cords[1]:orig_box_cords[3], :] = mask
        img = canvas
        # 保存图片
        cv2.imwrite(new_img_path, img)
​
        if xml_path:
            # 如果xml不为空
            bboxes, categories = get_boxes(xml_path)
            print("所有的种类", categories)
            # 只对前四列进行操作 不修改类别信息
            bboxes[:, :4] += [corner_x, corner_y, corner_x, corner_y]
            bboxes, index = clip_box(bboxes, [0, 0, img_shape[1], img_shape[0]], 0.25)
            classes_num = list(map(int, index))
            print("裁剪后的种类的index", classes_num)
            class_names = [categories[idx] for idx in classes_num]
            print("裁剪后的种类", class_names)
            # 进行裁剪之后 object也会发生改变
            print("转换后:", bboxes)
            # 将bboxes转为xml保存
            save_xml_as_path(new_xml_path, img_path, bboxes, class_names)
            print("xml文件保存完成")
            # 返回转换后的标注框信息 可以用于验证转换是否正确
        # return bboxes
        

使用示范

if __name__ == '__main__':
    img_path = "2007_000032.jpg"
    xml_path = "2007_000032.xml"
    new_img_path = '1.jpg'
    new_xml_path = '1.xml'
​
    trans = RandomTranslate(translate=0.4, diff=True)
    # 输入参数调用call魔法方法
    trans(img_path, new_img_path, xml_path, new_xml_path)
    # 测试,根据新的xml获取bbox进行画图 如果不进行xml转换则不用下面两步
    bb, key = get_boxes(new_xml_path)
    plot_img = draw_rect(new_img_path, bb)
    cv2.imshow("img", plot_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

image.png

10.随机缩放

进行随机缩放后,其实只是对原图片进行了画面的进行了缩放,但是其图片的大小还是不变的

RandomScale的使用方法以及clip_box、get_boxes等函数皆与前面一致。

class RandomScale(object):
    """Randomly scales an image
​
​
    Bounding boxes which have an area of less than 25% in the remaining in the
    transformed image is dropped. The resolution is maintained, and the remaining
    area if any is filled by black color.
​
    Parameters
    ----------
    scale: float or tuple(float)
        if **float**, the image is scaled by a factor drawn
        randomly from a range (1 - `scale` , 1 + `scale`). If **tuple**,
        the `scale` is drawn randomly from values specified by the
        tuple
​
    Returns
    -------
​
    numpy.ndaaray
        Scaled image in the numpy format of shape `HxWxC`
​
    numpy.ndarray
        Tranformed bounding box co-ordinates of the format `n x 4` where n is
        number of bounding boxes and 4 represents `x1,y1,x2,y2` of the box
​
    """
​
    def __init__(self, scale=0.2, diff=False):
        self.scale = scale
​
        if type(self.scale) == tuple:
            assert len(self.scale) == 2, "Invalid range"
            assert self.scale[0] > -1, "Scale factor can't be less than -1"
            assert self.scale[1] > -1, "Scale factor can't be less than -1"
        else:
            assert self.scale > 0, "Please input a positive float"
            self.scale = (max(-1, -self.scale), self.scale)
​
        self.diff = diff
​
    def __call__(self, img_path, new_img_path, xml_path=None, new_xml_path=None):
        img = cv2.imread(img_path)
        # Chose a random digit to scale by
​
        img_shape = img.shape
        print("原始图像大小", img_shape)
        if self.diff:
            scale_x = random.uniform(*self.scale)
            scale_y = random.uniform(*self.scale)
        else:
            scale_x = random.uniform(*self.scale)
            scale_y = scale_x
​
        resize_scale_x = 1 + scale_x
        resize_scale_y = 1 + scale_y
​
        img = cv2.resize(img, None, fx=resize_scale_x, fy=resize_scale_y)
​
        canvas = np.zeros(img_shape, dtype=np.uint8)
​
        y_lim = int(min(resize_scale_y, 1) * img_shape[0])
        x_lim = int(min(resize_scale_x, 1) * img_shape[1])
​
        canvas[:y_lim, :x_lim, :] = img[:y_lim, :x_lim, :]
​
        img = canvas
        print("进行随机缩放后图片的大小", img.shape)
        # 图片保存
        cv2.imwrite(new_img_path, img)
        if xml_path:
            bboxes, categories = get_boxes(xml_path)
            print("所有的种类", categories)
​
            bboxes[:, :4] *= [resize_scale_x, resize_scale_y, resize_scale_x, resize_scale_y]
            bboxes, index = clip_box(bboxes, [0, 0, 1 + img_shape[1], img_shape[0]], 0.25)
​
            classes_num = list(map(int, index))
            print("裁剪后的种类的index", classes_num)
            class_names = [categories[idx] for idx in classes_num]
            print("裁剪后的种类", class_names)
            # 进行裁剪之后 object也会发生改变
            print("转换后:", bboxes)
            # 将bboxes转为xml保存
            save_xml_as_path(new_xml_path, img_path, bboxes, class_names)
            print("xml文件保存完成")
        # return img, bboxes

image.png