1. Вступление

Это сообщение является продолжением Deep learning с использованием языка R и библиотеки mxnet. Установка и начало работы. Будет рассмотрено предсказание классов изображений на основе модели, а также работа с итераторами и некоторые другие аспекты.

Полезные ссылки:

End-to-End Deep Learning Tutorial,

https://github.com/dmlc/mxnet/tree/master/docs/tutorials/r,

https://github.com/dmlc/mxnet/tree/master/R-package/vignettes.

По двум последним ссылкам доступна самая актуальная документация и примеры от разработчиков.

2. Создание “словаря синонимов”

В комплекте с готовыми моделями, доступными для скачивания, обычно идет файл synset.txt. Этот файл содержит информацию о соответствии между номером класса и его названием/меткой, например 0 "airplane". При создании бинарного файла с изображениями были также созданы файлы в формате .lst, из которых легко получить нужную нам таблицу:

setwd("~/R/cifar10/")
library(mxnet)
library(imager)
library(abind)
labels <- read.table("cifar_train.lst")

# Оставляем только уникальные значения
labels <- labels[!duplicated(labels$V2), ]

# Разделяем имена файлов и имена папок, которые соответствуют меткам классов
tmp <- strsplit(as.character(labels$V3), split = "/", fixed = TRUE)

# Создаем таблицу с метками и номерами классов, сортируем по возрастанию номеров
class_labels <- data.frame(response = labels$V2,
                           label = sapply(tmp, function(x) x[1]))
class_labels <- class_labels[order(class_labels$response), ]

# Сохраняем в файл "synset.txt" для дальнейшего использования
write.table(class_labels, "synset.txt", row.names = FALSE, col.names = TRUE)
class_labels <- read.table("synset.txt", header = TRUE)

3. Обучение модели

Повторим обучение той же модели с тем же набором данных, что и в прошлый раз. Но теперь укажем размер изображений 32х30, то есть 32 пикселя в высоту и 30 пикселей в ширину (исходные картинки 32х32 будут обрезаться). Это нужно для лучшего понимания того, как правильно указывать размерности на этапе обучения модели и при ее использовании для предсказаний. Значения аргументов kernel, stride, pad задаются всегда в том же формате: сначала высота (y-координата), затем ширина (x-координата); третьим числов в векторе размерностей может быть глубина.

Создаем итераторы:

get_iterator <- function(data_shape, 
                         train_data, 
                         val_data, 
                         batch_size = 128) {
    train <- mx.io.ImageRecordIter(
        path.imgrec = train_data,
        batch.size  = batch_size,
        data.shape  = data_shape,
        rand.crop   = TRUE,
        rand.mirror = TRUE)
  
    val <- mx.io.ImageRecordIter(
        path.imgrec = val_data,
        batch.size  = batch_size,
        data.shape  = data_shape,
        rand.crop   = FALSE,
        rand.mirror = FALSE
        )
 
  return(list(train = train, val = val))
}
data  <- get_iterator(data_shape = c(32, 30, 3), # 32 пикселя в высоту
                      train_data = "/home/andrey/R/cifar10/cifar_train.rec",
                      val_data   = "/home/andrey/R/cifar10/cifar_val.rec",
                      batch_size = 100)
train <- data$train
val   <- data$val

Используем ту же архитектуру Resnet:

conv_factory <- function(data, num_filter, kernel, stride,
                         pad, act_type = 'relu', conv_type = 0) {
    if (conv_type == 0) {
      conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
                                   kernel = kernel, stride = stride, pad = pad)
      bn = mx.symbol.BatchNorm(data = conv)
      act = mx.symbol.Activation(data = bn, act_type = act_type)
      return(act)
    } else if (conv_type == 1) {
      conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
                                   kernel = kernel, stride = stride, pad = pad)
      bn = mx.symbol.BatchNorm(data = conv)
      return(bn)
    }
}

residual_factory <- function(data, num_filter, dim_match) {
  if (dim_match) {
    identity_data = data
    conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
                         stride = c(1, 1), pad = c(1, 1), act_type = 'relu', conv_type = 0)
    
    conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
                         stride = c(1, 1), pad = c(1, 1), conv_type = 1)
    new_data = identity_data + conv2
    act = mx.symbol.Activation(data = new_data, act_type = 'relu')
    return(act)
  } else {
    conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
                         stride = c(2, 2), pad = c(1, 1), act_type = 'relu', conv_type = 0)
    conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
                         stride = c(1, 1), pad = c(1, 1), conv_type = 1)
    
    # adopt project method in the paper when dimension increased
    project_data = conv_factory(data = data, num_filter = num_filter, kernel = c(1, 1),
                                stride = c(2, 2), pad = c(0, 0), conv_type = 1)
    new_data = project_data + conv2
    act = mx.symbol.Activation(data = new_data, act_type = 'relu')
    return(act)
  }
}

residual_net <- function(data, n) {
  #fisrt 2n layers
  for (i in 1:n) {
    data = residual_factory(data = data, num_filter = 16, dim_match = TRUE)
  }
  
  
  #second 2n layers
  for (i in 1:n) {
    if (i == 1) {
      data = residual_factory(data = data, num_filter = 32, dim_match = FALSE)
    } else {
      data = residual_factory(data = data, num_filter = 32, dim_match = TRUE)
    }
  }
  #third 2n layers
  for (i in 1:n) {
    if (i == 1) {
      data = residual_factory(data = data, num_filter = 64, dim_match = FALSE)
    } else {
      data = residual_factory(data = data, num_filter = 64, dim_match = TRUE)
    }
  }
  return(data)
}

get_symbol <- function(num_classes = 10) {
  conv <- conv_factory(data = mx.symbol.Variable(name = 'data'), num_filter = 16,
                      kernel = c(3, 3), stride = c(1, 1), pad = c(1, 1),
                      act_type = 'relu', conv_type = 0)
  n <- 3 # set n = 3 means get a model with 3*6+2=20 layers, set n = 9 means 9*6+2=56 layers
  resnet <- residual_net(conv, n) #
  pool <- mx.symbol.Pooling(data = resnet, kernel = c(7, 7), pool_type = 'avg')
  flatten <- mx.symbol.Flatten(data = pool, name = 'flatten')
  fc <- mx.symbol.FullyConnected(data = flatten, num_hidden = num_classes, name = 'fc1')
  softmax <- mx.symbol.SoftmaxOutput(data = fc, name = 'softmax')
  return(softmax)
}

# Сеть для 10 классов
resnet <- get_symbol(10)

Обучаем модель в течение 15 эпох:

model <- mx.model.FeedForward.create(
  symbol             = resnet,
  X                  = train,
  eval.data          = val,
  ctx                = mx.gpu(0),
  eval.metric        = mx.metric.accuracy,
  num.round          = 15,
  learning.rate      = 0.05,
  momentum           = 0.9,
  wd                 = 0.00001,
  kvstore            = "local",
  array.batch.size   = 100,
  epoch.end.callback = NULL,
  batch.end.callback = mx.callback.log.train.metric(150),
  initializer        = mx.init.Xavier(factor_type = "in", magnitude = 2.34),
  optimizer          = "sgd"
)
## Start training with 1 devices
## Batch [150] Train-accuracy=0.2868
## Batch [300] Train-accuracy=0.347766666666667
## [1] Train-accuracy=0.38280701754386
## [1] Validation-accuracy=0.4331
## Batch [150] Train-accuracy=0.527066666666666
## Batch [300] Train-accuracy=0.557066666666667
## [2] Train-accuracy=0.5738
## [2] Validation-accuracy=0.5977
## Batch [150] Train-accuracy=0.637066666666666
## Batch [300] Train-accuracy=0.649399999999999
## [3] Train-accuracy=0.660374999999999
## [3] Validation-accuracy=0.6443
## Batch [150] Train-accuracy=0.704533333333333
## Batch [300] Train-accuracy=0.711766666666666
## [4] Train-accuracy=0.718349999999999
## [4] Validation-accuracy=0.7042
## Batch [150] Train-accuracy=0.7422
## Batch [300] Train-accuracy=0.747533333333334
## [5] Train-accuracy=0.752975000000001
## [5] Validation-accuracy=0.7198
## Batch [150] Train-accuracy=0.7704
## Batch [300] Train-accuracy=0.772866666666667
## [6] Train-accuracy=0.777025
## [6] Validation-accuracy=0.7485
## Batch [150] Train-accuracy=0.79
## Batch [300] Train-accuracy=0.790533333333334
## [7] Train-accuracy=0.794425000000001
## [7] Validation-accuracy=0.772699999999999
## Batch [150] Train-accuracy=0.806733333333333
## Batch [300] Train-accuracy=0.807766666666668
## [8] Train-accuracy=0.8109
## [8] Validation-accuracy=0.7733
## Batch [150] Train-accuracy=0.822799999999999
## Batch [300] Train-accuracy=0.821533333333334
## [9] Train-accuracy=0.825025
## [9] Validation-accuracy=0.7679
## Batch [150] Train-accuracy=0.8332
## Batch [300] Train-accuracy=0.833066666666668
## [10] Train-accuracy=0.834800000000001
## [10] Validation-accuracy=0.7856
## Batch [150] Train-accuracy=0.845333333333333
## Batch [300] Train-accuracy=0.842533333333334
## [11] Train-accuracy=0.843475
## [11] Validation-accuracy=0.7886
## Batch [150] Train-accuracy=0.8502
## Batch [300] Train-accuracy=0.850333333333335
## [12] Train-accuracy=0.851950000000001
## [12] Validation-accuracy=0.7801
## Batch [150] Train-accuracy=0.860866666666666
## Batch [300] Train-accuracy=0.858900000000001
## [13] Train-accuracy=0.8601
## [13] Validation-accuracy=0.8068
## Batch [150] Train-accuracy=0.8648
## Batch [300] Train-accuracy=0.865733333333334
## [14] Train-accuracy=0.86715
## [14] Validation-accuracy=0.8171
## Batch [150] Train-accuracy=0.867466666666666
## Batch [300] Train-accuracy=0.869833333333333
## [15] Train-accuracy=0.871949999999999
## [15] Validation-accuracy=0.8096

4. Предсказания на основе модели

Для работы с изображениями в R будем использовать пакет imager.

В общих чертах процесс описан в руководстве Classify Images with a Pretrained Model, но там есть некоторые нюансы и неточности. Примерами будут служить следующие два изображения:

bird.png

bird.png

deer.jpg

deer.jpg

Разберем операции предварительной обработки подробно для первого изображения:

# Скачиваем или загружаем с диска изображение
im <- load.image("http://kindersay.com/files/images/bird.png")
# Image. Width: 445 pix Height: 355 pix Depth: 1 Colour channels: 3 
# Размерность: ширина на высоту
# Depth - количество кадров, если это видео; для изображений всегда 1 
# Colour channels - 3 цветовых канала (RGB)

shape <- dim(im)
# 445 355   1   3
# Индексация идет сначала по столбцам (ширина), затем по строкам (высота) -
# изображение из линейного вектора формируется именно в таком порядке
# В R матрицы формируются и индексируются в обратном порядке,
# то есть используется так называемый Fortran-style, или порядок column-major:
# a <- 1:4
# dim(a) <- c(2, 2)
# a
#      [,1] [,2]
# [1,]    1    3
# [2,]    2    4

# Меняем размер на 32х30, требуемый для нашей модели
# Обрезка (crop) не используется
resized <- resize(im,  size_x = 30, size_y = 32)
# Image. Width: 30 pix Height: 32 pix Depth: 1 Colour channels: 3 

# Конвертируем в массив
# Если значения для каждого цветового канала заданы в диапазоне [0, 1], 
# то нужно умножить на 255. В нашем случае это не требуется
arr <- as.array(resized) 
# 30 32  1  3 - 30 строк, а не 32 строки, как в изображении
# Произошло транспонирование: строки стали столбцами

# Средние значения для каждого пикселя не отнимаем
# Задаем нужный формат (width, height, channel, num)
dim(arr) <- c(30, 32, 3, 1)

# Предсказываем вероятности и класс
prob <- predict(model, X = arr)
prob
##              [,1]
##  [1,] 0.097901896
##  [2,] 0.002545726
##  [3,] 0.742579162
##  [4,] 0.041871570
##  [5,] 0.016076958
##  [6,] 0.017367113
##  [7,] 0.011065663
##  [8,] 0.066732869
##  [9,] 0.001639599
## [10,] 0.002219437
class_labels$label[prob == max(prob)]
## [1] bird
## Levels: airplane automobile bird cat deer dog frog horse ship truck

Все этапы предварительной обработки можно оформить в виде функции (измененный вариант preproc.image из https://github.com/dmlc/mxnet/blob/master/docs/tutorials/r/classifyRealImageWithPretrainedModel.md):

preproc_image <- function(src,              # URL or file location
                          height,        
                          width,  
                          num_channels = 3, # 3 for RGB, 1 for grayscale
                          mult_by = 1,      # set to 255 for normalized image
                          crop = FALSE) {   # no crop by default
    
    im <- load.image(src)
    
    if (crop) {
        shape <- dim(im)
        short_edge <- min(shape[1:2])
        xx <- floor((shape[1] - short_edge) / 2)
        yy <- floor((shape[2] - short_edge) / 2) 
        im <- crop.borders(im, xx, yy)
    }
    
    resized <- resize(im,  size_x = width, size_y = height)
    arr <- as.array(resized) * mult_by
    dim(arr) <- c(width, height, num_channels, 1)
    return(arr)
} 

Предсказание для картинки с оленем:

arr <- preproc_image("http://kingofwallpapers.com/deer-images/deer-images-007.jpg",
                     height = 32,
                     width = 30)
prob <- predict(model, X = arr)
prob
##               [,1]
##  [1,] 1.974541e-05
##  [2,] 4.638057e-06
##  [3,] 3.291356e-02
##  [4,] 6.472487e-03
##  [5,] 8.688778e-01
##  [6,] 1.603282e-02
##  [7,] 4.853056e-03
##  [8,] 7.081089e-02
##  [9,] 2.876292e-06
## [10,] 1.208015e-05
class_labels$label[prob == max(prob)]
## [1] deer
## Levels: airplane automobile bird cat deer dog frog horse ship truck

Рассмотренную особенность с порядком индексации массивов нужно учитывать и при использовании итераторов по файлам в формате .csv, таких как в этом примере. При создании такого файла исходное изображение (или несколько изображений, которые соответствуют одному наблюдению) превращаются в вектор, вектор становится строкой файла, а затем при обучении модели и при использовании модели для предсказаний задается правильная размерность массива. Это позволяет воспроизвести пространственную структуру входных данных, хранящихся в линейном виде.

Вопросы оптимальной реализации этих операций в данном сообщении не рассматриваются, но при работе с большим количество изображений наверняка пригодятся пакеты типа foreach и doParallel для параллельной обработки, также может быть полезен пакет data.table и консольные утилиты.

Простейший пример обработки нескольких изображений с использованием пакета abind:

image_urls <- c(
    "http://kindersay.com/files/images/bird.png",
    "http://kingofwallpapers.com/deer-images/deer-images-007.jpg"
)

images <- lapply(image_urls,
                 preproc_image,
                 height = 32,
                 width = 30)

images <- do.call(abind, images)

probs <- predict(model, X = images)

probs
##              [,1]         [,2]
##  [1,] 0.097901933 1.974543e-05
##  [2,] 0.002545726 4.638062e-06
##  [3,] 0.742579103 3.291357e-02
##  [4,] 0.041871566 6.472491e-03
##  [5,] 0.016076960 8.688778e-01
##  [6,] 0.017367108 1.603283e-02
##  [7,] 0.011065662 4.853058e-03
##  [8,] 0.066732846 7.081092e-02
##  [9,] 0.001639599 2.876297e-06
## [10,] 0.002219437 1.208016e-05
class_labels$label[apply(probs, 2, function(x) which(x == max(x)))]
## [1] bird deer
## Levels: airplane automobile bird cat deer dog frog horse ship truck

5. Итераторы

Для работы с любыми данными, которые не помещаются в памяти, можно использовать функцию mx.io.CSVIter(). Как понятно из названия, она обрабатывает файлы в формате .csv построчно и конструирует из каждой строки массив (тензор) нужной размерности, которая задается аргументами data.shape для самого набора данных и label.shape для целевой переменной. См. https://github.com/dmlc/mxnet/tree/master/example/kaggle-ndsb2. За создание .csv-файлов там отвечает код на Python. Аналог на R можно написать, взяв за основу представленную выше функцию preproc_image() и заменив dim(arr) <- c(width, height, num_channels, 1) на dim(arr) <- c(width * height * num_channels * 1).

Также есть возможность создавать свои собственные итераторы - см. Custom Iterator Tutorial. Чтобы сделать что-то действительно серьезно отличающееся от представленного варианта, понадобятся знания C++.

6. Доступные слои и функции потерь

Список слоев довольно обширен:

apropos("mx.symbol.")
##   [1] "mx.symbol.abs"                      
##   [2] "mx.symbol.Activation"               
##   [3] "mx.symbol.adam_update"              
##   [4] "mx.symbol.arccos"                   
##   [5] "mx.symbol.arccosh"                  
##   [6] "mx.symbol.arcsin"                   
##   [7] "mx.symbol.arcsinh"                  
##   [8] "mx.symbol.arctan"                   
##   [9] "mx.symbol.arctanh"                  
##  [10] "mx.symbol.argmax"                   
##  [11] "mx.symbol.argmax_channel"           
##  [12] "mx.symbol.argmin"                   
##  [13] "mx.symbol.argsort"                  
##  [14] "mx.symbol.batch_dot"                
##  [15] "mx.symbol.BatchNorm"                
##  [16] "mx.symbol.BlockGrad"                
##  [17] "mx.symbol.broadcast_add"            
##  [18] "mx.symbol.broadcast_axis"           
##  [19] "mx.symbol.broadcast_div"            
##  [20] "mx.symbol.broadcast_equal"          
##  [21] "mx.symbol.broadcast_greater"        
##  [22] "mx.symbol.broadcast_greater_equal"  
##  [23] "mx.symbol.broadcast_hypot"          
##  [24] "mx.symbol.broadcast_lesser"         
##  [25] "mx.symbol.broadcast_lesser_equal"   
##  [26] "mx.symbol.broadcast_maximum"        
##  [27] "mx.symbol.broadcast_minimum"        
##  [28] "mx.symbol.broadcast_minus"          
##  [29] "mx.symbol.broadcast_mul"            
##  [30] "mx.symbol.broadcast_not_equal"      
##  [31] "mx.symbol.broadcast_plus"           
##  [32] "mx.symbol.broadcast_power"          
##  [33] "mx.symbol.broadcast_sub"            
##  [34] "mx.symbol.broadcast_to"             
##  [35] "mx.symbol.Cast"                     
##  [36] "mx.symbol.ceil"                     
##  [37] "mx.symbol.choose_element_0index"    
##  [38] "mx.symbol.clip"                     
##  [39] "mx.symbol.Concat"                   
##  [40] "mx.symbol.Convolution"              
##  [41] "mx.symbol.Correlation"              
##  [42] "mx.symbol.cos"                      
##  [43] "mx.symbol.cosh"                     
##  [44] "mx.symbol.crop"                     
##  [45] "mx.symbol.Crop"                     
##  [46] "mx.symbol.CuDNNBatchNorm"           
##  [47] "mx.symbol.Custom"                   
##  [48] "mx.symbol.Deconvolution"            
##  [49] "mx.symbol.degrees"                  
##  [50] "mx.symbol.dot"                      
##  [51] "mx.symbol.Dropout"                  
##  [52] "mx.symbol.ElementWiseSum"           
##  [53] "mx.symbol.elemwise_add"             
##  [54] "mx.symbol.Embedding"                
##  [55] "mx.symbol.exp"                      
##  [56] "mx.symbol.expand_dims"              
##  [57] "mx.symbol.expm1"                    
##  [58] "mx.symbol.fill_element_0index"      
##  [59] "mx.symbol.fix"                      
##  [60] "mx.symbol.Flatten"                  
##  [61] "mx.symbol.flip"                     
##  [62] "mx.symbol.floor"                    
##  [63] "mx.symbol.FullyConnected"           
##  [64] "mx.symbol.gamma"                    
##  [65] "mx.symbol.gammaln"                  
##  [66] "mx.symbol.Group"                    
##  [67] "mx.symbol.identity"                 
##  [68] "mx.symbol.IdentityAttachKLSparseReg"
##  [69] "mx.symbol.infer.shape"              
##  [70] "mx.symbol.InstanceNorm"             
##  [71] "mx.symbol.L2Normalization"          
##  [72] "mx.symbol.LeakyReLU"                
##  [73] "mx.symbol.LinearRegressionOutput"   
##  [74] "mx.symbol.load"                     
##  [75] "mx.symbol.load.json"                
##  [76] "mx.symbol.log"                      
##  [77] "mx.symbol.log10"                    
##  [78] "mx.symbol.log1p"                    
##  [79] "mx.symbol.log2"                     
##  [80] "mx.symbol.LogisticRegressionOutput" 
##  [81] "mx.symbol.LRN"                      
##  [82] "mx.symbol.MAERegressionOutput"      
##  [83] "mx.symbol.MakeLoss"                 
##  [84] "mx.symbol.max"                      
##  [85] "mx.symbol.max_axis"                 
##  [86] "mx.symbol.min"                      
##  [87] "mx.symbol.min_axis"                 
##  [88] "mx.symbol.nanprod"                  
##  [89] "mx.symbol.nansum"                   
##  [90] "mx.symbol.negative"                 
##  [91] "mx.symbol.norm"                     
##  [92] "mx.symbol.normal"                   
##  [93] "mx.symbol.Pad"                      
##  [94] "mx.symbol.Pooling"                  
##  [95] "mx.symbol.prod"                     
##  [96] "mx.symbol.radians"                  
##  [97] "mx.symbol.Reshape"                  
##  [98] "mx.symbol.rint"                     
##  [99] "mx.symbol.RNN"                      
## [100] "mx.symbol.ROIPooling"               
## [101] "mx.symbol.round"                    
## [102] "mx.symbol.rsqrt"                    
## [103] "mx.symbol.save"                     
## [104] "mx.symbol.SequenceLast"             
## [105] "mx.symbol.SequenceMask"             
## [106] "mx.symbol.SequenceReverse"          
## [107] "mx.symbol.sgd_mom_update"           
## [108] "mx.symbol.sgd_update"               
## [109] "mx.symbol.sign"                     
## [110] "mx.symbol.sin"                      
## [111] "mx.symbol.sinh"                     
## [112] "mx.symbol.slice_axis"               
## [113] "mx.symbol.SliceChannel"             
## [114] "mx.symbol.smooth_l1"                
## [115] "mx.symbol.Softmax"                  
## [116] "mx.symbol.SoftmaxActivation"        
## [117] "mx.symbol.softmax_cross_entropy"    
## [118] "mx.symbol.SoftmaxOutput"            
## [119] "mx.symbol.sort"                     
## [120] "mx.symbol.SpatialTransformer"       
## [121] "mx.symbol.sqrt"                     
## [122] "mx.symbol.square"                   
## [123] "mx.symbol.sum"                      
## [124] "mx.symbol.sum_axis"                 
## [125] "mx.symbol.SVMOutput"                
## [126] "mx.symbol.SwapAxis"                 
## [127] "mx.symbol.tan"                      
## [128] "mx.symbol.tanh"                     
## [129] "mx.symbol.topk"                     
## [130] "mx.symbol.transpose"                
## [131] "mx.symbol.uniform"                  
## [132] "mx.symbol.UpSampling"               
## [133] "mx.symbol.Variable"

Если в названии функции последнее слово начинается с большой буквы - эта функция создает слой; если в конце имеется слово Output, то в этом “символе” есть не только выходной слой нейросети, но и функция потерь вместе со всем необходимым для операций обратного распространения ошибки. Например, функция mx.symbol.SoftmaxActivation() создает слой с активацией softmax, после которого можно указать свою собственную функцию потерь. А если использовать mx.symbol.SoftmaxOutput, то перекрестная энтропия (cross-entropy, в данном случае это синоним logloss) сразу будет использоваться в качестве функции потерь.

Остальные функции отвечают за операции над “символами”, которые являются аналогами соответствующих операций над значениями. Для арифметических операторов +, -, * и / добавлены соответствующие методы.

В руководстве Customized loss function рассматривается создание собственной функции потерь с помощью mx.symbol.MakeLoss(), также полезные материалы есть по этой ссылке. Выглядит это вот так:

data <- mx.symbol.Variable("data")

fc1 <- mx.symbol.FullyConnected(data, num_hidden=1)

lro <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc1, shape = 0) - label))
# Аналог mx.symbol.LinearRegressionOutput()

7. Использование активаций скрытых слоев

Активации скрытых слоев нейросети можно использовать как для визуализации ее работы, так и в качестве признаков для других алгоритмов машинного обучения. Особенно это полезно, когда данных для обучения с нуля своей нейросети не хватает. В таком случае можно взять предобученную сеть для того же класса задач и получить не только финальные предсказания (которые скорее всего будут бесполезными), но и активации, например, последнего полносвязного слоя. Такие активации могут содержать высокоуровневые признаки, релевантные для решаемой задачи.

Руководства от разработчиков на эту тему пока нет, но в обсуждении приводится готовое решение. Модель обучается обычным образом, дополнительно создается executor - объект, параметры которого можно обновить, используя параметры обученной модели. В конце выполняем forward pass и получаем значения на нужных нам слоях:

# Group some output layers for visual analysis
out <- mx.symbol.Group(c(convAct1, poolLayer1, convAct2, poolLayer2, LeNet1))
# Create an executor
executor <- mx.simple.bind(symbol = out, 
                           data = dim(test.array), 
                           ctx = mx.cpu())

# Update parameters
mx.exec.update.arg.arrays(executor, 
                          model$arg.params, 
                          match.name = TRUE)
mx.exec.update.aux.arrays(executor, 
                          model$aux.params, 
                          match.name = TRUE)
# Select data to use
mx.exec.update.arg.arrays(executor, 
                          list(data = mx.nd.array(test.array)), 
                          match.name = TRUE)
# Do a forward pass with the current parameters and data
mx.exec.forward(executor, 
                is.train = FALSE)
names(executor$ref.outputs)

Предобученную модель можно использовать и другим образом: можно дообучить ее на своих данных, подобно тому, как мы продолжаем обучение своей собственной модели с “контрольной точки”. Также можно при этом поменять конфигурацию сети - см. руководство для Python. Наверное, в R проще всего отредактировать .json-файл, содержащий описание архитектуры сети: сохраняем resnet$as.json(), редактируем, загружаем файл с помощью mx.symbol.load().

Продолжение, надеюсь, следует, но уже немного в другом формате.