
The script learn.py is used to learn the optimal value function using a neural network model.

The training algorithm is a value-based version of the DQN algorithm (Mnih et al., Nature 2015).

The script periodically evaluates the performance of the RL policy derived from the current value function.

Training can be interrupted at any time, and be restarted from a previously saved model.

Outputs generated during training are:
  • *figure_learning_progress.png

    It monitors the progress of the learning agent and is periodically updated as the training progresses. It shows the evolution of

    • ‘eps’: current value of eps (probability of taking a random action) used when collecting new experience

    • ‘p_not_found’: probability that the source is never found

    • ‘mean’: mean number of steps to find the source, provided it is ever found (if ‘p_not_found’ > 1e-3, the mean is depicted by a cross instead of a dot)

    • ‘rel_mean’: ‘mean’ divided by the mean obtained for the reference policy

    • ‘p50’: number of steps to find the source with 50 % probability

    • ‘p99’: number of steps to find the source with 99 % probability

    as a function of training iterations (top row) and number of transitions used for training (bottom row).

  • *figure_stats_i.png

    A figure is produced every time the learning agent is evaluated (i is an integer). It shows distributions (pdf, cdf, ccdf) and statistics (mean, standard deviation, median, etc.) of the number of steps to find the source.

  • *table_stats.npy

    Numpy array containing the history of performance. Each row corresponds to a training iteration where the performance was evaluated, and the columns contain

    • col 0: training iteration

    • col 1: number of transitions seen

    • col 2: number of transitions generated

    • col 3: eps (probability of taking a random action)

    • col 4: probability that the source is never found

    • col 5: mean number of steps to find the source, provided that the source is ultimately found

    • col 6: error on the mean

    • col 7: number of steps to find the source with 50 % probability

    • col 8: number of steps to find the source with 99 % probability

  • *parameters.txt

    Text file summarizing the parameters used.

Models are saved periodically:
  • *model

    most recent model

  • *model_bkp_i

    models saved at evaluation points (i is an integer), that is, models which performance is shown in figures

Parameters of the script are:
  • Source-tracking POMDP
    • N_DIMS (int > 0)

      number of space dimensions (1D, 2D, …)

      Warning: while the script works for any space dimension, it is very computationally intensive and is generally not usable in more than 2D.

    • LAMBDA_OVER_DX (float >= 1.0)

      dimensionless problem size

    • R_DT (float > 0.0)

      dimensionless source intensity

    • NORM_POISSON (‘Euclidean’, ‘Manhattan’ or ‘Chebyshev’)

      norm used for hit detections, usually ‘Euclidean’

    • N_HITS (int >= 2 or None)

      number of possible hit values, set automatically if None

    • N_GRID (odd int >= 3 or None)

      linear size of the domain, set automatically if None

  • Reinforcement learning
    Neural network architecture (fully connected)
    • FC_LAYERS (int > 0)

      number of hidden layers

    • FC_UNITS (int > 0 or tuple(int > 0))

      number of units per layers

    Stochastic gradient descent
    • BATCH_SIZE (int > 0)

      size of the mini-batch

    • N_GD_STEPS (int > 0)

      number of gradient descent steps per training iteration

    • LEARNING_RATE (0.0 < float < 1.0)

      usual learning rate

    Exploration: eps is the probability of taking a random action
    • E_GREEDY_FLOOR (0.0 <= float <= 1.0)

      floor value of eps

    • E_GREEDY_0 (0.0 <= float <= 1.0)

      initial value of eps

    • E_GREEDY_DECAY (float > 0.0)

      timescale for eps decay, in number of training iterations

    Accounting for symmetries:

      whether to average value over symmetric duplicates during evaluation


      whether to augment data by including symmetric duplicates with identical targets during training step


      whether to apply random symmetry transformations when generating the data (no duplicates)

    Experience replay
    • MEMORY_SIZE (int > 0)

      number of transitions (s, s’) to keep in memory

    • REPLAY_NTIMES (int > 0)

      how many times a transition is used for training before being deleted, on average

    Additional DQN algo parameters
    • ALGO_MAX_IT (int > 0)

      max number of training iterations


      how often is the target network updated, in number of training iterations

    • DDQN (bool)

      whether to use Double DQN instead of original DQN

    Evaluation of the RL policy
    • POLICY_REF (int)

      heuristic policy to use for comparison


      how often is the RL policy evaluated, in number of training iterations

    • N_RUNS_STATS (int > 0 or None)

      number of episodes used to compute the stats of a policy, set automatically if None

    Monitoring/Saving during the training
    • PRINT_INFO_EVERY (int > 0)

      how often to print info on screen, in number of training iterations

    • SAVE_MODEL_EVERY (int > 0)

      how often to save the current model, in number of training iterations (in addition, model copies will be saved every EVALUATE_PERFORMANCE_EVERY)

  • Criteria for episode termination
    • STOP_t (int > 0 or None)

      maximum number of steps per episode, set automatically if None

    • STOP_p (float ~ 0.0)

      episode stops when the probability that the source has been found is greater than 1 - STOP_p

  • Parallelization
    • N_PARALLEL (int)

      number of episodes computed in parallel when generating new experience or evaluating the RL policy (if <= 0, will use all available cpus)

      Known bug: for large neural networks, the code may hang if N_PARALLEL > 1, so use N_PARALLEL = 1 instead.

  • Reload an existing model
    • MODEL_PATH (str or None)

      path of the model (neural network) to reload, if None starts from scratch

  • Saving
    • RUN_NAME (str or None)

      prefix used for all output files, if None will use a timestamp