DBSRCNN.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Nov 8 12:35:32 2018
  4. @author: fdrea
  5. """
  6. from __future__ import print_function
  7. import keras
  8. from keras.layers import Input, Convolution2D, Conv2DTranspose, merge, Conv2D
  9. from keras.models import Model
  10. from keras.callbacks import LearningRateScheduler
  11. from keras import backend as K
  12. from keras.optimizers import Adam
  13. import matplotlib.pyplot as plt
  14. import h5py
  15. import math
  16. import sys
  17. from math import sqrt
  18. from keras.callbacks import ModelCheckpoint
  19. import generate_data as pd
  20. import pandas
  21. import numpy
  22. import cv2
  23. #for save excel:
  24. import xlwt
  25. from tempfile import TemporaryFile
  26. scale = 2
  27. batch_size = 64
  28. nb_epoch = 80
  29. def PSNRLoss(y_true, y_pred):
  30. """
  31. PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
  32. It can be calculated as
  33. PSNR = 20 * log10(MAXp) - 10 * log10(MSE)
  34. When providing an unscaled input, MAXp = 255. Therefore 20 * log10(255)== 48.1308036087.
  35. However, since we are scaling our input, MAXp = 1. Therefore 20 * log10(1) = 0.
  36. Thus we remove that component completely and only compute the remaining MSE component.
  37. """
  38. #psnr= k.reduce_mean(k.square(y_pred - y_true))
  39. return 10.0 * K.log(1.0 / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
  40. def step_decay(epoch):
  41. initial_lrate = 0.001
  42. drop = 0.1
  43. epochs_drop = 20
  44. lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
  45. return lrate
  46. def SRCNN_model():
  47. '''
  48. _input = Input(shape=(None, None, 1), name='input')
  49. EES = Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu')(_input)
  50. EES = Conv2DTranspose(filters=32, kernel_size=(14, 14), strides=(2, 2), padding='same', activation='relu')(EES)
  51. out = Conv2D(filters=1, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='same')(EES)
  52. model = Model(input=_input, output=out)
  53. '''
  54. x = Input(shape = (None, None, 1), name='input')
  55. c1 = Convolution2D(32, (9, 9), padding="same", kernel_initializer="he_normal", activation="relu")(x)
  56. c2 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c1)
  57. m= keras.layers.concatenate([c1, c2])
  58. c3 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(m)
  59. c4 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c3)
  60. c5 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c4)
  61. c6 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c5)
  62. c7 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c6)
  63. c8 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c7)
  64. c9 = Convolution2D(32, (5, 5), padding="same", kernel_initializer="he_normal", activation="relu")(c8)
  65. c10 = Convolution2D(1, (5, 5), padding="same", kernel_initializer="he_normal")(c9)
  66. model = Model(inputs = x, outputs = c10)
  67. return model
  68. def SRCNN_train():
  69. '''
  70. EES = model_EES16()
  71. EES.compile(optimizer=adam(lr=0.0003), loss='mse')
  72. print (EES.summary())
  73. data, label = pd.read_training_data("./train1.h5")
  74. val_data, val_label = pd.read_training_data("./val.h5")
  75. checkpoint = ModelCheckpoint("EES2_check.h5", monitor='val_loss', verbose=1, save_best_only=True,
  76. save_weights_only=False, mode='min')
  77. callbacks_list = [checkpoint]
  78. history_callback = EES.fit(data, label, batch_size=64, validation_data=(val_data, val_label),
  79. callbacks=callbacks_list, shuffle=True, nb_epoch=10, verbose=1)
  80. pandas.DataFrame(history_callback.history).to_csv("history.csv")
  81. EES.save_weights("EES2_final.h5")
  82. '''
  83. model = SRCNN_model()
  84. ##compile
  85. adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
  86. model.compile(loss='mean_squared_error', metrics=[PSNRLoss], optimizer = adam) #'mse'
  87. print (model.summary())
  88. #data, label = pd.read_training_data("./data/train91.h5")
  89. #data, label = pd.read_training_data("./data/train_aug.h5")
  90. data, label = pd.read_training_data("./data/train291.h5")
  91. val_data, val_label = pd.read_training_data("./data/val.h5")
  92. # save the best model weights
  93. ModelCheckpoint("SRCNN_check.h5", monitor='val_loss', verbose=0, save_best_only=True,
  94. save_weights_only=False, mode='min')
  95. # learning schedule callback
  96. lrate = LearningRateScheduler(step_decay)
  97. print('lrate=',lrate)
  98. #callbacks_list = [lrate]
  99. history = model.fit(data, label, batch_size=batch_size, epochs=nb_epoch, callbacks = [lrate],
  100. verbose=1, validation_data=(val_data, val_label))
  101. print(history.history.keys())
  102. # to save history
  103. pandas.DataFrame(history.history).to_csv("history.csv")
  104. #save model and weights
  105. json_string = model.to_json()
  106. open('DBSRCNN_model.json','w').write(json_string)
  107. model.save_weights('DBSRCNN_model_weights.h5')
  108. model.save('dbsrcnn_model.h5')
  109. # summarize history for Peak signal to noise ratio (PSNR)
  110. plt.figure()
  111. plt.plot(history.history['PSNRLoss'])
  112. plt.plot(history.history['val_PSNRLoss'])
  113. plt.title('Peak Signal to Noise Ratio')
  114. plt.ylabel('PSNR/dB')
  115. plt.xlabel('Epoch')
  116. plt.legend(['Train', 'Test'], loc='lower right')
  117. #plt.show()
  118. plt.grid(True,which="both",ls="-")
  119. plt.savefig('Epoch and PSNR SR.png')
  120. # summarize history for loss
  121. plt.figure()
  122. plt.plot(history.history['loss'])
  123. plt.plot(history.history['val_loss'])
  124. plt.title('Model Loss')
  125. plt.ylabel('Loss')
  126. plt.xlabel('Epoch')
  127. plt.legend(['Train', 'Test'], loc='upper right')
  128. #plt.show()
  129. plt.grid(True,which="both",ls="-")
  130. plt.savefig('Epoch and Loss SR.png')
  131. '''
  132. # for save excel file:
  133. book = xlwt.Workbook()
  134. sheet1 = book.add_sheet('sheet1')
  135. for i,e in enumerate(history.history['val_PSNRLoss']):
  136. sheet1.write(i,1,e)
  137. name = "PSNR_Validation.xls"
  138. book.save(name)
  139. book.save(TemporaryFile())
  140. '''
  141. ###====================================================================================================
  142. def SRCNN_predict():
  143. #IMG_NAME = "./butterfly_GT.bmp"
  144. #down_NAME = "down.jpg"
  145. #INPUT_NAME = "input.jpg"
  146. IMG_NAME = sys.argv[1]
  147. down_NAME = sys.argv[2]
  148. INPUT_NAME = sys.argv[3]
  149. OUTPUT_NAME = sys.argv[4]+'_output.jpg'
  150. b = 0.1 #blur sigma
  151. label = cv2.imread(IMG_NAME)
  152. shape = label.shape
  153. img = cv2.resize(label, (shape[1] // scale, shape[0] // scale), cv2.INTER_CUBIC)
  154. cv2.imwrite(down_NAME, img)
  155. img = cv2.resize(img, (img.shape[1] * scale, img.shape[0] * scale), cv2.INTER_CUBIC)
  156. img = cv2.GaussianBlur(img, (0,0), sigmaX = b)
  157. cv2.imwrite(INPUT_NAME, img)
  158. SRCNN = SRCNN_model()
  159. #SRCNN.load_weights("SRCNN_check.h5")
  160. SRCNN.load_weights("DBSRCNN_model_weights.h5")
  161. img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
  162. Y = numpy.zeros((1, img.shape[0], img.shape[1], 1))
  163. Y[0, :, :, 0] = img[:, :, 0].astype(float) / 255.
  164. img = cv2.cvtColor(label, cv2.COLOR_BGR2YCrCb)
  165. pre = SRCNN.predict(Y, batch_size=1) * 255.
  166. pre[pre[:] > 255] = 255
  167. pre = numpy.uint8(pre)
  168. img[:, :, 0] = pre[0, :, :, 0]
  169. img = cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR)
  170. cv2.imwrite(OUTPUT_NAME, img)
  171. # psnr calculation:
  172. im1 = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
  173. im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2YCrCb)
  174. im2 = cv2.imread(INPUT_NAME, cv2.IMREAD_COLOR)
  175. im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2YCrCb)
  176. im2 = cv2.resize(im2, (img.shape[1], img.shape[0]))
  177. cv2.imwrite("Bicubic.jpg", cv2.cvtColor(im2, cv2.COLOR_YCrCb2BGR))
  178. im3 = cv2.imread(OUTPUT_NAME, cv2.IMREAD_COLOR)
  179. im3 = cv2.cvtColor(im3, cv2.COLOR_BGR2YCrCb)
  180. print ("Bicubic:")
  181. print (cv2.PSNR(im1[:, :, 0], im2[:, :, 0]))
  182. #m = PSNRLoss(im1.astype('float32')[:, :, 0], im2.astype('float32')[:, :, 0])
  183. #print (m)
  184. print ("SRCNN:")
  185. print (cv2.PSNR(im1[:, :, 0], im3[:, :, 0]))
  186. ###======================================================================================================
  187. # Function by gcalmettes from http://stackoverflow.com/questions/11159436/multiple-figures-in-a-single-window
  188. def plot_figures(figures, nrows=1, ncols=1, titles=False):
  189. """Plot a dictionary of figures.
  190. Parameters
  191. ----------
  192. figures : <title, figure> dictionary
  193. ncols : number of columns of subplots wanted in the display
  194. nrows : number of rows of subplots wanted in the figure
  195. """
  196. fig, axeslist = plt.subplots(ncols=ncols, nrows=nrows)
  197. for ind, title in enumerate(sorted(figures.keys(), key=lambda s: int(s[3:]))):
  198. axeslist.ravel()[ind].imshow(figures[title], cmap=plt.gray())
  199. if titles:
  200. axeslist.ravel()[ind].set_title(title)
  201. for ind in range(nrows*ncols):
  202. axeslist.ravel()[ind].set_axis_off()
  203. if titles:
  204. plt.tight_layout()
  205. plt.show()
  206. def get_dim(num):
  207. """
  208. Simple function to get the dimensions of a square-ish shape for plotting
  209. num images
  210. """
  211. s = sqrt(num)
  212. if round(s) < s:
  213. return (int(s), int(s)+1)
  214. else:
  215. return (int(s)+1, int(s)+1)
  216. def feature_map_visilization(model, _input):
  217. # Get the convolutional layers
  218. conv_layers = [layer for layer in model.layers if isinstance(layer, Conv2D)]
  219. # Use a keras function to extract the conv layer data
  220. convout_func = K.function([model.layers[0].input, K.learning_phase()], [layer.output for layer in conv_layers])
  221. conv_imgs_filts = convout_func([_input, 0])
  222. # Also get the prediction so we know what we predicted
  223. predictions = model.predict(_input)
  224. imshow = plt.imshow # alias
  225. # Show the original image
  226. plt.title("Image used:")
  227. imshow(_input[0, :, :, 0], cmap='gray')
  228. plt.tight_layout()
  229. plt.show()
  230. # Plot the filter images
  231. for i, conv_imgs_filt in enumerate(conv_imgs_filts):
  232. conv_img_filt = conv_imgs_filt[0]
  233. print("Visualizing Convolutions Layer %d" % i)
  234. # Get it ready for the plot_figures function
  235. fig_dict = {'flt{0}'.format(i): conv_img_filt[:, :, i] for i in range(conv_img_filt.shape[-1])}
  236. plot_figures(fig_dict, *get_dim(len(fig_dict)))
  237. cv2.waitKey(0)
  238. def vilization_and_show():
  239. model = SRCNN_model()
  240. #model.load_weights("SRCNN_check.h5")
  241. model.load_weights("DBSRCNN_model_weights.h5")
  242. #IMG_NAME = "comic.bmp"
  243. IMG_NAME = "./butterfly_GT.bmp"
  244. INPUT_NAME = "input.jpg"
  245. img = cv2.imread(IMG_NAME)
  246. shape = img.shape
  247. img = cv2.resize(img, (shape[1] // 2, shape[0] // 2), cv2.INTER_CUBIC)
  248. cv2.imwrite(INPUT_NAME, img)
  249. img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
  250. Y = numpy.zeros((1, img.shape[0], img.shape[1], 1))
  251. Y[0, :, :, 0] = img[:, :, 0]
  252. feature_map_visilization(model, Y)
  253. ####===================================================================================================
  254. if __name__ == "__main__":
  255. #SRCNN_train()
  256. SRCNN_predict()
  257. #vilization_and_show()