generate_data.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Nov 8 22:24:55 2018
  4. @author: fdrea
  5. """
  6. import os
  7. import cv2
  8. import h5py
  9. import numpy
  10. from scipy import arange
  11. import matplotlib.pyplot as plt
  12. from PIL import Image
  13. #DATA_PATH = "./data/Trainingset/" # for train
  14. #DATA_PATH = "./data/Train_291_aug/" # for train 291 augmentation
  15. #DATA_PATH = "./data/Train_291/" # for train 201 images
  16. DATA_PATH = "./data/Test/Set5/" # for validation
  17. #DATA_PATH = "./data/Set1/" # for try
  18. Random_Crop = 30 # number of random patches
  19. Patch_size = 32
  20. label_size = 32
  21. scale = 2
  22. blurring_levels = 40
  23. #b= 2
  24. def prepare_training_data():
  25. names = os.listdir(DATA_PATH)
  26. names = sorted(names)
  27. nums = names.__len__()
  28. count = 0
  29. #data = numpy.zeros((nums * Random_Crop, 1, Patch_size, Patch_size), dtype=numpy.double)
  30. #label = numpy.zeros((nums * Random_Crop, 1, label_size, label_size), dtype=numpy.double)
  31. # number of saved batches in h5 file
  32. data = numpy.zeros((nums * Random_Crop * blurring_levels, 1, Patch_size, Patch_size), dtype=numpy.double)
  33. label = numpy.zeros((nums * Random_Crop * blurring_levels, 1, label_size, label_size), dtype=numpy.double)
  34. #imshow = plt.imshow
  35. for i in range(nums):
  36. print('i = ', i)
  37. name = DATA_PATH + names[i]
  38. hr_img = cv2.imread(name, cv2.IMREAD_COLOR)
  39. shape = hr_img.shape
  40. hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2YCrCb)
  41. hr_img = hr_img[:, :, 0]
  42. # produce Random_Crop random coordinate to crop training img
  43. if(min(shape[0], shape[1]) - label_size < 0):
  44. continue
  45. Points_x = numpy.random.randint(0, min(shape[0], shape[1]) - label_size, Random_Crop)
  46. Points_y = numpy.random.randint(0, min(shape[0], shape[1]) - label_size, Random_Crop)
  47. for b in arange(0.1, 4.1, 0.1):
  48. b = round(b, 2)
  49. print('b = ', b)
  50. for j in range(Random_Crop):
  51. print('j = ', j)
  52. hr_patch = hr_img[Points_x[j]: Points_x[j] + label_size, Points_y[j]: Points_y[j] + label_size]
  53. #cv2.imshow("hr", hr_patch)
  54. #imshow(hr_patch)
  55. #plt.show()
  56. lr_patch = cv2.resize(hr_patch, (label_size // scale, label_size // scale), cv2.INTER_CUBIC)
  57. #cv2.imshow("down", lr_patch)
  58. #imshow(lr_patch)
  59. #plt.show()
  60. lr_patch = cv2.resize(lr_patch , (lr_patch.shape[1] * scale, lr_patch.shape[0] * scale), cv2.INTER_CUBIC)
  61. #cv2.imshow("bicubic", lr_patch)
  62. lr_patch = cv2.GaussianBlur(lr_patch, (0,0), sigmaX = b) # to blur
  63. #cv2.imshow("blur_lr", lr_patch)
  64. #imshow(lr_patch)
  65. #plt.show()
  66. lr_patch = lr_patch.astype(float) / 255.
  67. #imshow(lr_patch)
  68. #plt.show()
  69. hr_patch = hr_patch.astype(float) / 255.
  70. #imshow(hr_patch)
  71. #plt.show()
  72. data[count, 0, :, :] = lr_patch
  73. label[count, 0, :, :] = hr_patch
  74. #cv2.imshow("lr/255", lr_patch)
  75. #cv2.imshow("hr/255", hr_patch)
  76. #cv2.waitKey(0)
  77. count= count+1
  78. print('number of samples', count)
  79. return data, label
  80. def write_hdf5(data, labels, output_filename):
  81. """
  82. This function is used to save image data and its label(s) to hdf5 file.
  83. output_file.h5,contain data and label
  84. """
  85. x = data.astype(numpy.float32)
  86. y = labels.astype(numpy.float32)
  87. with h5py.File(output_filename, 'w') as h:
  88. h.create_dataset('data', data=x, shape=x.shape)
  89. h.create_dataset('label', data=y, shape=y.shape)
  90. # h.create_dataset()
  91. def read_training_data(file):
  92. with h5py.File(file, 'r') as hf:
  93. data = numpy.array(hf.get('data'))
  94. label = numpy.array(hf.get('label'))
  95. train_data = numpy.transpose(data, (0, 2, 3, 1))
  96. train_label = numpy.transpose(label, (0, 2, 3, 1))
  97. return train_data, train_label
  98. if __name__ == "__main__":
  99. data, label = prepare_training_data()
  100. #write_hdf5(data, label, "try_train.h5")
  101. # for training 91 images:
  102. #write_hdf5(data, label, "train91.h5")
  103. # for training 291 images:
  104. #write_hdf5(data, label, "train291.h5")
  105. # for training 291 aug images:
  106. #write_hdf5(data, label, "train_aug.h5")
  107. # for validation Set5:
  108. write_hdf5(data, label, "val.h5")
  109. #_, _a = read_training_data("train.h5")
  110. #_, _a = read_training_data("test.h5")