# diffusion scheduler from OpenAI
## 初始化
diffusion = create_diffusion(timestep_respacing="") # default: 1000steps, linear noise schedule
## train
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) loss_dict = diffusion.training_losses(model, x, t, model_kwargs) loss = loss_dict["loss"].mean()
x : (N,C,T,H,W)
在内部会运行 model_output = model(x_t, t, **model_kwargs)
## Inference
diffusion_ = create_diffusion(str(250))
model_kwargs = dict(encoder_hidden_states=cond, attention_mask=None,encoder_attention_mask=cond_mask) sample_fn = model_.forward samples = diffusion_.p_sample_loop(sample_fn, z.shape, z, clip_denoised=False,model_kwargs=model_kwargs, progress=True,device=accelerator.device)