conditional-flow-matching

1 min read Original article ↗
@torch.no_grad()
def run_flow(flow_model, x_0, t_0, t_1, device='cpu'):
    def f(t: float, x):
        return flow_model(x, time=torch.full(x.shape[:1], t, device=device))

    return odeint(f, x_0, t_0, t_1, phi=flow_model.parameters())


def animate_flow_run(flow_model, X, frames=20, device='cpu'):
    bins = [
        np.linspace(X[:, 0].min().cpu(), X[:, 0].max().cpu(), 128),
        np.linspace(X[:, 1].min().cpu(), X[:, 1].max().cpu(), 128),
    ]

    def plot_frame(time):
        plt.cla()
        plot_dataset(run_flow(flow_model, X, 0, time, device=device).cpu(), bins=bins, title=f'distribution at time {time:.2f}')
    
    fig = plt.figure(figsize=(8, 8))
    animation = FuncAnimation(fig, plot_frame, frames=np.linspace(0, 1, frames))
    html = HTML(animation.to_html5_video())
    plt.close()
    return html


animate_flow_run(ExampleFlow(), noise)