1
0

mxnet.R 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # installation
  2. cran <- getOption("repos")
  3. cran["dmlc"] <- "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/CRAN/GPU/cu92"
  4. options(repos = cran)
  5. install.packages("mxnet")
  6. require(mxnet)
  7. # load data and split
  8. df <-
  9. structure(list(name = c("acebutolol", "acebutolol_ester", "acetylsalic_acid",
  10. "acyclovir", "alprenolol", "alprenolol ester", "aminopyrin",
  11. "artemisinin", "artesunate", "atenolol", "betazolol ester", "betazolol_",
  12. "bremazocine", "caffeine", "chloramphenicol", "chlorothiazide",
  13. "chlorpromazine", "cimetidine", "clonidine", "corticosterone",
  14. "desiprarnine", "dexamethas", "dexamethas_beta_D_glucoside",
  15. "dexamethas_beta_D_glucuronide", "diazepam", "dopamine", "doxorubici",
  16. "erythromycin", "estradiol", "felodipine", "ganciclovir", "griseofulvin",
  17. "hydrochlorothiazide", "hydrocortisone", "ibuprophen", "imipramine",
  18. "indomethacin", "labetalol", "mannitol", "meloxicam", "methanol",
  19. "methotrexate", "methylscopolamine", "metoprolol", "nadolol",
  20. "naproxen", "nevirapine", "nicotine", "olsalazine", "oxprenolol",
  21. "oxprenolol ester", "phencyclidine", "Phenytoin", "pindolol",
  22. "pirenzepine", "piroxicam", "pnu200603", "practolol", "prazocin",
  23. "progesterone", "propranolol", "propranolo_ester", "quinidine",
  24. "ranitidine", "salicylic acid", "scopolamine", "sucrose", "sulfasalazine",
  25. "telmisartan", "terbutaline", "tesosterone", "timolol", "timolol_ester",
  26. "uracil", "urea", "warfarine", "zidovudine"), log_P_eff_exp = c(-5.83,
  27. -4.61, -5.06, -6.15, -4.62, -4.47, -4.44, -4.52, -5.4, -6.44,
  28. -4.81, -4.52, -5.1, -4.41, -4.69, -6.72, -4.7, -5.89, -4.59,
  29. -4.47, -4.67, -4.75, -6.54, -6.12, -4.32, -5.03, -6.8, -5.43,
  30. -4.77, -4.64, -6.27, -4.44, -6.06, -4.66, -4.28, -4.85, -4.69,
  31. -5.03, -6.21, -4.71, -4.58, -5.92, -6.16, -4.59, -5.41, -4.83,
  32. -4.52, -4.71, -6.96, -4.68, -4.51, -4.61, -4.57, -4.78, -6.36,
  33. -4.45, -6.25, -6.05, -4.36, -4.37, -4.58, -4.48, -4.69, -6.31,
  34. -4.79, -4.93, -5.77, -6.33, -4.82, -6.38, -4.34, -4.85, -4.6,
  35. -5.37, -5.34, -4.68, -5.16), log_D = c(-0.09, 1.59, -2.25, -1.8,
  36. 1.38, 2.78, 0.63, 2.22, -0.88, -1.81, 0.28, 0.63, 1.66, 0.02,
  37. 1.14, -1.15, 1.86, -0.36, 0.78, 1.78, 1.57, 1.89, 0.58, -1.59,
  38. 2.58, -0.8, -0.16, 1.26, 2.24, 3.48, -0.87, 2.47, -0.12, 1.48,
  39. 0.68, 2.52, 1, 1.24, -2.65, 0.03, -0.7, -2.53, -1.14, 0.51, 0.68,
  40. 0.42, 1.81, 0.41, -4.5, 0.45, 1.98, 1.31, 2.26, 0.19, -0.46,
  41. -0.07, -4, -1.4, 1.88, 3.48, 1.55, 3.02, 2.04, -0.12, -1.44,
  42. 0.21, -3.34, -0.42, 2.41, -1.07, 3.11, 0.03, 1.74, -1.11, -1.64,
  43. 0.64, -0.58), rgyr = c(4.64, 5.12, 3.41, 3.37, 3.68, 3.84, 2.97,
  44. 2.75, 4.02, 4.58, 5.41, 5.64, 3.43, 2.47, 3.75, 3.11, 3.74, 4.26,
  45. 2.79, 3.68, 3.4, 3.6, 5.67, 5.75, 3.28, 2.67, 4.85, 4.99, 3.44,
  46. 3.39, 3.7, 3.37, 3.11, 3.72, 3.45, 3.44, 4.16, 4.61, 2.48, 3.34,
  47. 0.84, 5.33, 3.67, 4.59, 4.37, 3.38, 2.94, 2.5, 4.62, 3.63, 3.87,
  48. 2.91, 2.97, 3.71, 3.55, 3.17, 3.89, 4.02, 4.96, 3.58, 3.63, 4.13,
  49. 3.25, 5.13, 2.14, 3.63, 3.49, 5.68, 5.29, 3.15, 3.33, 4.02, 3.98,
  50. 1.84, 1.23, 3.45, 3.14), rgyr_d = c(4.51, 5.03, 3.24, 3.23, 3.69,
  51. 3.88, 2.97, 2.75, 3.62, 4.52, 5.27, 5.39, 3.38, 2.47, 3.73, 3.11,
  52. 3.69, 4.24, 2.79, 3.71, 3.42, 3.66, 5.28, 5.23, 3.28, 2.68, 4.9,
  53. 5.01, 3.44, 3.48, 3.48, 3.37, 3.11, 3.79, 3.36, 3.45, 3.16, 4.46,
  54. 2.59, 3.36, 0.84, 5.18, 3.74, 4.53, 4.1, 3.43, 2.94, 2.5, 4.37,
  55. 3.56, 3.9, 2.91, 2.97, 3.71, 3.4, 3.26, 3.79, 4.09, 4.99, 3.62,
  56. 3.53, 4.06, 3.3, 4.57, 2.14, 3.49, 3.54, 5.53, 5.01, 3.15, 3.33,
  57. 4.01, 4.13, 1.84, 1.23, 3.5, 3.13), HCPSA = c(82.88, 77.08, 79.38,
  58. 120.63, 38.92, 35.53, 20.81, 54.27, 102.05, 86.82, 43.02, 47.14,
  59. 49.56, 45.55, 113.73, 138.76, 4.6, 105.44, 30.03, 75.95, 13.8,
  60. 90.74, 163.95, 186.88, 25.93, 75.13, 186.78, 138.69, 44.34, 50.34,
  61. 139.45, 67.55, 142.85, 93.37, 39.86, 3.56, 67.13, 93.29, 127.46,
  62. 93.21, 25.64, 204.96, 51.29, 44.88, 86.73, 76.98, 36.68, 15.1,
  63. 144.08, 48.62, 49.58, 1.49, 65.63, 52.8, 59.71, 99.19, 69.89,
  64. 64.79, 86.76, 38.1, 40.42, 36.21, 43.77, 105.15, 61.71, 57.35,
  65. 187.69, 133.67, 55.48, 79.52, 42.35, 100.74, 96.25, 66.72, 82.72,
  66. 59.47, 96.33), TPSA = c(87.66, 93.73, 89.9, 114.76, 41.49, 47.56,
  67. 26.79, 53.99, 100.52, 84.58, 50.72, 56.79, 43.7, 58.44, 115.38,
  68. 118.69, 6.48, 88.89, 36.42, 74.6, 15.27, 94.83, 173.98, 191.05,
  69. 32.67, 66.48, 206.07, 193.91, 40.46, 64.63, 134.99, 71.06, 118.36,
  70. 94.83, 37.3, 6.48, 68.53, 95.58, 121.38, 99.6, 20.23, 210.54,
  71. 59.06, 50.72, 81.95, 46.53, 58.12, 16.13, 139.78, 50.72, 56.79,
  72. 3.24, 58.2, 57.28, 68.78, 99.6, 91.44, 70.59, 106.95, 34.14,
  73. 41.49, 47.56, 45.59, 86.26, 57.53, 62.3, 189.53, 141.31, 72.94,
  74. 72.72, 37.3, 79.74, 85.81, 58.2, 69.11, 63.6, 103.59), N_rotb = c(0.31,
  75. 0.29, 0.23, 0.21, 0.29, 0.27, 0.17, 0.07, 0.16, 0.29, 0.27, 0.26,
  76. 0.15, 0.12, 0.28, 0.08, 0.14, 0.33, 0.08, 0.1, 0.11, 0.13, 0.17,
  77. 0.17, 0.06, 0.23, 0.18, 0.21, 0.06, 0.22, 0.25, 0.16, 0.08, 0.12,
  78. 0.24, 0.13, 0.19, 0.24, 0.44, 0.16, 0.2, 0.26, 0.16, 0.3, 0.24,
  79. 0.19, 0.05, 0.07, 0.27, 0.31, 0.29, 0.04, 0.06, 0.23, 0.08, 0.13,
  80. 0.15, 0.29, 0.15, 0.07, 0.22, 0.22, 0.14, 0.33, 0.19, 0.15, 0.28,
  81. 0.2, 0.15, 0.29, 0.06, 0.24, 0.23, 0, 0.29, 0.15, 0.18), log_P_eff_calc = c(-5.3,
  82. -4.89, -5.77, -5.91, -4.58, -4.39, -4.63, -4.47, -5.64, -5.85,
  83. -5.2, -5.13, -4.57, -4.89, -5.11, -5.87, -4.38, -5.55, -4.69,
  84. -4.78, -4.46, -4.77, -5.83, -6.55, -4.45, -5.27, -6, -5.13, -4.57,
  85. -4.44, -5.79, -4.59, -5.62, -4.94, -4.78, -4.28, -5, -5.09, -5.87,
  86. -5.27, -4.67, -6.79, -5.37, -4.99, -5.15, -5.09, -4.49, -4.65,
  87. -6.97, -4.84, -4.45, -4.42, -4.6, -5.02, -5.3, -5.31, -6.37,
  88. -5.5, -5.05, -4.54, -4.57, -4.5, -4.46, -5.6, -5.29, -5.07, -6.56,
  89. -6.06, -4.85, -5.36, -4.53, -5.35, -4.82, -5.23, -5.29, -4.95,
  90. -5.43), residuals = c(-0.53, 0.28, 0.71, -0.24, -0.04, -0.08,
  91. 0.19, -0.05, 0.24, -0.59, 0.39, 0.61, -0.53, 0.48, 0.42, -0.85,
  92. -0.32, -0.34, 0.1, 0.31, -0.21, 0.02, -0.71, 0.43, 0.13, 0.24,
  93. -0.8, -0.3, -0.2, -0.2, -0.48, 0.15, -0.44, 0.28, 0.5, -0.57,
  94. 0.31, 0.06, -0.34, 0.56, 0.09, 0.87, -0.79, 0.4, -0.26, 0.26,
  95. -0.03, -0.06, 0.01, 0.16, -0.06, -0.19, 0.03, 0.24, -1.06, 0.86,
  96. 0.12, -0.55, 0.69, 0.17, -0.01, 0.02, -0.23, -0.71, 0.5, 0.14,
  97. 0.79, -0.27, 0.03, -1.02, 0.19, 0.5, 0.22, -0.14, -0.05, 0.27,
  98. 0.27)), row.names = c(NA, -77L), class = c("tbl_df", "tbl", "data.frame"
  99. ))
  100. set.seed(42)
  101. #transform and split train on x and y
  102. train_ind <- sample(1:77, 60)
  103. x_train <- as.matrix(df[train_ind, 2:8])
  104. y_train <- unlist(df[train_ind, 9])
  105. x_val <- as.matrix(df[-train_ind, 2:8])
  106. y_val <- unlist(df[-train_ind, 9])
  107. # define model
  108. data <- mx.symbol.Variable("data")
  109. fc1 <- mx.symbol.FullyConnected(data,
  110. num_hidden = 1)
  111. linreg <- mx.symbol.LinearRegressionOutput(fc1)
  112. # define initializer
  113. initializer <- mx.init.normal(sd = 0.1)
  114. # define optimizer algorythm to update weigths
  115. optimizer <- mx.opt.create("sgd",
  116. learning.rate = 1e-6,
  117. momentum = 0.9)
  118. # define logger where we will keep updates
  119. logger <- mx.metric.logger()
  120. epoch.end.callback <- mx.callback.log.train.metric(
  121. period = 4, # число батчей, после которого оценивается метрика
  122. logger = logger)
  123. n_epoch <- 20
  124. # plot our model
  125. graph.viz(linreg)
  126. # train model with parameters
  127. model <- mx.model.FeedForward.create(
  128. symbol = linreg,
  129. X = x_train,
  130. y = y_train,
  131. ctx = mx.cpu(),
  132. num.round = n_epoch,
  133. initializer = initializer,
  134. optimizer = optimizer,
  135. eval.data = list(data = x_val, label = y_val),
  136. eval.metric = mx.metric.rmse,
  137. array.batch.size = 15,
  138. epoch.end.callback = epoch.end.callback)
  139. # plot train loss curve and eval
  140. rmse_log <- data.frame(RMSE = c(logger$train, logger$eval),
  141. dataset = c(rep("train",
  142. length(logger$train)),
  143. rep("val",
  144. length(logger$eval))),
  145. epoch = 1:n_epoch)
  146. library(ggplot2)
  147. ggplot(rmse_log, aes(epoch, RMSE,
  148. group = dataset,
  149. colour = dataset)) +
  150. geom_point() +
  151. geom_line()