文章目录
第3 章预处理图像数据
3.1 数据增强简介
提高模型泛化能力的最重要的三大因素是数据、模型、损失函数,其中数据又是三个因素中最重要的因素,但数据的获取往往不充分或成本比较高。是否有其他方法可以快速又便捷地增加数据量呢?在一些领域是存在的,如在图像识别、语言识别领域,可以通过水平或垂直翻转图像、裁剪、色彩变换、扩展和旋转等数据增强(Data Augmentation)技术来增加数据量。
通过数据增强技术不仅可以扩大训练数据集的规模、降低模型对某些属性的依赖,从而提高模型的泛化能力,也可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同位置,从而减轻模型对物体出现位置的依赖性,还可以通过调整亮度、色彩等因素来降低模型对色彩的敏感度等。当然,对图像做这些预处理时,不宜使用会改变其类别的转换,如对于手写的数字,如果旋转90度,就有可能把9变成6,或把6变为9。
此外,把随机噪声添加到输入数据或隐藏单元中也是增加数据量的方法之一。
3.2 使用OpenCV实现图像增强
3.2.1导入需要的库
1 2 3 4 5 6 |
import cv2 import numpy as np import os import random from PIL import Image from PIL import ImageEnhance |
3.2.2导入图像数据
1 2 3 4 5 |
image_path='./data/cat/' save_path='./data/save/' im = Image.open(image_path+"cat.1.jpg") im |
运行结果
3.2.3增加高斯噪声
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# 高斯噪声 def gauss(image_path, save_path): # 设置高斯分布的均值和标准差 mean_list = [-0.2, -0.1, 0, 0.3, 0.6] sigma_list = [40, 50, 60] #读取文件,放在一个列表中 files = os.listdir(image_path) for file in files: #把路径和文件名结合在一起 file_path = os.path.join(image_path, file) #把文件名与扩展名进行拆分 temp = os.path.splitext(file) file_name = temp[0] + '_gauss' + temp[-1] image = cv2.imread(file_path) mean = random.choice(mean_list) sigma = random.choice(sigma_list) gauss = np.random.normal(mean, sigma, image.shape) noisy_image = image + gauss # 将noisy_image中的像素控制在0-255 noisy_image = np.clip(noisy_image, a_min=0, a_max=255) cv2.imwrite(os.path.join(save_path, file_name), noisy_image) |
读取处理后的图像
1 2 3 |
gauss(image_path, save_path) im = Image.open('./data/save/cat.1_gauss.jpg') im |
运行结果
3.2.4图像缩小为0.5倍
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# 图像高宽缩小为0.5倍 def change_scale_05(image_path, save_path): files = os.listdir(image_path) for file in files: file_path = os.path.join(image_path, file) image = cv2.imread(file_path) image = cv2.resize(image, None, fx=0.5, fy=0.5) temp = os.path.splitext(file) file_name = temp[0] + '_scale_0.5' + temp[-1] cv2.imwrite(os.path.join(save_path, file_name), image) 读取处理后的图像 |
1 2 3 |
change_scale_05(image_path, save_path) im = Image.open('./data/save/cat.1_scale_0.5.jpg') im |
运行结果
3.2.5图像水平翻转
1 2 3 4 5 6 7 8 9 10 11 12 13 |
# 图像水平翻转 def horizontal_flip(image_path, save_path): files = os.listdir(image_path) for file in files: file_path = os.path.join(image_path, file) image = cv2.imread(file_path) image = cv2.flip(image, 1) temp = os.path.splitext(file) file_name = temp[0] + '_horizontal' + temp[-1] cv2.imwrite(os.path.join(save_path, file_name), image) |
读取处理后的图像
1 2 3 |
horizontal_flip(image_path, save_path) im = Image.open('./data/save/cat.1_horizontal.jpg') im |
运行结果
3.2.6图像垂直翻转
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
def rotate_90(image_path, save_path): files = os.listdir(image_path) angel_list = [0, 1] for file in files: file_path = os.path.join(image_path, file) image = cv2.imread(file_path) image = cv2.transpose(image) rotate_angle = random.choice(angel_list) image = cv2.flip(image, rotate_angle) temp = os.path.splitext(file) file_name = temp[0] + '_90' + temp[-1] cv2.imwrite(os.path.join(save_path, file_name), image) |
读取处理后的图像
1 2 3 |
rotate_90(image_path, save_path) im = Image.open('./data/save/cat.1_90.jpg') im |
运行结果
3.2.6增强图像亮度
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
#增强亮度 def bright(image_path, save_path): files = os.listdir(image_path) for file in files: image = Image.open(os.path.join(image_path, file)) image = ImageEnhance.Brightness(image) bright_factor = np.random.randint(5, 19) / 10. image = image.enhance(bright_factor) temp = os.path.splitext(file) file_name = temp[0] + '_bright' + temp[-1] image.save(os.path.join(save_path, file_name)) |
读取处理后的图像
1 2 3 |
bright(image_path, save_path) im = Image.open('./data/save/cat.1_bright.jpg') im |
运行结果
3.2.7混合增强
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
# 混合增强 def hybrid(image_path, save_path): files = os.listdir(image_path) for file in files: image = Image.open(os.path.join(image_path, file)) contrast_factor = np.random.randint(5, 19) / 10. bright_factor = np.random.randint(5, 19) / 10. color_factor = np.random.randint(1, 25) / 10. sharp_factor = np.random.randint(1, 30) / 10. image = ImageEnhance.Contrast(image).enhance(contrast_factor) image = ImageEnhance.Brightness(image).enhance(bright_factor) image = ImageEnhance.Color(image).enhance(color_factor) image = ImageEnhance.Sharpness(image).enhance(sharp_factor) temp = os.path.splitext(file) file_name = temp[0] + '_hybrid' + temp[-1] # cv2.imwrite(os.path.join(save_path, file_name), image) image.save(os.path.join(save_path, file_name)) |
读取处理后的图像
1 2 3 |
hybrid(image_path, save_path) im = Image.open('./data/save/cat.1_hybrid.jpg') im |
运行结果
3.3 图像去雾
图像去雾的研究算法有很多,但是主要分为两类:基于图像增强的去雾算法和基于图像复原的去雾算法。 基于图像增强的去雾算法:基于图像增强的去雾算法出发点是尽量去除图像噪声,提高图像对比度,从而恢复出无雾清晰图像。代表性方法有:直方图均衡化(HLE)、自适应直方图均衡化(AHE)、限制对比度自适应直方图均衡化(CLAHE)、Retinex算法、小波变换、同态滤波等 基于图像复原的去雾算法:这一系列方法基本是基于大气退化模型,进行响应的去雾处理。代表性算法有来自何凯明博士的暗通道去雾算法、基于导向滤波的暗通道去雾算法、Fattal的单幅图像去雾算法(Single image dehazing)、Tan的单一图像去雾算法(Visibility in bad weather from a single image)、Tarel的快速图像恢复算法(Fast visibility restoration from a single color or gray level image)、贝叶斯去雾算法(Single image defogging by multiscale depth fusion),基于大气退化模型的去雾效果普遍好于基于图像增强的去雾算法,后面挑选的传统去雾算法例子也大多是基于图像复原的去雾算法。
这里主要介绍的基于图像增强的图像去雾,在此使用直方图均衡化和局部直方图均衡化进行图像的去雾处理。
3.3.1 显示原图
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import cv2 import matplotlib.pyplot as plt %matplotlib inline plt.rcParams['font.family'] = 'SimHei' # matplotlib 绘图库正常使用中文黑体 #读取图像信息,图像格式为BGR img0 = cv2.imread('./data/shanghai02.jpg') img1 = cv2.resize(img0, dsize = None, fx = 0.6, fy = 0.6) #灰度化:彩色图像转为灰度图像 img2 = cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY) #为便于plt显示,把BGR格式转换为RGB格式 img21 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) |
3.3.2全局直方图均衡化
1 2 3 4 5 6 7 8 9 10 11 12 |
#全局直方图均衡化 img3 = cv2.equalizeHist(img2) #为便于plt显示,把BGR格式转换为RGB格式 img31 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB) plt.figure(figsize=(10,5)) plt.subplot(1, 2, 1) plt.title("原图") plt.imshow(img21) plt.subplot(1, 2, 2) plt.title("全局直方图均衡化") plt.imshow(img31) |
运行结果
3.3.3 局部直方图均衡化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
#局部直方图均衡化 clahe = cv2.createCLAHE (clipLimit = 2.0, tileGridSize=(10, 10)) img4 = clahe.apply(img2) #为便于plt显示,把BGR格式转换为RGB格式 img41 = cv2.cvtColor(img4, cv2.COLOR_BGR2RGB) plt.figure(figsize=(10,5)) plt.subplot(1, 2, 1) plt.title("原图") plt.imshow(img21) plt.subplot(1, 2, 2) plt.title("局部直方图均衡化") plt.imshow(img41) |
运行结果
3.3.4 比较直方图
1 2 3 4 5 6 7 8 |
plt.figure(figsize=(8,4)) plt.subplot(1, 2, 1) plt.plot(hist1, label = "处理前的直方图", color = 'b', linestyle = '--') plt.legend() plt.subplot(1, 2, 2) plt.plot(hist2, label = "局部直方图均衡化", color = 'b', linestyle = '--') plt.legend() plt.show() |
运行结果
最明显的变化就是某一些像素点数比较少的亮度级别消失了,而且图像直方图的变化也没有那么突兀了,图像也就更加清晰了。
3.4 使用PyTorch实现图像增强
使用pytorch中的torchvision模块实现数据增强。
3.4.1 按比例缩放
随机比例缩放主要使用的是 torchvision.transforms.Resize()函数。
1)显示原图。
1 2 3 4 5 6 |
import sys from PIL import Image from torchvision import transforms as trans im = Image.open('./data/cat/cat.1.jpg') im |
运行结果如图3-1所示。
图3-1 小猫原图
2)随机比例缩放。
1 2 3 4 5 |
# 比例缩放 print('原图像大小: {}'.format(im.size)) new_im = trans.Resize((100, 200))(im) print('缩放后大小: {}'.format(new_im.size)) new_im |
运行结果如图3-2所示。
原图像大小: (500, 414)
缩放后大小: (200, 100)
图3-2 缩放后的图像
3.4.2 裁剪
随机裁剪有两种方式,一种是对图像在随机位置进行截取,可传入裁剪大小,使用的函数为torchvision.transforms.RandomCrop();另一种是在中心,按比例裁剪,函数为 torchvision.transforms.CenterCrop()。
1 2 3 |
# 随机裁剪出 200 x 200 的区域 random_im1 = trans.RandomCrop(200)(im) random_im1 |
运行结果如图3-3所示。
图3-3 剪辑后的图像
3.4.3 翻转
翻转猫还是猫,不会改变其类别。通过翻转图像可以增加其多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip() 、torchvision.transforms.RandomVerticalFlip()和 torchvision.transforms.RandomRotation()等函数。
1 2 3 |
# 随机竖直翻转 v_flip = trans.RandomVerticalFlip()(im) v_flip |
运行结果如图3-4所示。
图3-4 翻转后的图像
3.4.4改变颜色
除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要是用 torchvision.transforms.ColorJitter() 来实现的。
1 2 3 |
# 改变颜色 color_im = trans.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化 color_im |
运行结果如图3-5所示。
图3-5 改变颜色后的图像
3.4.5组合多种增强方法
我们可用torchvision.transforms.Compose() 函数把以上这些变化组合在一起。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
im_aug = trans.Compose([ tfs.Resize(200), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5) ]) import matplotlib.pyplot as plt %matplotlib inline nrows = 3 ncols = 3 figsize = (8, 8) _, figs = plt.subplots(nrows, ncols, figsize=figsize) plt.axis('off') for i in range(nrows): for j in range(ncols): figs[i][j].imshow(im_aug(im)) plt.show() |
运行结果如图3-6所示。
图3-6实现图像增强后的部分图像