Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

So I am trying to figure out how to compute the accuracy of a BandRNN.

BandRnn is a diagonalRNN model with a different number of connections per neuron. For example: enter image description here here C is the number of connections per neuron.

My current model training is as follows:

model = ModelLSTM(m, k).to(device)

model.train()

opt = torch.optim.Adam(model.parameters(), lr=args.lr)

best_test = 1e7
best_validation = 1e7

for ep in range(1, args.epochs + 1):

init_time = datetime.now()
processed = 0
step = 1

for batch_idx, (batch_x, batch_y, len_batch) in enumerate(train_loader):
    batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)

    opt.zero_grad()

    logits = model(batch_x)
   
    loss = model.loss(logits, batch_y, len_batch)

    acc = sum(logits == batch_y) * 1.0 / len(logits)
    print(acc)

    loss.backward()

    if args.clip > 0:
        nn.utils.clip_grad_norm_(model.parameters(), args.clip)

    opt.step()

    processed += len(batch_x)
    step += 1
    print("   batch_idx {}Loss: {:.2f} ".format(batch_idx, loss))

print("Epoch {}, LR {:.5f} Loss: {:.2f} ".format(ep, opt.param_groups[0]['lr'], loss))

And my model test is as follows:

model.eval()
with torch.no_grad():

for batch_x, batch_y, len_batch in test_loader:
    batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
    logits = model(batch_x)
    loss_test = model.loss(logits, batch_y, len_batch)
    
    acc = sum(logits == batch_y) * 1.0 / len(logits)
    

for batch_x, batch_y, len_batch in val_loader:
    batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
    logits = model(batch_x)
    loss_val = model.loss(logits, batch_y, len_batch)

if loss_val < best_validation:
    best_validation = loss_val.item()
    best_test = loss_test.item()

print()
print("Val:  Loss: {:.2f}Best: {:.2f}".format(loss_val, best_validation))
print("Test: Loss: {:.2f}Best: {:.2f}".format(loss_test, best_test))
print()

model.train()

I am struggling with thinking about a way to compute the accuracy of this model and I would like to receive some suggestions about a way to do so. Thank you.

question from:https://stackoverflow.com/questions/66065431/compute-accuracy-of-band-rnn

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
286 views
Welcome To Ask or Share your Answers For Others

1 Answer

I believe this line in your code is already attempting to calculate accuracy:

acc = sum(logits == batch_y) * 1.0 / len(logits)

Though you probably want to argmax the logits before comparing with the labels:

preds = logits.argmax(dim=-1)
acc = sum(preds == batch_y) * 1.0 / len(logits)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...