#VQ-VAE + Transformer for techno Music creations. Issues with the Transformer

3 messages · Page 1 of 1 (latest)

worldly cargo
#

I'm currently trying to create an AI to make music based on techno DJ-sets as training data.
I have trained the VQ-VAE and get okish results there, so the music is good recognizable and it is generalized quite well works for music quite different to my training dataset too. Now i'm currently building the Transformer which shall predict the next token to generate new music. The training data i use for my transformer are 2**15 long token sequences (2048 values) and then from those a random 2048 long sequence gets chosen. Might increase the sequence length in the future but for computational speed thats my context length atm.

TransformerDec(vocab_size=2048, embed_size=1024, n_layers=6, forward_expansion=4, n_heads=8, pad_idx=-1, dropout=0.3, device=device, max_seq_len=2048)

The transformer seems to work, tested it with a synzthetic dataset, also once got to 7% accurracy, but havent been able to reproduce that. also tried to increase the number of layers and the embedding size with no improvements, does anyone have an idea what i am doing wrong my transformer implementa
tion and training script is attached

#
b1, b2 = (0.9, 0.99)
lr = 1e-5
optimizer = optim.AdamW(transformer.parameters(), lr, (b1, b2))
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10000, num_training_steps=n_epochs * len(train_dataloader), last_epoch=-1)

loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
logger.info(f"Training started on {device}")
loss_list: list = []
total_time: float = 0.0
transformer.train()
for e in range(n_epochs):
    total_loss: float = 0
    total_accuracy: float = 0
    start_time: float = time.time()

    for b_idx, (x, _) in enumerate(train_dataloader):
        if x.size(1) > training_seq_len:
            start_idx = torch.randint(0, x.size(1) - training_seq_len + 1, (1,)).item()
            x = x[:, start_idx:start_idx + training_seq_len]
        x = x.to(device)
        inp = x[:, :-1]
        target = x[:, 1:]
        with torch.autocast(device):
            pred = transformer(inp)
            loss = loss_fn(pred.transpose(1, 2), target)
            pred_indices = pred.argmax(dim=-1)
            correct = (pred_indices == target)
            correct_count = correct.sum().item()
            total_count = inp.shape[0] * inp.shape[1]
            accuracy = correct_count / total_count
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
        total_accuracy += accuracy
        total_norm = sum(p.grad.norm(2).item() ** 2 for p in transformer.parameters() if p.grad is not None) ** 0.5
        torch.nn.utils.clip_grad_norm_(transformer.parameters(), max_norm=1.0)
        if b_idx == 0:
            probs = F.softmax(pred[:, -1, :], dim=-1)
            logger.info(f"Epoch {e + 1}, Batch {b_idx + 1}: Pred min: {pred.min().item():.4f}, max: {pred.max().item():.4f}, Probs max: {probs.max().item():.4f} Tot. norm unclipped: {total_norm:.3e}")
        scheduler.step()

#
 transformer.eval()
    total_val_loss: float = 0
    total_val_accuracy: float = 0
    for x,_ in validation_dataloader:
        if x.size(1) > training_seq_len:
            start_idx = torch.randint(0, x.size(1) - training_seq_len + 1, (1,)).item()
            x = x[:, start_idx:start_idx + training_seq_len]
        x = x.to(device)
        inp = x[:, :-1]
        target = x[:, 1:]
        with torch.no_grad():
            pred = transformer(inp)
            total_val_loss += loss_fn(pred.transpose(1, 2), target)
            pred_indices = pred.argmax(dim=-1)
            correct = (pred_indices == target)
            correct_count = correct.sum().item()
            total_count = inp.shape[0] * inp.shape[1]
            total_val_accuracy += correct_count / total_count
    transformer.train()

    epoch_time = time.time() - start_time
    total_time += epoch_time
    remaining_time = int((total_time / (e + 1)) * (n_epochs - e - 1))
    avg_accuracy = total_accuracy / len(train_dataloader)
    avg_loss = total_loss / len(train_dataloader)
    avg_val_accuracy = total_val_accuracy / len(validation_dataloader)
    avg_val_loss = total_val_loss / len(validation_dataloader)
    logger.info(f"Epoch {e + 1:03d}: Avg. Loss: {avg_loss:.3e} Avg. Accuracy: {avg_accuracy:.3%} Avg. val Loss: {avg_val_loss:.3e} Avg. val Accuracy: {avg_val_accuracy:.3%} Remaining Time: {remaining_time // 3600:02d}h {(remaining_time % 3600) // 60:02d}min {round(remaining_time % 60):02d}s LR: {optimizer.param_groups[0]['lr']:.3e}")

torch.save({"transformer": transformer.state_dict(), "optim": optimizer.state_dict(), "epoch": e + 1}, full_model_path)