Skip to content

why train loss is different from the paper ? #156

@Johnson-yue

Description

@Johnson-yue

x_1 = vae.encode(img.to(accelerator.device).to(torch.float32))
inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts)
x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
bs = img.shape[0]
t = torch.sigmoid(torch.randn((bs,), device=accelerator.device))
x_0 = torch.randn_like(x_1).to(accelerator.device)
print(t.shape, x_1.shape, x_0.shape)
x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0
bsz = x_1.shape[0]
guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype)
block_res_samples = controlnet(
img=x_t.to(weight_dtype),
img_ids=inp['img_ids'].to(weight_dtype),
controlnet_cond=control_image.to(weight_dtype),
txt=inp['txt'].to(weight_dtype),
txt_ids=inp['txt_ids'].to(weight_dtype),
y=inp['vec'].to(weight_dtype),
timesteps=t.to(weight_dtype),
guidance=guidance_vec.to(weight_dtype),
)
# Predict the noise residual and compute loss
model_pred = dit(
img=x_t.to(weight_dtype),
img_ids=inp['img_ids'].to(weight_dtype),
txt=inp['txt'].to(weight_dtype),
txt_ids=inp['txt_ids'].to(weight_dtype),
block_controlnet_hidden_states=[
sample.to(dtype=weight_dtype) for sample in block_res_samples
],
y=inp['vec'].to(weight_dtype),
timesteps=t.to(weight_dtype),
guidance=guidance_vec.to(weight_dtype),
)
loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

In paper , x0 is noise, x1 is data

x_t = t * x1 +(1-t) *x0,  it means start from x0 to x1 
loss = || (x1- x0) - pred(x_t) ||

Image

In your code , x0 is noise ,x1 is data, the same as paper

x_t =(1-t)*x1 + t * x0  it is not equal to  x_t = t * x1 +(1-t) *x0
then the loss = || (x0 - x1) -pred(x_t) ||

why are you do this ??

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions