# Pseudocode for LoRA
class LoRAparam(nn.Module):
def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
super().__init__()
self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
nn.init.normal_(self.lora_A, mean=0, std=1)
self.scale = alpha / rank
self.enabled = True
def forward(self, original_weights):
if self.enabled:
# Return W + (B*A)*scale
return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
else:
return original_weights
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
features_in, features_out = layer.weight.shape
return LoRAparam(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)
nn.utils.parametrize.register_parametrization(net.linear1, "weight", linear_layer_parameterization(net.linear1, device))