ganMXnet.R 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. require("imager")
  18. require("dplyr")
  19. require("readr")
  20. require("mxnet")
  21. source("iterators.R")
  22. ### Data import and preperation
  23. # First download MNIST train data at Kaggle:
  24. # https://www.kaggle.com/c/digit-recognizer/data
  25. train <- read_csv("data/train.csv")
  26. train <- data.matrix(train)
  27. train_data <- train[, -1]
  28. train_data <- t(train_data/255 * 2 - 1)
  29. train_label <- as.integer(train[, 1])
  30. dim(train_data) <- c(28, 28, 1, ncol(train_data))
  31. ### Model parameters
  32. random_dim <- 96
  33. gen_features <- 96
  34. dis_features <- 32
  35. image_depth <- 1
  36. fix_gamma <- T
  37. no_bias <- T
  38. eps <- 1e-05 + 1e-12
  39. batch_size <- 64
  40. ### Generator Symbol
  41. data <- mx.symbol.Variable("data")
  42. gen_rand <- mx.symbol.normal(loc = 0, scale = 1, shape = c(1, 1, random_dim, batch_size),
  43. name = "gen_rand")
  44. gen_concat <- mx.symbol.concat(data = list(data, gen_rand), num.args = 2, name = "gen_concat")
  45. g1 <- mx.symbol.Deconvolution(gen_concat, name = "g1", kernel = c(4, 4), num_filter = gen_features *
  46. 4, no_bias = T)
  47. gbn1 <- mx.symbol.BatchNorm(g1, name = "gbn1", fix_gamma = fix_gamma, eps = eps)
  48. gact1 <- mx.symbol.Activation(gbn1, name = "gact1", act_type = "relu")
  49. g2 <- mx.symbol.Deconvolution(gact1, name = "g2", kernel = c(3, 3), stride = c(2,
  50. 2), pad = c(1, 1), num_filter = gen_features * 2, no_bias = no_bias)
  51. gbn2 <- mx.symbol.BatchNorm(g2, name = "gbn2", fix_gamma = fix_gamma, eps = eps)
  52. gact2 <- mx.symbol.Activation(gbn2, name = "gact2", act_type = "relu")
  53. g3 <- mx.symbol.Deconvolution(gact2, name = "g3", kernel = c(4, 4), stride = c(2,
  54. 2), pad = c(1, 1), num_filter = gen_features, no_bias = no_bias)
  55. gbn3 <- mx.symbol.BatchNorm(g3, name = "gbn3", fix_gamma = fix_gamma, eps = eps)
  56. gact3 <- mx.symbol.Activation(gbn3, name = "gact3", act_type = "relu")
  57. g4 <- mx.symbol.Deconvolution(gact3, name = "g4", kernel = c(4, 4), stride = c(2,
  58. 2), pad = c(1, 1), num_filter = image_depth, no_bias = no_bias)
  59. G_sym <- mx.symbol.Activation(g4, name = "G_sym", act_type = "tanh")
  60. ### Discriminator Symbol
  61. data <- mx.symbol.Variable("data")
  62. dis_digit <- mx.symbol.Variable("digit")
  63. label <- mx.symbol.Variable("label")
  64. dis_digit <- mx.symbol.Reshape(data = dis_digit, shape = c(1, 1, 10, batch_size),
  65. name = "digit_reshape")
  66. dis_digit <- mx.symbol.broadcast_to(data = dis_digit, shape = c(28, 28, 10, batch_size),
  67. name = "digit_broadcast")
  68. data_concat <- mx.symbol.concat(list(data, dis_digit), num.args = 2, dim = 1, name = "dflat_concat")
  69. d1 <- mx.symbol.Convolution(data = data_concat, name = "d1", kernel = c(3, 3), stride = c(1,
  70. 1), pad = c(0, 0), num_filter = 24, no_bias = no_bias)
  71. dbn1 <- mx.symbol.BatchNorm(d1, name = "dbn1", fix_gamma = fix_gamma, eps = eps)
  72. dact1 <- mx.symbol.LeakyReLU(dbn1, name = "dact1", act_type = "elu", slope = 0.25)
  73. pool1 <- mx.symbol.Pooling(data = dact1, name = "pool1", pool_type = "max", kernel = c(2,
  74. 2), stride = c(2, 2), pad = c(0, 0))
  75. d2 <- mx.symbol.Convolution(pool1, name = "d2", kernel = c(3, 3), stride = c(2, 2),
  76. pad = c(0, 0), num_filter = 32, no_bias = no_bias)
  77. dbn2 <- mx.symbol.BatchNorm(d2, name = "dbn2", fix_gamma = fix_gamma, eps = eps)
  78. dact2 <- mx.symbol.LeakyReLU(dbn2, name = "dact2", act_type = "elu", slope = 0.25)
  79. d3 <- mx.symbol.Convolution(dact2, name = "d3", kernel = c(3, 3), stride = c(1, 1),
  80. pad = c(0, 0), num_filter = 64, no_bias = no_bias)
  81. dbn3 <- mx.symbol.BatchNorm(d3, name = "dbn3", fix_gamma = fix_gamma, eps = eps)
  82. dact3 <- mx.symbol.LeakyReLU(dbn3, name = "dact3", act_type = "elu", slope = 0.25)
  83. d4 <- mx.symbol.Convolution(dact2, name = "d3", kernel = c(4, 4), stride = c(1, 1),
  84. pad = c(0, 0), num_filter = 64, no_bias = no_bias)
  85. dbn4 <- mx.symbol.BatchNorm(d4, name = "dbn4", fix_gamma = fix_gamma, eps = eps)
  86. dact4 <- mx.symbol.LeakyReLU(dbn4, name = "dact4", act_type = "elu", slope = 0.25)
  87. # pool4 <- mx.symbol.Pooling(data=dact3, name='pool4', pool_type='avg',
  88. # kernel=c(4,4), stride=c(1,1), pad=c(0,0))
  89. dflat <- mx.symbol.Flatten(dact4, name = "dflat")
  90. dfc <- mx.symbol.FullyConnected(data = dflat, name = "dfc", num_hidden = 1, no_bias = F)
  91. D_sym <- mx.symbol.LogisticRegressionOutput(data = dfc, label = label, name = "D_sym")
  92. ### Graph
  93. input_shape_G <- c(1, 1, 10, batch_size)
  94. input_shape_D <- c(28, 28, 1, batch_size)
  95. graph.viz(G_sym, type = "graph", direction = "LR")
  96. graph.viz(D_sym, type = "graph", direction = "LR")
  97. ### Training module for GAN
  98. # Change this to mx.gpu() when running on gpu machine.
  99. devices <- mx.cpu()
  100. data_shape_G <- c(1, 1, 10, batch_size)
  101. data_shape_D <- c(28, 28, 1, batch_size)
  102. digit_shape_D <- c(10, batch_size)
  103. mx.metric.binacc <- mx.metric.custom("binacc", function(label, pred) {
  104. res <- mean(label == round(pred))
  105. return(res)
  106. })
  107. mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
  108. res <- mean(label * log(pred) + (1 - label) * log(1 - pred))
  109. return(res)
  110. })
  111. ### Define iterators
  112. iter_G <- G_iterator(batch_size = batch_size)
  113. iter_D <- D_iterator(batch_size = batch_size)
  114. exec_G <- mx.simple.bind(symbol = G_sym, data = data_shape_G, ctx = devices, grad.req = "write")
  115. exec_D <- mx.simple.bind(symbol = D_sym, data = data_shape_D, digit = digit_shape_D,
  116. ctx = devices, grad.req = "write")
  117. ### initialize parameters - To Do - personalise each layer
  118. initializer <- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 3)
  119. arg_param_ini_G <- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym,
  120. data = data_shape_G)$arg.shapes, ctx = devices)
  121. aux_param_ini_G <- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym,
  122. data = data_shape_G)$aux.shapes, ctx = devices)
  123. arg_param_ini_D <- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym,
  124. data = data_shape_D, digit = digit_shape_D)$arg.shapes, ctx = devices)
  125. aux_param_ini_D <- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym,
  126. data = data_shape_D, digit = digit_shape_D)$aux.shapes, ctx = devices)
  127. mx.exec.update.arg.arrays(exec_G, arg_param_ini_G, match.name = TRUE)
  128. mx.exec.update.aux.arrays(exec_G, aux_param_ini_G, match.name = TRUE)
  129. mx.exec.update.arg.arrays(exec_D, arg_param_ini_D, match.name = TRUE)
  130. mx.exec.update.aux.arrays(exec_D, aux_param_ini_D, match.name = TRUE)
  131. input_names_G <- mxnet:::mx.model.check.arguments(G_sym)
  132. input_names_D <- mxnet:::mx.model.check.arguments(D_sym)
  133. ### initialize optimizers
  134. optimizer_G <- mx.opt.create(name = "adadelta", rho = 0.92, epsilon = 1e-06, wd = 0,
  135. rescale.grad = 1/batch_size, clip_gradient = 1)
  136. updater_G <- mx.opt.get.updater(optimizer = optimizer_G, weights = exec_G$ref.arg.arrays,
  137. ctx = devices)
  138. optimizer_D <- mx.opt.create(name = "adadelta", rho = 0.92, epsilon = 1e-06, wd = 0,
  139. rescale.grad = 1/batch_size, clip_gradient = 1)
  140. updater_D <- mx.opt.get.updater(optimizer = optimizer_D, weights = exec_D$ref.arg.arrays,
  141. ctx = devices)
  142. ### initialize metric
  143. metric_G <- mx.metric.binacc
  144. metric_G_value <- metric_G$init()
  145. metric_D <- mx.metric.binacc
  146. metric_D_value <- metric_D$init()
  147. iteration <- 1
  148. iter_G$reset()
  149. iter_D$reset()
  150. for (iteration in 1:2400) {
  151. iter_G$iter.next()
  152. iter_D$iter.next()
  153. ### Random input to Generator to produce fake sample
  154. G_values <- iter_G$value()
  155. G_data <- G_values[input_names_G]
  156. mx.exec.update.arg.arrays(exec_G, arg.arrays = G_data, match.name = TRUE)
  157. mx.exec.forward(exec_G, is.train = T)
  158. ### Feed Discriminator with Concatenated Generator images and real images Random
  159. ### input to Generator
  160. D_data_fake <- exec_G$ref.outputs$G_sym_output
  161. D_digit_fake <- G_values$data %>% mx.nd.Reshape(shape = c(-1, batch_size))
  162. D_values <- iter_D$value()
  163. D_data_real <- D_values$data
  164. D_digit_real <- D_values$digit
  165. ### Train loop on fake
  166. mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data = D_data_fake, digit = D_digit_fake,
  167. label = mx.nd.array(rep(0, batch_size))), match.name = TRUE)
  168. mx.exec.forward(exec_D, is.train = T)
  169. mx.exec.backward(exec_D)
  170. update_args_D <- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
  171. mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null = TRUE)
  172. metric_D_value <- metric_D$update(label = as.array(mx.nd.array(rep(0, batch_size))),
  173. pred = as.array(exec_D$ref.outputs[["D_sym_output"]]), metric_D_value)
  174. ### Train loop on real
  175. mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data = D_data_real, digit = D_digit_real,
  176. label = mx.nd.array(rep(1, batch_size))), match.name = TRUE)
  177. mx.exec.forward(exec_D, is.train = T)
  178. mx.exec.backward(exec_D)
  179. update_args_D <- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
  180. mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null = TRUE)
  181. metric_D_value <- metric_D$update(label = as.array(mx.nd.array(rep(1, batch_size))),
  182. pred = as.array(exec_D$ref.outputs[["D_sym_output"]]), metric_D_value)
  183. ### Update Generator weights - use a seperate executor for writing data gradients
  184. exec_D_back <- mxnet:::mx.symbol.bind(symbol = D_sym, arg.arrays = exec_D$arg.arrays,
  185. aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)),
  186. ctx = devices)
  187. mx.exec.update.arg.arrays(exec_D_back, arg.arrays = list(data = D_data_fake,
  188. digit = D_digit_fake, label = mx.nd.array(rep(1, batch_size))), match.name = TRUE)
  189. mx.exec.forward(exec_D_back, is.train = T)
  190. mx.exec.backward(exec_D_back)
  191. D_grads <- exec_D_back$ref.grad.arrays$data
  192. mx.exec.backward(exec_G, out_grads = D_grads)
  193. update_args_G <- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays)
  194. mx.exec.update.arg.arrays(exec_G, update_args_G, skip.null = TRUE)
  195. ### Update metrics metric_G_value <- metric_G$update(values[[label_name]],
  196. ### exec_G$ref.outputs[[output_name]], metric_G_value)
  197. if (iteration%%25 == 0) {
  198. D_metric_result <- metric_D$get(metric_D_value)
  199. cat(paste0("[", iteration, "] ", D_metric_result$name, ": ", D_metric_result$value,
  200. "\n"))
  201. }
  202. if (iteration == 1 | iteration%%100 == 0) {
  203. metric_D_value <- metric_D$init()
  204. par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  205. for (i in 1:9) {
  206. img <- as.array(exec_G$ref.outputs$G_sym_output)[, , , i]
  207. plot(as.cimg(img), axes = F)
  208. }
  209. print(as.numeric(as.array(G_values$digit)))
  210. print(as.numeric(as.array(D_values$label)))
  211. }
  212. }
  213. ifelse(!dir.exists(file.path(".", "models")), dir.create(file.path(".", "models")),
  214. "Folder already exists")
  215. mx.symbol.save(D_sym, filename = "models/D_sym_model_v1.json")
  216. mx.nd.save(exec_D$arg.arrays, filename = "models/D_aux_params_v1.params")
  217. mx.nd.save(exec_D$aux.arrays, filename = "models/D_aux_params_v1.params")
  218. mx.symbol.save(G_sym, filename = "models/G_sym_model_v1.json")
  219. mx.nd.save(exec_G$arg.arrays, filename = "models/G_arg_params_v1.params")
  220. mx.nd.save(exec_G$aux.arrays, filename = "models/G_aux_params_v1.params")
  221. ### Inference
  222. G_sym <- mx.symbol.load("models/G_sym_model_v1.json")
  223. G_arg_params <- mx.nd.load("models/G_arg_params_v1.params")
  224. G_aux_params <- mx.nd.load("models/G_aux_params_v1.params")
  225. digit <- mx.nd.array(rep(9, times = batch_size))
  226. data <- mx.nd.one.hot(indices = digit, depth = 10)
  227. data <- mx.nd.reshape(data = data, shape = c(1, 1, -1, batch_size))
  228. exec_G <- mx.simple.bind(symbol = G_sym, data = data_shape_G, ctx = devices, grad.req = "null")
  229. mx.exec.update.arg.arrays(exec_G, G_arg_params, match.name = TRUE)
  230. mx.exec.update.arg.arrays(exec_G, list(data = data), match.name = TRUE)
  231. mx.exec.update.aux.arrays(exec_G, G_aux_params, match.name = TRUE)
  232. mx.exec.forward(exec_G, is.train = F)
  233. par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  234. for (i in 1:9) {
  235. img <- as.array(exec_G$ref.outputs$G_sym_output)[, , , i]
  236. plot(as.cimg(img), axes = F)
  237. }