@torch.no_grad() def run_flow(flow_model, x_0, t_0, t_1, device='cpu'): def f(t: float, x): return flow_model(x, time=torch.full(x.shape[:1], t, device=device)) return odeint(f, x_0, t_0, t_1, phi=flow_model.parameters()) def animate_flow_run(flow_model, X, frames=20, device='cpu'): bins = [ np.linspace(X[:, 0].min().cpu(), X[:, 0].max().cpu(), 128), np.linspace(X[:, 1].min().cpu(), X[:, 1].max().cpu(), 128), ] def plot_frame(time): plt.cla() plot_dataset(run_flow(flow_model, X, 0, time, device=device).cpu(), bins=bins, title=f'distribution at time {time:.2f}') fig = plt.figure(figsize=(8, 8)) animation = FuncAnimation(fig, plot_frame, frames=np.linspace(0, 1, frames)) html = HTML(animation.to_html5_video()) plt.close() return html animate_flow_run(ExampleFlow(), noise)