optimization - Computing Gradient of Loss w.r.t Learning Rate PyTorch - Stack Overflow

I am building a custom optimizer that samples learning rates from a Dirichlet distribution, whose param

I am building a custom optimizer that samples learning rates from a Dirichlet distribution, whose parameters (alpha) need to be updated in each backpropagation. I've already figured out how to get the loss w.r.t. to these alpha parameters, effectively this would be ∂η/∂α, where η is the learning rate.

However, I need to "connect," for lack of better word, this gradient with that of the loss, effectively ∂L/∂η, such that I can "chain" these gradients together, forming the expression:

∂L/∂η * ∂η/∂α = ∂L/∂α

I can then use this gradient to update the alphas and therefore improve the sampling of the distribution. The problem is I cannot figure out how to get ∂L/∂η. I've tried using the following line:

grad_learning_rate = torch.autograd.grad(loss, self.learning_rate, grad_outputs=torch.tensor(1.0, device=loss.device), retain_graph=True, allow_unused=True)[0]

where loss is passed into the optimizer after each forward pass. But the following error message is returned:

One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I've attached the model:

class MLP(nn.Module):
    def __init__(self, input_size, output_size, device: torch.device=None):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 10, dtype=torch.float64)
        self.relu = nn.ReLU()

        # Fit model on cpu or inputted processing unit (xpu).
        self.device = device if device is not None else torch.device('cpu')
        self.to(self.device)

    def forward(self, x):
        x = self.fc1(x)
        return x

And the Optimizer:

class Dart(Optimizer):
'''
Optimizer must receive losses throughout training.
'''
def __init__(self, params, betas=(0.9, 0.999),
             alpha_init=1.0, alpha_lr=0.0001, eps=1e-8, weight_decay=0): 
    defaults = dict(betas=betas, eps=eps, weight_decay=weight_decay)
    super(Dart, self).__init__(params, defaults)
    self.alpha_scaler = alpha_init
    self.alpha_lr = alpha_lr
    self.learning_rate = None
    self.alpha_grads = None
    
def sample_lr_candidates(self, mean=1e-3, std=1e-4, num_samples=(10, 1), min_lr=1e-6, max_lr=1e-1):
    # Sample from a Gaussian distribution
    lr_samples = torch.normal(mean=mean, std=std, size=(num_samples))
    
    # Clip the values to ensure they are within the min_lr and max_lr range
    lr_samples = torch.clamp(lr_samples, min=min_lr, max=max_lr)
    
    return lr_samples.to(torch.float64)

def step(self, loss):
    for group in self.param_groups:  # only one group.
        for p in group['params']:
            if p.grad is None:
                continue

            dim = (10, 784) if p.shape == torch.Size([10, 784]) else (1, 10)
            
            state = self.state[p]  # optimizer class opens 'history' for param.
            input = torch.empty(dim, device='cpu', dtype=torch.float64)

    
            if len(state) == 0:  # Initialize state if not already done.
                state['step'] = 0
                state['lr_candidates'] = self.sample_lr_candidates(num_samples=p.shape) # .to('xpu')
                state['alphas'] = torch.ones_like(input, memory_format=torch.preserve_format) * self.alpha_scaler
            
            state['step'] += 1  
            
            # Enable autograd for alpha updates
            state['alphas'].requires_grad_(True)
            
            # Sample from Dirichlet (WARNING: May not support autograd)
            samples = torch.distributions.Dirichlet(state['alphas']).rsample() # .to('xpu')  # Differentiable
            total = state['alphas'].sum(-1, True).expand_as(state['alphas'])
            grad_samples = torch._dirichlet_grad(samples, state['alphas'], total) # del p (samples) / del alphas
            
            # Compute learning rate
            # print(samples.shape, state['lr_candidates'].shape)
            self.learning_rate = samples * state['lr_candidates'] # .to('xpu')
            self.learning_rate.retain_grad()
            
            # Compute gradient wrt alphas
            grad_learning_rate = torch.autograd.grad(loss, self.learning_rate, grad_outputs=torch.tensor(1.0, device=loss.device), retain_graph=True)[0]
            
            # Update alphas with gradient descent
            state['alphas'] = state['alphas'] - self.alpha_lr * grad_samples * state['lr_candidates'] # need del L/ del n
            self.alpha_grads = state['alphas']

            # Apply weight update
            # print(self.learning_rate.shape, p.grad.shape)
            p.data.sub_(self.learning_rate.squeeze() * p.grad) # .to('xpu')

EDIT: Training function; note how in optimizer.step(loss) the loss found using cross entropy is passed to the optimizer.

def train(self, model: nn.Module, optim: optim.Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]) -> dict:
    model.train()
    training_history = {}

    for epoch in range(self.epochs):
        losses = []
        learning_rates = []
        alpha_grads = []
        accuracies = []
        for images, labels in self.data:
            images = torch.squeeze(images) # .to('xpu')
            labels = torch.squeeze(labels) # .to('xpu')
            
            assert images.shape == torch.Size([128, 784]), f"Images [{images.shape}] not of desired dimensions."
            assert labels.shape == torch.Size([128]), f"Labels [{labels.shape}] not of desired dimensions."
            
            predictions = model(images)
            loss = criterion(predictions, labels)

            optim.zero_grad() # reset gradients
            loss.backward(retain_graph=True) # compute gradients for all 128 samples
            optim.step(loss) # apply weight update and pass loss'

            learning_rates.append(torch.mean(optim.learning_rate.to('cpu')))
            alpha_grads.append(torch.mean(optim.alpha_grads.to('cpu')))
            losses.append(loss.to('cpu'))
        self.store_training_history(history=training_history,
                               epoch_num=epoch,
                               loss=losses,
                               learning_rate=learning_rates,
                               alpha_grads = alpha_grads
                            )
        with torch.no_grad():
            print(f"Completed Epoch: {epoch+1}/{self.epochs}, Loss: {np.mean(losses):.4f}")
    return training_history

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744838759a4596425.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信