Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Yaoyao Liu
E3BM
Commits
4047a7a1
Verified
Commit
4047a7a1
authored
Apr 13, 2021
by
Yaoyao Liu
Browse files
Update the pre-train code
parent
db4ce346
Changes
3
Show whitespace changes
Inline
Side-by-side
main.py
View file @
4047a7a1
...
...
@@ -45,6 +45,7 @@ parser.add_argument('-way', type=int, default=5)
parser
.
add_argument
(
'-shot'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'-query'
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
'-val_episode'
,
type
=
int
,
default
=
3000
)
parser
.
add_argument
(
'-val_epoch'
,
type
=
int
,
default
=
40
)
parser
.
add_argument
(
'-backbone'
,
type
=
str
,
default
=
'resnet12'
,
choices
=
[
'wrn'
,
'resnet12'
])
parser
.
add_argument
(
'-dropout'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'-save_all'
,
action
=
'store_true'
,
help
=
'save models on each epoch'
)
...
...
model/meta_model.py
View file @
4047a7a1
...
...
@@ -192,8 +192,8 @@ class MetaModel(nn.Module):
embedding_query
=
self
.
encoder
(
data_query
)
embedding_shot
=
self
.
encoder
(
data_shot
)
embedding_shot
=
self
.
normalize_feature
(
embedding_shot
)
embedding_query
=
self
.
normalize_feature
(
embedding_query
)
embedding_shot
=
self
.
normalize_feature
(
embedding_shot
)
embedding_query
=
self
.
normalize_feature
(
embedding_query
)
with
torch
.
no_grad
():
if
self
.
args
.
shot
==
1
:
...
...
@@ -228,3 +228,40 @@ class MetaModel(nn.Module):
basestep_value_list
.
append
(
generated_basestep_weights
)
return
total_logits
def
preval_forward
(
self
,
data_shot
,
data_query
):
data_query
=
data_query
.
squeeze
(
0
)
data_shot
=
data_shot
.
squeeze
(
0
)
embedding_query
=
self
.
encoder
(
data_query
)
embedding_shot
=
self
.
encoder
(
data_shot
)
embedding_shot
=
self
.
normalize_feature
(
embedding_shot
)
embedding_query
=
self
.
normalize_feature
(
embedding_query
)
with
torch
.
no_grad
():
if
self
.
args
.
shot
==
1
:
proto
=
embedding_shot
else
:
proto
=
self
.
fusion
(
embedding_shot
)
self
.
base_learner
.
fc1_w
.
data
=
proto
fast_weights
=
self
.
base_learner
.
vars
batch_shot
=
embedding_shot
batch_label
=
self
.
label_shot
logits_q
=
self
.
base_learner
(
embedding_query
,
fast_weights
)
total_logits
=
0.0
*
logits_q
for
k
in
range
(
0
,
self
.
update_step
):
batch_shot
=
embedding_shot
batch_label
=
self
.
label_shot
logits
=
self
.
base_learner
(
batch_shot
,
fast_weights
)
*
self
.
args
.
temperature
loss
=
F
.
cross_entropy
(
logits
,
batch_label
)
grad
=
torch
.
autograd
.
grad
(
loss
,
fast_weights
)
fast_weights
=
list
(
map
(
lambda
p
:
p
[
1
]
-
0.1
*
p
[
0
],
zip
(
grad
,
fast_weights
)))
logits_q
=
self
.
base_learner
(
embedding_query
,
fast_weights
)
logits_q
=
logits_q
*
self
.
args
.
temperature
total_logits
+=
logits_q
return
total_logits
\ No newline at end of file
trainer/meta_trainer.py
View file @
4047a7a1
...
...
@@ -306,6 +306,12 @@ class MetaTrainer(object):
global_count
=
0
writer
=
SummaryWriter
(
osp
.
join
(
args
.
save_path
,
'tf'
))
label
=
torch
.
arange
(
args
.
way
).
repeat
(
args
.
query
)
if
torch
.
cuda
.
is_available
():
label
=
label
.
type
(
torch
.
cuda
.
LongTensor
)
else
:
label
=
label
.
type
(
torch
.
LongTensor
)
SLEEP
(
args
)
for
epoch
in
range
(
1
,
args
.
max_epoch
+
1
):
print
(
args
.
save_path
)
...
...
@@ -348,7 +354,7 @@ class MetaTrainer(object):
vl
=
Averager
()
va
=
Averager
()
if
epoch
<
args
.
val_epoch
:
if
epoch
<
args
.
val_epoch
:
vl
=
0
va
=
0
else
:
...
...
@@ -360,8 +366,8 @@ class MetaTrainer(object):
data
=
batch
[
0
]
p
=
args
.
shot
*
args
.
way
data_shot
,
data_query
=
data
[:
p
],
data
[
p
:]
data_shot
=
data_shot
.
unsqueeze
(
0
).
repeat
(
num_gpu
,
1
,
1
,
1
,
1
)
logits
=
model
.
meta
_forward
(
data_shot
,
data_query
)
data_shot
=
data_shot
.
unsqueeze
(
0
).
repeat
(
args
.
num_gpu
,
1
,
1
,
1
,
1
)
logits
=
model
.
preval
_forward
(
data_shot
,
data_query
)
loss
=
F
.
cross_entropy
(
logits
,
label
)
acc
=
count_acc
(
logits
,
label
)
vl
.
add
(
loss
.
item
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment