-
Notifications
You must be signed in to change notification settings - Fork 162
Open
Description
x-flux/train_flux_deepspeed_controlnet.py
Lines 213 to 251 in 4749542
| x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) | |
| inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts) | |
| x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| bs = img.shape[0] | |
| t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) | |
| x_0 = torch.randn_like(x_1).to(accelerator.device) | |
| print(t.shape, x_1.shape, x_0.shape) | |
| x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0 | |
| bsz = x_1.shape[0] | |
| guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype) | |
| block_res_samples = controlnet( | |
| img=x_t.to(weight_dtype), | |
| img_ids=inp['img_ids'].to(weight_dtype), | |
| controlnet_cond=control_image.to(weight_dtype), | |
| txt=inp['txt'].to(weight_dtype), | |
| txt_ids=inp['txt_ids'].to(weight_dtype), | |
| y=inp['vec'].to(weight_dtype), | |
| timesteps=t.to(weight_dtype), | |
| guidance=guidance_vec.to(weight_dtype), | |
| ) | |
| # Predict the noise residual and compute loss | |
| model_pred = dit( | |
| img=x_t.to(weight_dtype), | |
| img_ids=inp['img_ids'].to(weight_dtype), | |
| txt=inp['txt'].to(weight_dtype), | |
| txt_ids=inp['txt_ids'].to(weight_dtype), | |
| block_controlnet_hidden_states=[ | |
| sample.to(dtype=weight_dtype) for sample in block_res_samples | |
| ], | |
| y=inp['vec'].to(weight_dtype), | |
| timesteps=t.to(weight_dtype), | |
| guidance=guidance_vec.to(weight_dtype), | |
| ) | |
| loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") |
In paper , x0 is noise, x1 is data
x_t = t * x1 +(1-t) *x0, it means start from x0 to x1
loss = || (x1- x0) - pred(x_t) ||
In your code , x0 is noise ,x1 is data, the same as paper
x_t =(1-t)*x1 + t * x0 it is not equal to x_t = t * x1 +(1-t) *x0
then the loss = || (x0 - x1) -pred(x_t) ||
why are you do this ??
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
