I’m working on a PyTorch model where I compute a “global representation” through a forward pipeline. This pipeline is subsequently used in an extra sampling procedure later on in the network. When I compute the global representation with a full recompute (i.e. without checkpointing), everything works fine and gradients flow back correctly. However, when I try to use torch.utils.checkpoint to save memory by recomputing the global representation during the backward pass, I get a runtime error similar to:
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 34:
saved metadata: {'shape': torch.Size([128, 192]), 'dtype': torch.bfloat16, 'device': device(type='mps', index=0)}
recomputed metadata: {'shape': torch.Size([128, 128, 192]), 'dtype': torch.float32, 'device': device(type='mps', index=0)}
... (more tensor mismatches follow) ...
Some details about my setup:
- I run on the MPS backend (Apple Silicon) with mixed precision (bfloat16) using autocast.
- The global representation is computed in a module that later feeds into an extra sampling procedure, so gradients must flow back properly.
- Recomputing the global representation fully (i.e. running the entire forward pass twice) is too inefficient, so checkpointing is critical.
Besides this, I’ve already tried some fixes such as replacing all inplace operations with their out-of-place equivalents, but these modifications didn’t resolve the issue.
Additionally, I’m using the following line in my Gumbel sampling procedure:
cond_expanded = cond_cont.unsqueeze(1).expand(B, num_samples, -1).reshape(B * num_samples, -1)
I intended this to properly broadcast the condition over multiple Monte Carlo samples. However, I suspect that the unsqueeze/expand/reshape sequence might be contributing to the metadata mismatch between the tensors saved during the forward pass and those recomputed during the backward pass.
I suspect this issue is related either to interactions between checkpointing and autocast or possibly an inadvertent change in tensor dimensions during recomputation. Has anyone encountered a similar problem or know how to ensure that the recomputed tensors match the original forward pass (in terms of shape, dtype, and device) while still benefiting from checkpointing? Any suggestions on how to resolve this, or workarounds that allow efficient memory use without sacrificing gradient flow, would be very helpful.
Additional context or sample code snippets can be provided if needed.
(also maybe someone can create a "torch.utils.checkpoint" tag)
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744728341a4590315.html
评论列表(0条)