mxnet.R 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # installation
  2. if (require(mxnet)!=TRUE) {
  3. cran <- getOption("repos")
  4. cran["dmlc"] <- "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/CRAN/GPU/cu92"
  5. options(repos = cran)
  6. install.packages("mxnet")
  7. }
  8. # load data and split
  9. df <- read_rds('data.rds')
  10. set.seed(100)
  11. #transform and split train on x and y
  12. train_ind <- sample(1:77, 60)
  13. x_train <- as.matrix(df[train_ind, 2:8])
  14. y_train <- unlist(df[train_ind, 9])
  15. x_val <- as.matrix(df[-train_ind, 2:8])
  16. y_val <- unlist(df[-train_ind, 9])
  17. iter_train_data <- mx.io.arrayiter(t(x_train),
  18. y_train,
  19. batch.size = 15,
  20. shuffle = TRUE)
  21. # define model
  22. data <- mx.symbol.Variable("data")
  23. fc1 <- mx.symbol.FullyConnected(data,
  24. num_hidden = 1)
  25. linreg <- mx.symbol.LinearRegressionOutput(fc1)
  26. # define initializer
  27. initializer <- mx.init.normal(sd = 0.1)
  28. # define optimizer algorythm to update weigths
  29. optimizer <- mx.opt.create("sgd",
  30. learning.rate = 1e-6,
  31. momentum = 0.9)
  32. # define logger where we will keep updates
  33. logger <- mx.metric.logger()
  34. epoch.end.callback <- mx.callback.log.train.metric(
  35. period = 4, # число батчей, после которого оценивается метрика
  36. logger = logger)
  37. n_epoch <- 20
  38. # plot our model
  39. graph.viz(linreg)
  40. # train model with parameters
  41. model <- mx.model.FeedForward.create(
  42. symbol = linreg,
  43. X = x_train,
  44. y = y_train,
  45. ctx = mx.cpu(),
  46. num.round = n_epoch,
  47. initializer = initializer,
  48. optimizer = optimizer,
  49. eval.data = list(data = x_val, label = y_val),
  50. eval.metric = mx.metric.rmse,
  51. array.batch.size = 15,
  52. epoch.end.callback = epoch.end.callback)
  53. # plot train loss curve and eval
  54. rmse_log <- data.frame(RMSE = c(logger$train, logger$eval),
  55. dataset = c(rep("train",
  56. length(logger$train)),
  57. rep("val",
  58. length(logger$eval))),
  59. epoch = 1:n_epoch)
  60. library(ggplot2)
  61. ggplot(rmse_log, aes(epoch, RMSE,
  62. group = dataset,
  63. colour = dataset)) +
  64. geom_point() +
  65. geom_line()