Minimax Optimal Online Imitation Learning via Replay Estimation

With Gokul Swamy, Matthew Peng, Sanjiban Choudhury, J Andrew Bagnell, Zhiwei Steven Wu, Jiantao Jiao and Kannan Ramchandran

Prior work in Imitation Learning has shown that in the infinite sample regime, moment/distribution matching achieves the same value as that of the expert policy. However, in the finite sample regime, even in the absence of optimization error, empirical variance can lead to a performance gap that scales with the \(O(H^2/N)\) for behavioral cloning and \(O(H \sqrt{1/N})\) for online moment matching, where \(H\) is the length of the horizon and \(N\) is the size of the expert dataset. The quadratic \(H\) dependence of behavior cloning arises from the phenomenon of “error compounding” - mistakes made by the learner earlier in the process propagate to later times as well. On the other hand, online moment matching suffers from high variance because of using only distribution estimates and ignoring the actions played by the expert at a state.

In this paper, we introduce the technique of replay estimation which makes the best of both worlds. By repeatedly executing cached expert actions in a stochastic simulator the empirical variance can be reduced to compute a smoother expert visitation distribution estimate to match, no longer suffering from the high variance incurred by standard moment matching. Our algorithm, Replay Estimation, is based on the following simple four-step approach,

1. Train Behavior Cloning (BC) on the dataset.

2. Train an state-uncertainty estimate of BC on the dataset: When the uncertainty is high, BC is inaccurate and when it is low, it is accurate.

3. Roll out BC on the environment, weighing the observed trajectories by the state-uncertainty of the visited states.

4. Carry out moment matching against this weighted distribution.

In theory, the uncertainty estimate can simply be instantiated as \(1\) for a states if the BC classifier classifies them with low margin (i.e. the highest two logits are close to each other), and \(0\) otherwise. In other words, low margin states are uncertain and high margin states are more certain.

In the presence of general function approximation, we prove a meta theorem reducing the performance gap of our approach to the parameter estimation error for offline classification (i.e. learning the expert policy). In the case of linear function approximation, our meta theorem shows that the performance gap incurred by Replay Estimation achieves suboptimality gap scaling as \(O(d^{5/4} H^{3/2}/N)\), under significantly weaker assumptions compared to prior work. Our algorithm instantiated on several continuous control tasks outperforms existing approaches across a variety of dataset sizes, but especially when the number of samples is heavily constrained.

Comparison on Noisy Hopper 

Comparison of algorithms on a noisy version of
the “Hopper” Pybullet stochastic control task.
RE shines most in the low data regime.