Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jiangxin Dong
SVMAP
Commits
397b5eb9
Commit
397b5eb9
authored
Jun 17, 2021
by
Jiangxin Dong
Browse files
Upload New File
parent
f44473e5
Changes
1
Hide whitespace changes
Inline
Side-by-side
loss/__init__.py
0 → 100644
View file @
397b5eb9
import
os
from
importlib
import
import_module
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
import
torch.nn
as
nn
class
Loss
(
nn
.
modules
.
loss
.
_Loss
):
def
__init__
(
self
,
args
,
ckp
):
super
(
Loss
,
self
).
__init__
()
print
(
'Preparing loss function:'
)
self
.
n_GPUs
=
args
.
n_GPUs
self
.
loss
=
[]
self
.
loss_module
=
nn
.
ModuleList
()
for
loss
in
args
.
loss
.
split
(
'+'
):
print
(
loss
)
weight
,
loss_type
=
loss
.
split
(
'*'
)
if
loss_type
==
'MSE'
:
loss_function
=
nn
.
MSELoss
()
elif
loss_type
==
'L1'
:
loss_function
=
nn
.
L1Loss
()
elif
loss_type
==
'Huber'
:
loss_function
=
nn
.
SmoothL1Loss
()
elif
loss_type
.
find
(
'GAN'
)
>=
0
:
module
=
import_module
(
'loss.adversarial'
)
loss_function
=
getattr
(
module
,
'Adversarial'
)(
args
,
loss_type
)
self
.
loss
.
append
({
'type'
:
loss_type
,
'weight'
:
float
(
weight
),
'function'
:
loss_function
}
)
if
loss_type
.
find
(
'GAN'
)
>=
0
:
self
.
loss
.
append
({
'type'
:
'DIS'
,
'weight'
:
1
,
'function'
:
None
})
if
len
(
self
.
loss
)
>
1
:
self
.
loss
.
append
({
'type'
:
'Total'
,
'weight'
:
0
,
'function'
:
None
})
for
l
in
self
.
loss
:
if
l
[
'function'
]
is
not
None
:
print
(
'{:.3f} * {}'
.
format
(
l
[
'weight'
],
l
[
'type'
]))
self
.
loss_module
.
append
(
l
[
'function'
])
self
.
log
=
torch
.
Tensor
()
device
=
torch
.
device
(
'cpu'
if
args
.
cpu
else
'cuda'
)
self
.
loss_module
.
to
(
device
)
if
args
.
precision
==
'half'
:
self
.
loss_module
.
half
()
if
not
args
.
cpu
and
args
.
n_GPUs
>
1
:
self
.
loss_module
=
nn
.
DataParallel
(
self
.
loss_module
,
range
(
args
.
n_GPUs
)
)
if
args
.
load
!=
'.'
:
self
.
load
(
ckp
.
dir
,
cpu
=
args
.
cpu
)
def
forward
(
self
,
sr
,
hr
):
losses
=
[]
for
i
,
l
in
enumerate
(
self
.
loss
):
if
l
[
'function'
]
is
not
None
:
loss
=
l
[
'function'
](
sr
,
hr
)
effective_loss
=
l
[
'weight'
]
*
loss
losses
.
append
(
effective_loss
)
self
.
log
[
-
1
,
i
]
+=
effective_loss
.
item
()
elif
l
[
'type'
]
==
'DIS'
:
self
.
log
[
-
1
,
i
]
+=
self
.
loss
[
i
-
1
][
'function'
].
loss
loss_sum
=
sum
(
losses
)
if
len
(
self
.
loss
)
>
1
:
self
.
log
[
-
1
,
-
1
]
+=
loss_sum
.
item
()
return
loss_sum
def
step
(
self
):
for
l
in
self
.
get_loss_module
():
if
hasattr
(
l
,
'scheduler'
):
l
.
scheduler
.
step
()
def
start_log
(
self
):
self
.
log
=
torch
.
cat
((
self
.
log
,
torch
.
zeros
(
1
,
len
(
self
.
loss
))))
def
end_log
(
self
,
n_batches
):
self
.
log
[
-
1
].
div_
(
n_batches
)
def
display_loss
(
self
,
batch
):
n_samples
=
batch
+
1
log
=
[]
for
l
,
c
in
zip
(
self
.
loss
,
self
.
log
[
-
1
]):
log
.
append
(
'[{}: {:.4f}]'
.
format
(
l
[
'type'
],
c
/
n_samples
))
return
''
.
join
(
log
)
def
plot_loss
(
self
,
apath
,
epoch
):
axis
=
np
.
linspace
(
1
,
epoch
,
epoch
)
for
i
,
l
in
enumerate
(
self
.
loss
):
label
=
'{} Loss'
.
format
(
l
[
'type'
])
fig
=
plt
.
figure
()
plt
.
title
(
label
)
plt
.
plot
(
axis
,
self
.
log
[:,
i
].
numpy
(),
label
=
label
)
plt
.
legend
()
plt
.
xlabel
(
'Epochs'
)
plt
.
ylabel
(
'Loss'
)
plt
.
grid
(
True
)
plt
.
savefig
(
'{}/loss_loss_{}.pdf'
.
format
(
apath
,
l
[
'type'
]))
plt
.
close
(
fig
)
def
get_loss_module
(
self
):
if
self
.
n_GPUs
==
1
:
return
self
.
loss_module
else
:
return
self
.
loss_module
.
module
def
save
(
self
,
apath
):
torch
.
save
(
self
.
state_dict
(),
os
.
path
.
join
(
apath
,
'loss.pt'
))
torch
.
save
(
self
.
log
,
os
.
path
.
join
(
apath
,
'loss_log.pt'
))
def
load
(
self
,
apath
,
cpu
=
False
):
if
cpu
:
kwargs
=
{
'map_location'
:
lambda
storage
,
loc
:
storage
}
else
:
kwargs
=
{}
self
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
apath
,
'loss.pt'
),
**
kwargs
))
self
.
log
=
torch
.
load
(
os
.
path
.
join
(
apath
,
'loss_log.pt'
))
for
l
in
self
.
loss_module
:
if
hasattr
(
l
,
'scheduler'
):
for
_
in
range
(
len
(
self
.
log
)):
l
.
scheduler
.
step
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment