Commit db4ce346 authored by Yaoyao Liu's avatar Yaoyao Liu
Browse files

Update meta_trainer.py

parent ae4e779a
...@@ -352,21 +352,20 @@ class MetaTrainer(object): ...@@ -352,21 +352,20 @@ class MetaTrainer(object):
vl=0 vl=0
va=0 va=0
else: else:
with torch.no_grad(): tqdm_gen = tqdm.tqdm(val_loader)
tqdm_gen = tqdm.tqdm(val_loader) for i, batch in enumerate(tqdm_gen, 1):
for i, batch in enumerate(tqdm_gen, 1): if torch.cuda.is_available():
if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch]
data, _ = [_.cuda() for _ in batch] else:
else: data = batch[0]
data = batch[0] p = args.shot * args.way
p = args.shot * args.way data_shot, data_query = data[:p], data[p:]
data_shot, data_query = data[:p], data[p:] data_shot = data_shot.unsqueeze(0).repeat(num_gpu, 1, 1, 1, 1)
data_shot = data_shot.unsqueeze(0).repeat(num_gpu, 1, 1, 1, 1) logits = model.meta_forward(data_shot, data_query)
logits = model.meta_forward(data_shot, data_query) loss = F.cross_entropy(logits, label)
loss = F.cross_entropy(logits, label) acc = count_acc(logits, label)
acc = count_acc(logits, label) vl.add(loss.item())
vl.add(loss.item()) va.add(acc)
va.add(acc)
vl = vl.item() vl = vl.item()
va = va.item() va = va.item()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment