Description: python script for applying ZCA-Whitening on CIFAR10 dataset (and creating the respective lmdb/leveldb dataset for Caffe)
Submitted by Master on November 13, 2017

#in the name of God 
#[email protected]
#reading Cifar10 and creating an lmdb dataset with preprocessed(zca whitened) images . 
import os
import numpy as np
import cPickle
import lmdb
import matplotlib.pyplot as plt
import cv2
from scipy.misc import toimage
import caffe

#read all cifar10 data and test sets into corrosponding variables 
#they have the shape [N, 3072]
data_train_size = 50000
data_test_size = 10000
image_size_flattened = 32*32*3 #3072

data_train = np.empty([data_train_size, image_size_flattened], np.uint8)
label_train = np.empty([data_train_size], np.uint8)
data_test = np.empty([data_test_size, image_size_flattened], np.uint8)
label_test = np.empty([data_test_size], np.uint8)

data_batches = np.empty([5, 10000, image_size_flattened], np.uint8)

#downloaded and extracted 
cifar10_dir = 'C:/Users/Master/Desktop/cifar-10-batches-py/'
i = 0
j = 0

for file in os.listdir(cifar10_dir):
    if 'data_batch_' in os.path.basename(file):
        #read the data into the array
        with open(cifar10_dir+file,'rb') as input_file:
            data = cPickle.load(input_file)
            data_batches[j] = data['data']
            data_train[i: i+10000, :] = data['data']
            label_train[i: i+10000] = data['labels']
            i = i+10000
    elif 'test_batch' in os.path.basename(file):
        #read the test into the respective array
        with open(cifar10_dir+file,'rb') as input_file: 
            data = cPickle.load(input_file)
            data_test = data['data']
            label_test = data['labels']

print ('training and testing sets are read!')
print data_train.shape
print data_test.shape

def show(img):
    """Displays the input image using plt.imshow"""
def normalize_image(data):
    """normalizes the input image (converts to floats (0.0-1.0))"""
    min_ = np.min(data)
    max_ = np.max(data)
    return (data-min_)/(max_ - min_)    

def run_zca_stackOverflow(data):
    """runs zca on the input data and returns 3 outputs:
    zca, mean, zca_matrix
    print('data shape: ', data.shape)
    print ('Zero Centring...')
    # zero-center
    mean = data.mean(axis=0)
    data = data - mean
    print('Contrast Normalizing..')
    #Contrast Normalization (L2-normalization)    
    data = data / np.sqrt((data ** 2).sum(axis=1))[:,None]
    print ('Calculating Covariance...')
    cov = np.cov(data, rowvar=True) #true calculates covariance among images, and false among dimensions
    U,S,V = np.linalg.svd(cov)     # U is (N, N), S is (N,)

    print('Building ZCA Matrix...')
    # build the ZCA matrix
    epsilon = 1e-5
    zca_matrix =, + epsilon)), U.T))
    print ('Applying ZCA to the data...')
    # transform the image data zca_matrix is (N,N)
    zca =, data)  
    return zca, mean, zca_matrix

# In[ ]:

#test and see an example image
img_test = data_train[6].copy()
print 'original image shape :',img_test.shape
img_test2 = img_test.reshape(3,32,32).transpose(1,2,0)
print 'reshaped and transposed',img_test2.shape
#print img_test2

# In[ ]:

#conducting the zca operation
zca_train, mean, zca_matrix = run_zca_stackOverflow(data_train)

# In[ ]:

#ZCA the test set!
data_test = data_test - mean
data_test = data_test / np.sqrt((data_test ** 2).sum(axis=1))[:,None]
zca_test =, data_test)
print 'zca train shape: ',zca_train.shape
print 'zca test shape: ',zca_test.shape

# In[ ]:

#Saving the dataset as lmdb for use in caffe
def Save_lmdb(data_train, data_test):
    print 'Outputting training data'
    lmdb_file ='cifar10_train_lmdb_whitened'
    batch_size = size_train

    db =, map_size=int(data_train.nbytes))
    batch = db.begin(write=True)
    datum = caffe_pb2.Datum()

    for i in range(size_train):
        if i % 1000 == 0:
            print i

        # save in datum
        datum =[i], label_train[i])
        keystr = '{:0>5d}'.format(i)
        batch.put( keystr, datum.SerializeToString() )

        # write batch
        if(i + 1) % batch_size == 0:
            print (i + 1)

    # write last batch
    if (i+1) % batch_size != 0:
        print 'last batch'
        print (i + 1)

    print 'Outputting test data'
    lmdb_file = 'cifar10_test_lmdb_whitened'
    batch_size = size_test

    db =,map_size=int(data_test.nbytes))
    batch = db.begin(write=True)
    datum = caffe_pb2.Datum()

    for i in range(size_test):
        # save in datum
        datum =[i], label_test[i])
        keystr = '{:0>5d}'.format(i)
        batch.put( keystr, datum.SerializeToString() )

        # write batch
        if(i + 1) % batch_size == 0:
            batch = db.begin(write=True)
            print (i + 1)

    # write last batch
    if (i+1) % batch_size != 0:
        print 'last batch'
        print (i + 1)

# In[ ]:

#saving the dataset as leveldb
def Save_leveldb(data_train, data_test):
    print 'Outputting training data'
    leveldb_file = direct + 'cifar10_train_leveldb_whitened'
    batch_size = size_train

    # create the leveldb file
    db = leveldb.LevelDB(leveldb_file)
    batch = leveldb.WriteBatch()
    datum = caffe_pb2.Datum()

    for i in range(size_train):
        if i % 1000 == 0:
            print i

        # save in datum
        datum =[i], label_train[i])
        keystr = '{:0>5d}'.format(i)
        batch.Put( keystr, datum.SerializeToString() )

        # write batch
        if(i + 1) % batch_size == 0:
            db.Write(batch, sync=True)
            batch = leveldb.WriteBatch()
            print (i + 1)

    # write last batch
    if (i+1) % batch_size != 0:
        db.Write(batch, sync=True)
        print 'last batch'
        print (i + 1)

    print 'Outputting test data'
    leveldb_file = direct + 'cifar10_test_leveldb_whitened'
    batch_size = size_test

    # create the leveldb file
    db = leveldb.LevelDB(leveldb_file)
    batch = leveldb.WriteBatch()
    datum = caffe_pb2.Datum()

    for i in range(size_test):
        # save in datum
        datum =[i], label_test[i])
        keystr = '{:0>5d}'.format(i)
        batch.Put( keystr, datum.SerializeToString() )

        # write batch
        if(i + 1) % batch_size == 0:
            db.Write(batch, sync=True)
            batch = leveldb.WriteBatch()
            print (i + 1)

    # write last batch
    if (i+1) % batch_size != 0:
        db.Write(batch, sync=True)
        print 'last batch'
        print (i + 1)

# In[ ]:

Save_lmdb(zca_train, zca_test)