Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Yaoyao Liu
E3BM
Commits
db4ce346
Commit
db4ce346
authored
Apr 11, 2021
by
Yaoyao Liu
Browse files
Update meta_trainer.py
parent
ae4e779a
Changes
1
Hide whitespace changes
Inline
Side-by-side
trainer/meta_trainer.py
View file @
db4ce346
...
...
@@ -352,21 +352,20 @@ class MetaTrainer(object):
vl
=
0
va
=
0
else
:
with
torch
.
no_grad
():
tqdm_gen
=
tqdm
.
tqdm
(
val_loader
)
for
i
,
batch
in
enumerate
(
tqdm_gen
,
1
):
if
torch
.
cuda
.
is_available
():
data
,
_
=
[
_
.
cuda
()
for
_
in
batch
]
else
:
data
=
batch
[
0
]
p
=
args
.
shot
*
args
.
way
data_shot
,
data_query
=
data
[:
p
],
data
[
p
:]
data_shot
=
data_shot
.
unsqueeze
(
0
).
repeat
(
num_gpu
,
1
,
1
,
1
,
1
)
logits
=
model
.
meta_forward
(
data_shot
,
data_query
)
loss
=
F
.
cross_entropy
(
logits
,
label
)
acc
=
count_acc
(
logits
,
label
)
vl
.
add
(
loss
.
item
())
va
.
add
(
acc
)
tqdm_gen
=
tqdm
.
tqdm
(
val_loader
)
for
i
,
batch
in
enumerate
(
tqdm_gen
,
1
):
if
torch
.
cuda
.
is_available
():
data
,
_
=
[
_
.
cuda
()
for
_
in
batch
]
else
:
data
=
batch
[
0
]
p
=
args
.
shot
*
args
.
way
data_shot
,
data_query
=
data
[:
p
],
data
[
p
:]
data_shot
=
data_shot
.
unsqueeze
(
0
).
repeat
(
num_gpu
,
1
,
1
,
1
,
1
)
logits
=
model
.
meta_forward
(
data_shot
,
data_query
)
loss
=
F
.
cross_entropy
(
logits
,
label
)
acc
=
count_acc
(
logits
,
label
)
vl
.
add
(
loss
.
item
())
va
.
add
(
acc
)
vl
=
vl
.
item
()
va
=
va
.
item
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment