Python读入CIFAR-10数据库

CIFAR-10数据库由60000个32$\times$32的彩色图像组成,一般50000个用作训练,10000个用作测试。这60000个图片共包含了10类。CIFAR-10的库作者已经将数据打包成了一定的标准格式方便各位筒子的使用。在Python下,这个数据库被划分为了6个batch,每个batch包含一万张图像,并以numpy array的形式存储。这个10000$\times$3072大小的array,每一行的数据代表了一张图像,其中,这一行中3072个数据的前1024个数据为该32$\times$32图像的red通道,而随后的1024个数据为green通道,最后的1024个数据为blue通道。

在Python中读入这些数据的方式如下:

1
2
3
4
5
6
def unpickle(file):
import cPickle
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
return dict

CS231n这个课程中提供的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
import cPickle as pickle
import numpy as np
import os

def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = pickle.load(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y = np.array(Y)
return X, Y

利用Cpickle读入数据后,对数据进行了转换:

1
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")

reshape用于对数组尺寸进行改变,上面reshape(10000, 3, 32, 32)首先将这个10000$\times$3072数组均分成10000份,也就是1行1份,然后每份在继续分成3份,也就是R、G、B一个通道一份,然后再将这一份分为32$\times$32表示。我们可以用下图来理解这一步骤,图中,蓝色圆圈内数字为相应的某维度/轴。

Python Reshape and Transpose.jpg-47.7kB

而transpose(0,2,3,1)函数的意义则是生成矩阵进行转置,新矩阵的0,1,2,3维度/轴分别为原矩阵的0,2,3,1维度/轴,从下面的图中我们就可以清楚看到矩阵维度/轴的变换。

Python Reshape and Transpose.png-47.5kB

经过这样变换,我们可以很方便的获取想要的图像和图像中对应像素的值。