#GRPOTrainer compute_loss function shape mismatch for matmul

2 messages · Page 1 of 1 (latest)

quick geode
#
TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in method matmul of type object at 0x7d24f8af6fa0>(*(GradTrackingTensor(lvl=1, value=
    FakeTensor(..., device='cuda:0', size=(1, s3, s2), dtype=torch.float16,
               requires_grad=True)
), GradTrackingTensor(lvl=1, value=
    FakeTensor(..., device='cuda:0', size=(2048, 65537), dtype=torch.float16)
)), **{}): got RuntimeError('a and b must have same reduction dim, but got [s3, s2] X [2048, 65537].')

from user code:
   File "/content/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 346, in accumulate_chunk
    (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/apis.py", line 441, in wrapper
    return eager_transforms.grad_and_value_impl(
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/vmap.py", line 48, in fn
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/content/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 294, in compute_loss
    new_logits = torch.matmul(new_hidden_states.to(lm_head.dtype), lm_head.t())

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

This was after taking the gemma 1B grpo colab,
changing to unsloth/Falcon-H1-1.5B-Deep-Instruct
adding os.environ['TRITON_JIT_DISABLE_OPT'] = '1'
and !uv pip install --no-build-isolation mamba-ssm[causal-conv1d]

That seemed like it would give me fast track on the t4 and it got to training but fell over. I have the full output if wanted.

edgy mauveBOT
#

This error is a tensor shape mismatch during torch.matmul in the GRPOTrainer's compute_loss function, typically caused by the hidden state and lm_head weight shapes not aligning for Falcon models. Unsloth's GRPO implementation is primarily tested on Llama, Gemma, Qwen, and similar architectures; Falcon models may have different hidden state or lm_head shapes, leading to this issue. There is no official support or fix for Falcon models in Unsloth's GRPO notebooks or code, and this shape mismatch is a known problem when using unsupported architectures with GRPOTrainer (source, source, source).

To resolve, use a supported model (Llama, Gemma, Qwen, etc.) for GRPO training, or you would need to manually adapt the GRPOTrainer and loss computation to Falcon's architecture, which is not documented or officially supported by Unsloth. Would you like more detail on how to check or adapt the tensor shapes for Falcon, or guidance on using a supported model?

Sources: