#Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: struc

1 messages · Page 1 of 1 (latest)

keen ferry
#

I've been trying to resolve a bug in the FluxRegionAttention custom node for a few weeks now. I've looked everywhere for information, updated nodes and comfyUI, changed the code type to bring attn_mask to dtype, but it still doesn't work.

prime plume
#

looks like device precision / quantization is not right

#

attention mask is bf16, but what it tries to search is fp16

#

attention masks are either true/false or 1/0

keen ferry
prime plume
#

you may need to figure out where is the desync of datatype is happening

#

the issue is that the model is receiving a tensor of a different datatype

#

Pad mask for xformers to reduce allocations during inference

    device = torch.device('cuda')
    attn_dtype = torch.bfloat16 if model_management.should_use_bf16(device=device) else torch.float16
#

that's where it setsit

keen ferry
prime plume
#

just change that specific line to toch.float16

#

and pontentiallt add prints for dtype of tensors the node receives as input

keen ferry
# prime plume just change that specific line to toch.float16

I changed the parameters in the node.ru file on the path custom_nodes\ComfyUI-FluxRegionAttention. The error remains the same. It occurs in SamplerCustomAdvanced, maybe I'm changing it in the wrong place? Could you please post exactly what these changed lines should look like. (screenshot of how I tried to fix it).

prime plume
#

would be easier had you just posted a workflow json

keen ferry
prime plume
#

i was wrong...

#

how is it even supposed to work?

#

okay.. it needed comfy update, never mind

#

@keen ferry i need your

keen ferry
prime plume
#

hm... it ran fine, nothing exploded, painted a brick wall

keen ferry
#

I've updated all the nodes and comfy itself

prime plume
#

what's in the console log?

keen ferry
prime plume
#

I did add a print in ComfyUI\custom_nodes\ComfyUI-FluxRegionAttention\node.py line 43 before sdp call

#

print("attn_mask", attn_mask.shape, attn_mask.dtype, 'q', q.shape, q.dtype)

#

and both the mask and the q are bfloat16

#

what's your gpu?

#

anyway, i see you've created a ticked on comfy github, so hopefully they can answer

keen ferry
#

RTX2060 Mobile

keen ferry
prime plume
#

this part

#
            print(f'Aplying attention masks: {attn_mask.shape}')
            L, _ = attn_mask.shape
            H = 24  # 24 heads for FLUX models
            pad = 8 - L % 8
            
            # print(f'Attention mask memory padded by: {pad}')
            if pad != 8:
                # TODO: take dtype from memory_management computational_type
                mask_out = torch.empty([bs, H, L + pad, L + pad],
                                       dtype=torch.bfloat16, device=device)
                mask_out[:, :, :L, :L] = attn_mask
                # print(f'Attention mask memory padded to: {mask_out.shape}')
                attn_mask = mask_out[:, :, :L, :L]
            else:
                mask_out = torch.empty([bs, H, L, L],
                                       dtype=torch.bfloat16, device=device)
                mask_out[:, :, :, :] = attn_mask
                attn_mask = mask_out
#

it sets the attention mask to bfloat16

#

maybe add attn_mask = attn_mask.to(torch.float16) after

#

"RTX 2060 does not support bfloat16 compilation natively"

#

so yeah, that's the problem

#

UI correctly sets precision to float16

#

this shitty node does not do it correctly

#

they just probably forgot to use attn_dtype they collected in attn_dtype = torch.bfloat16 if model_management.should_use_bf16(device=device) else torch.float16

#

so instead of dtype=torch.bfloat16 it should be dtype=attn_dtype