Commit db4ce346 authored by Yaoyao Liu's avatar Yaoyao Liu
Browse files

Update meta_trainer.py

parent ae4e779a
......@@ -352,21 +352,20 @@ class MetaTrainer(object):
vl=0
va=0
else:
with torch.no_grad():
tqdm_gen = tqdm.tqdm(val_loader)
for i, batch in enumerate(tqdm_gen, 1):
if torch.cuda.is_available():
data, _ = [_.cuda() for _ in batch]
else:
data = batch[0]
p = args.shot * args.way
data_shot, data_query = data[:p], data[p:]
data_shot = data_shot.unsqueeze(0).repeat(num_gpu, 1, 1, 1, 1)
logits = model.meta_forward(data_shot, data_query)
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
vl.add(loss.item())
va.add(acc)
tqdm_gen = tqdm.tqdm(val_loader)
for i, batch in enumerate(tqdm_gen, 1):
if torch.cuda.is_available():
data, _ = [_.cuda() for _ in batch]
else:
data = batch[0]
p = args.shot * args.way
data_shot, data_query = data[:p], data[p:]
data_shot = data_shot.unsqueeze(0).repeat(num_gpu, 1, 1, 1, 1)
logits = model.meta_forward(data_shot, data_query)
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
vl.add(loss.item())
va.add(acc)
vl = vl.item()
va = va.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment