Cross-entropy method in R

The cross entropy method is a general approach to optimization that relies in two nice ideas. In the context of finding the maximum of a scalar-valued function this means:

  1. Generate some random parameters and evaluate the function there.
  2. Use the best values of the parameter to generate new candidates.

One simple way of generating such random parameters is fitting a normal distribution at every iteration: we choose a subset of “elite parameters” (a fraction of those we tried), calculate their mean and covariance and use it to generate a new population of parameters.

This kind of optimization methods is quite useful for instance in reinforcement learning.

For instance, suppose we want to maximize the function:

# the function we want to optimize
f <-  function(theta){
  reward = -sum((solution - theta)**2)

  solution <- c(0.5, 0.1, -0.3)

Then the cross-entropy algorithm in this case is:


cem <- function(f, n_iter, theta_mean, theta_std, batch_size=25, elite_frac=0.2){
  # Now, for the algorithms
  for(it in 1:n_iter){
    # Sample parameter vectors
    thetas <-  matrix(mvrnorm(n=batch_size*dim_theta, mu= theta_mean, Sigma=theta_std)
, ncol = dim_theta)
    rewards <- apply(thetas,1,f) 
    # Get elite parameters
    n_elite <-  as.integer(batch_size * elite_frac)
    elite_inds <-  sort(rewards, decreasing = T, index.return=T)$ix[1:n_elite]
    elite_thetas <- thetas[elite_inds,]
    # Update theta_mean, theta_std
    theta_mean <- apply(elite_thetas, 2,mean)
    theta_std <- 0.01*diag(dim_theta)+0.99*cov(elite_thetas)


and we call this like:

  dim_theta <-  length(solution)
  theta_mean <-  matrix(0,dim_theta,1)
  theta_std <-  diag(dim_theta)
  batch_size <-  25 # number of samples per batch
  elite_frac <-  0.2 # fraction of samples used as elite set
cem(f,300, theta_mean=theta_mean, theta_std=theta_std
, batch_size=batch_size, elite_frac=elite_frac)

Author: Pablo Maldonado

Data Scientist, applied mathematician and aspiring drummer. On a quest for finding out how can math make the world a better place.

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.