Verified Commit 4047a7a1 authored by Yaoyao Liu's avatar Yaoyao Liu
Browse files

Update the pre-train code

parent db4ce346
......@@ -45,6 +45,7 @@ parser.add_argument('-way', type=int, default=5)
parser.add_argument('-shot', type=int, default=5)
parser.add_argument('-query', type=int, default=16)
parser.add_argument('-val_episode', type=int, default=3000)
parser.add_argument('-val_epoch', type=int, default=40)
parser.add_argument('-backbone', type=str, default='resnet12', choices=['wrn', 'resnet12'])
parser.add_argument('-dropout', type=float, default=0.5)
parser.add_argument('-save_all',action='store_true',help='save models on each epoch')
......
......@@ -192,8 +192,8 @@ class MetaModel(nn.Module):
embedding_query = self.encoder(data_query)
embedding_shot = self.encoder(data_shot)
embedding_shot=self.normalize_feature(embedding_shot)
embedding_query=self.normalize_feature(embedding_query)
embedding_shot = self.normalize_feature(embedding_shot)
embedding_query = self.normalize_feature(embedding_query)
with torch.no_grad():
if self.args.shot==1:
......@@ -227,4 +227,41 @@ class MetaModel(nn.Module):
combination_value_list.append(generated_combination_weights)
basestep_value_list.append(generated_basestep_weights)
return total_logits
def preval_forward(self, data_shot, data_query):
data_query=data_query.squeeze(0)
data_shot = data_shot.squeeze(0)
embedding_query = self.encoder(data_query)
embedding_shot = self.encoder(data_shot)
embedding_shot = self.normalize_feature(embedding_shot)
embedding_query = self.normalize_feature(embedding_query)
with torch.no_grad():
if self.args.shot==1:
proto = embedding_shot
else:
proto=self.fusion(embedding_shot)
self.base_learner.fc1_w.data = proto
fast_weights = self.base_learner.vars
batch_shot = embedding_shot
batch_label = self.label_shot
logits_q = self.base_learner(embedding_query, fast_weights)
total_logits = 0.0 * logits_q
for k in range(0, self.update_step):
batch_shot = embedding_shot
batch_label = self.label_shot
logits = self.base_learner(batch_shot, fast_weights) * self.args.temperature
loss = F.cross_entropy(logits, batch_label)
grad = torch.autograd.grad(loss, fast_weights)
fast_weights = list(map(lambda p: p[1] - 0.1 * p[0], zip(grad, fast_weights)))
logits_q = self.base_learner(embedding_query, fast_weights)
logits_q = logits_q * self.args.temperature
total_logits += logits_q
return total_logits
\ No newline at end of file
......@@ -306,6 +306,12 @@ class MetaTrainer(object):
global_count = 0
writer = SummaryWriter(osp.join(args.save_path,'tf'))
label = torch.arange(args.way).repeat(args.query)
if torch.cuda.is_available():
label = label.type(torch.cuda.LongTensor)
else:
label = label.type(torch.LongTensor)
SLEEP(args)
for epoch in range(1, args.max_epoch + 1):
print (args.save_path)
......@@ -348,7 +354,7 @@ class MetaTrainer(object):
vl = Averager()
va = Averager()
if epoch <args.val_epoch:
if epoch < args.val_epoch:
vl=0
va=0
else:
......@@ -360,8 +366,8 @@ class MetaTrainer(object):
data = batch[0]
p = args.shot * args.way
data_shot, data_query = data[:p], data[p:]
data_shot = data_shot.unsqueeze(0).repeat(num_gpu, 1, 1, 1, 1)
logits = model.meta_forward(data_shot, data_query)
data_shot = data_shot.unsqueeze(0).repeat(args.num_gpu, 1, 1, 1, 1)
logits = model.preval_forward(data_shot, data_query)
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
vl.add(loss.item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment