Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jiangxin Dong
SVMAP
Commits
06587888
Commit
06587888
authored
Jun 17, 2021
by
Jiangxin Dong
Browse files
Upload New File
parent
86fb1b67
Changes
1
Show whitespace changes
Inline
Side-by-side
logger/logger.py
0 → 100644
View file @
06587888
import
torch
import
imageio
import
numpy
as
np
import
os
import
datetime
import
matplotlib
matplotlib
.
use
(
'Agg'
)
from
matplotlib
import
pyplot
as
plt
class
Logger
:
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
psnr_log
=
torch
.
Tensor
()
self
.
loss_log
=
torch
.
Tensor
()
if
args
.
load
==
'.'
:
if
args
.
save
==
'.'
:
args
.
save
=
datetime
.
datetime
.
now
().
strftime
(
'%Y%m%d_%H:%M'
)
self
.
dir
=
'experiment/'
+
args
.
save
# args.save = 'save_path'
else
:
self
.
dir
=
'experiment/'
+
args
.
load
if
not
os
.
path
.
exists
(
self
.
dir
):
args
.
load
=
'.'
else
:
self
.
loss_log
=
torch
.
load
(
self
.
dir
+
'/loss_log.pt'
)
self
.
psnr_log
=
torch
.
load
(
self
.
dir
+
'/psnr_log.pt'
)
print
(
'Continue from epoch {}...'
.
format
(
len
(
self
.
psnr_log
)))
if
args
.
reset
:
os
.
system
(
'rm -rf {}'
.
format
(
self
.
dir
))
args
.
load
=
'.'
'''
if not os.path.exists(self.dir):
os.makedirs(self.dir)
if not os.path.exists(self.dir + '/model'):
os.makedirs(self.dir + '/model')
if not os.path.exists(self.dir + '/result/' + self.args.data_test):
print("Creating dir for saving images...", self.dir + '/result/' + self.args.data_test)
os.makedirs(self.dir + '/result/' + self.args.data_test)
print('Save Path : {}'.format(self.dir)) # Save Path : experiment/save_path
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
self.log_file = open(self.dir + '/log.txt', open_type)
with open(self.dir + '/config.txt', open_type) as f:
f.write('From epoch {}...'.format(len(self.psnr_log)) + '
\n\n
')
for arg in vars(args):
f.write('{}: {}
\n
'.format(arg, getattr(args, arg)))
f.write('
\n
')
'''
def
write_log
(
self
,
log
):
print
(
log
)
self
.
log_file
.
write
(
log
+
'
\n
'
)
def
save
(
self
,
trainer
,
epoch
,
is_best
):
trainer
.
model
.
save
(
self
.
dir
,
epoch
,
is_best
)
torch
.
save
(
self
.
psnr_log
,
os
.
path
.
join
(
self
.
dir
,
'psnr_log.pt'
))
torch
.
save
(
trainer
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
self
.
dir
,
'optimizer.pt'
))
trainer
.
loss
.
save
(
self
.
dir
)
trainer
.
loss
.
plot_loss
(
self
.
dir
,
epoch
)
self
.
plot_psnr_log
(
epoch
)
def
save_images
(
self
,
filename
,
save_list
):
if
self
.
args
.
task
==
'Deblurring'
:
idx
=
0
f
=
filename
[
idx
][
0
].
split
(
'.'
)
filename
=
'./result/{}'
.
format
(
f
[
0
])
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
filename
)):
os
.
makedirs
(
os
.
path
.
dirname
(
filename
))
if
self
.
args
.
model
==
'deblur'
:
postfix
=
[
'DEBLUR'
]
for
img
,
post
in
zip
(
save_list
,
postfix
):
img
=
img
[
0
].
data
.
mul
(
255
/
self
.
args
.
rgb_range
)
img
=
np
.
transpose
(
img
.
cpu
().
numpy
(),
(
1
,
2
,
0
)).
astype
(
'uint8'
)
imageio
.
imwrite
(
'{}{}.png'
.
format
(
filename
,
post
),
img
)
def
start_log
(
self
,
train
=
True
):
if
train
:
self
.
loss_log
=
torch
.
cat
((
self
.
loss_log
,
torch
.
zeros
(
1
)))
else
:
self
.
psnr_log
=
torch
.
cat
((
self
.
psnr_log
,
torch
.
zeros
(
1
)))
def
report_log
(
self
,
item
,
train
=
True
):
if
train
:
self
.
loss_log
[
-
1
]
+=
item
else
:
self
.
psnr_log
[
-
1
]
+=
item
def
end_log
(
self
,
n_div
,
train
=
True
):
if
train
:
self
.
loss_log
[
-
1
].
div_
(
n_div
)
else
:
self
.
psnr_log
[
-
1
].
div_
(
n_div
)
def
plot_loss_log
(
self
,
epoch
):
axis
=
np
.
linspace
(
1
,
epoch
,
epoch
)
fig
=
plt
.
figure
()
plt
.
title
(
'Loss Graph'
)
plt
.
plot
(
axis
,
self
.
loss_log
.
numpy
())
plt
.
legend
()
plt
.
xlabel
(
'Epochs'
)
plt
.
ylabel
(
'Loss'
)
plt
.
grid
(
True
)
plt
.
savefig
(
os
.
path
.
join
(
self
.
dir
,
'loss.pdf'
))
plt
.
close
(
fig
)
def
plot_psnr_log
(
self
,
epoch
):
axis
=
np
.
linspace
(
1
,
epoch
,
epoch
)
fig
=
plt
.
figure
()
plt
.
title
(
'PSNR Graph'
)
plt
.
plot
(
axis
,
self
.
psnr_log
.
numpy
())
plt
.
legend
()
plt
.
xlabel
(
'Epochs'
)
plt
.
ylabel
(
'PSNR'
)
plt
.
grid
(
True
)
plt
.
savefig
(
os
.
path
.
join
(
self
.
dir
,
'psnr.pdf'
))
plt
.
close
(
fig
)
def
done
(
self
):
print
(
''
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment