Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jiangxin Dong
SVMAP
Commits
7c08d83c
Commit
7c08d83c
authored
Jun 17, 2021
by
Jiangxin Dong
Browse files
Upload New File
parent
edac7512
Changes
1
Show whitespace changes
Inline
Side-by-side
model/deblur.py
0 → 100644
View file @
7c08d83c
import
torch.nn
as
nn
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
def
make_model
(
args
,
parent
=
False
):
return
DEBLUR
(
args
)
class
DEBLUR
(
nn
.
Module
):
def
__init__
(
self
,
args
):
super
(
DEBLUR
,
self
).
__init__
()
ksize
=
5
n_kernel
=
5
self
.
device
=
"cuda"
self
.
ksize
=
ksize
self
.
n_kernel
=
n_kernel
self
.
sigma
=
args
.
sigma_for_initialization
nets
=
[]
nets
.
extend
([
nn
.
Conv2d
(
args
.
n_colors
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
for
i
in
range
(
4
):
nets
.
extend
([
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
nets
.
extend
([
nn
.
Conv2d
(
64
,
ksize
*
ksize
*
n_kernel
*
args
.
n_colors
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
])
self
.
nets1
=
nn
.
Sequential
(
*
nets
)
nets
=
[]
nets
.
extend
([
nn
.
Conv2d
(
args
.
n_colors
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
for
i
in
range
(
4
):
nets
.
extend
([
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
nets
.
extend
([
nn
.
Conv2d
(
64
,
ksize
*
ksize
*
n_kernel
*
args
.
n_colors
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
])
self
.
nets2
=
nn
.
Sequential
(
*
nets
)
n_kernel_data
=
3
nets
=
[]
nets
.
extend
([
nn
.
Conv2d
(
1
*
args
.
n_colors
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
for
i
in
range
(
4
):
nets
.
extend
([
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
nets
.
extend
([
nn
.
Conv2d
(
64
,
ksize
*
ksize
*
n_kernel_data
*
args
.
n_colors
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
])
self
.
nets1_data
=
nn
.
Sequential
(
*
nets
)
nets
=
[]
nets
.
extend
([
nn
.
Conv2d
(
1
*
args
.
n_colors
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
for
i
in
range
(
4
):
nets
.
extend
([
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
)
])
nets
.
extend
([
nn
.
Conv2d
(
64
,
ksize
*
ksize
*
n_kernel_data
*
args
.
n_colors
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
])
self
.
nets2_data
=
nn
.
Sequential
(
*
nets
)
self
.
g1_kernel
=
torch
.
from_numpy
(
np
.
array
([[
0
,
0
,
0
],
[
0
,
-
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
"float32"
).
reshape
((
1
,
1
,
3
,
3
))).
to
(
self
.
device
)
self
.
g2_kernel
=
torch
.
from_numpy
(
np
.
array
([[
0
,
0
,
0
],
[
0
,
-
1
,
0
],
[
0
,
1
,
0
]],
dtype
=
"float32"
).
reshape
((
1
,
1
,
3
,
3
))).
to
(
self
.
device
)
def
kernel_conv3
(
self
,
x
,
kernel
,
n_kernel
):
b
,
c
,
h
,
w
=
x
.
size
()
psize
=
self
.
ksize
//
2
x_pad
=
F
.
pad
(
x
,
(
psize
,
psize
,
psize
,
psize
))
x_list
=
[]
for
i
in
range
(
self
.
ksize
):
for
j
in
range
(
self
.
ksize
):
xp
=
x_pad
[:,
:,
i
:
i
+
h
,
j
:
j
+
w
]
x_list
.
append
(
torch
.
unsqueeze
(
xp
,
2
))
x_repeat
=
torch
.
cat
(
x_list
,
dim
=
2
)
x_list
=
[
x_repeat
for
_
in
range
(
n_kernel
)]
x_stack
=
torch
.
stack
(
x_list
,
dim
=
2
)
kernel_stack
=
kernel
.
view
(
b
,
c
,
n_kernel
,
self
.
ksize
*
self
.
ksize
,
h
,
w
)
result
=
x_stack
*
kernel_stack
result
=
torch
.
mean
(
result
,
dim
=
3
)
return
result
def
kernel_conv_transpose3
(
self
,
x
,
kernel
,
n_kernel
):
b
,
c
,
n_kernel
,
h
,
w
=
x
.
size
()
psize
=
self
.
ksize
//
2
x_pad
=
F
.
pad
(
x
,
(
psize
,
psize
,
psize
,
psize
))
x_list2
=
[]
for
nk
in
range
(
n_kernel
):
x_list
=
[]
for
i
in
range
(
self
.
ksize
):
for
j
in
range
(
self
.
ksize
):
x_list
.
append
(
x_pad
[:,
:,
nk
:
nk
+
1
,
i
:
i
+
h
,
j
:
j
+
w
])
x_repeat
=
torch
.
cat
(
x_list
,
dim
=
2
)
x_list2
.
append
(
x_repeat
)
x_stack
=
torch
.
stack
(
x_list2
,
dim
=
2
)
kernel_stack
=
kernel
.
view
(
b
,
c
,
n_kernel
,
self
.
ksize
*
self
.
ksize
,
h
,
w
)
idx
=
[
idxx
for
idxx
in
range
(
self
.
ksize
*
self
.
ksize
-
1
,
-
1
,
-
1
)]
result
=
x_stack
*
kernel_stack
[:,
:,
:,
idx
,
:,
:]
result
=
torch
.
mean
(
result
,
dim
=
3
)
return
result
def
auto_crop_kernel
(
self
,
kernel
):
end
=
0
for
i
in
range
(
kernel
.
size
()[
2
]):
if
kernel
[
0
,
0
,
end
,
0
]
==
-
1
:
break
end
+=
1
kernel
=
kernel
[:,
:,
:
end
,
:
end
]
return
kernel
def
conv_func
(
self
,
input
,
kernel
,
padding
=
'same'
):
b
,
c
,
h
,
w
=
input
.
size
()
_
,
_
,
ksize
,
ksize
=
kernel
.
size
()
if
padding
==
'same'
:
pad
=
ksize
//
2
elif
padding
==
'valid'
:
pad
=
0
else
:
raise
Exception
(
"not support padding flag!"
)
conv_result
=
[]
for
i
in
range
(
c
):
conv_result
.
append
(
F
.
conv2d
(
input
[:,
i
:
i
+
1
,
:,
:],
kernel
,
bias
=
None
,
stride
=
1
,
padding
=
pad
))
conv_result_tensor
=
torch
.
cat
(
conv_result
,
dim
=
1
)
return
conv_result_tensor
def
dual_conv
(
self
,
input
,
kernel
,
mask
,
coefficient
):
kernel_numpy
=
kernel
.
cpu
().
numpy
()[:,
:,
::
-
1
,
::
-
1
]
kernel_numpy
=
np
.
ascontiguousarray
(
kernel_numpy
)
kernel_flip
=
torch
.
from_numpy
(
kernel_numpy
).
to
(
self
.
device
)
result
=
self
.
conv_func
(
input
,
kernel_flip
,
padding
=
'same'
)
result
=
self
.
conv_func
(
result
*
mask
*
coefficient
,
kernel
,
padding
=
'same'
)
return
result
def
dual_conv_grad
(
self
,
input
,
kernel
):
kernel_numpy
=
kernel
.
cpu
().
numpy
()[:,
:,
::
-
1
,
::
-
1
]
kernel_numpy
=
np
.
ascontiguousarray
(
kernel_numpy
)
kernel_flip
=
torch
.
from_numpy
(
kernel_numpy
).
to
(
self
.
device
)
result
=
self
.
conv_func
(
input
,
kernel_flip
,
padding
=
'same'
)
return
result
def
vector_inner_product3
(
self
,
x1
,
x2
):
b
,
c
,
h
,
w
=
x1
.
size
()
x1
=
x1
.
view
(
b
,
c
,
-
1
)
x2
=
x2
.
view
(
b
,
c
,
-
1
)
re
=
x1
*
x2
re
=
torch
.
sum
(
re
,
dim
=
2
)
re
=
re
.
view
(
b
,
c
,
1
,
1
)
return
re
def
deconv_func
(
self
,
input
,
input_ori
,
kernel
,
alpha
,
beta1
,
beta2
):
# , beta11, beta22, beta12
b
,
c
,
h
,
w
=
input
.
size
()
assert
b
==
1
,
"only support one image deconv operation!"
kernel
=
self
.
auto_crop_kernel
(
kernel
)
kernel
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
kernel
.
cpu
().
numpy
()[:,
:,
::
-
1
,
::
-
1
])).
to
(
self
.
device
)
kb
,
kc
,
ksize
,
ksize
=
kernel
.
size
()
psize
=
ksize
//
2
assert
kb
==
b
,
"kernel batch must be equal to input batch!"
assert
kc
==
1
,
"kernel channel must be 1!"
assert
ksize
%
2
==
1
,
"only support odd kernel size!"
mask
=
torch
.
zeros_like
(
alpha
).
to
(
self
.
device
)
mask
[:,
:,
psize
:
-
psize
,
psize
:
-
psize
]
=
1.
mask_beta
=
torch
.
ones_like
(
alpha
).
to
(
self
.
device
)
alphanew
=
torch
.
ones_like
(
alpha
).
to
(
self
.
device
)
x
=
input
b
=
self
.
conv_func
(
input_ori
*
mask
,
kernel
,
padding
=
'same'
)
sigma
=
self
.
sigma
Ax
=
self
.
dual_conv
(
x
,
kernel
,
mask
,
alphanew
)
Ax
=
Ax
+
sigma
*
self
.
dual_conv
(
x
,
self
.
g1_kernel
,
mask_beta
,
beta1
)
\
+
sigma
*
self
.
dual_conv
(
x
,
self
.
g2_kernel
,
mask_beta
,
beta2
)
r
=
b
-
Ax
for
i
in
range
(
25
):
rho
=
self
.
vector_inner_product3
(
r
,
r
)
if
i
==
0
:
p
=
r
else
:
beta
=
rho
/
rho_1
p
=
r
+
beta
*
p
Ap
=
self
.
dual_conv
(
p
,
kernel
,
mask
,
alphanew
)
Ap
=
Ap
+
sigma
*
self
.
dual_conv
(
p
,
self
.
g1_kernel
,
mask_beta
,
beta1
)
\
+
sigma
*
self
.
dual_conv
(
p
,
self
.
g2_kernel
,
mask_beta
,
beta2
)
q
=
Ap
alp
=
rho
/
self
.
vector_inner_product3
(
p
,
q
)
x
=
x
+
alp
*
p
r
=
r
-
alp
*
q
rho_1
=
rho
deconv_result
=
x
return
deconv_result
def
deconv_func2
(
self
,
input
,
input_ori
,
kernel
,
filters
,
filters_data
,
alpha
,
beta1
,
beta2
):
# , beta11, beta22, beta12
b
,
c
,
h
,
w
=
input
.
size
()
assert
b
==
1
,
"only support one image deconv operation!"
kernel
=
self
.
auto_crop_kernel
(
kernel
)
kernel
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
kernel
.
cpu
().
numpy
()[:,
:,
::
-
1
,
::
-
1
])).
to
(
self
.
device
)
kb
,
kc
,
ksize
,
ksize
=
kernel
.
size
()
psize
=
ksize
//
2
assert
kb
==
b
,
"kernel batch must be equal to input batch!"
assert
kc
==
1
,
"kernel channel must be 1!"
assert
ksize
%
2
==
1
,
"only support odd kernel size!"
mask
=
torch
.
zeros_like
(
alpha
).
to
(
self
.
device
)
mask
[:,
:,
psize
:
-
psize
,
psize
:
-
psize
]
=
1.
Fy
=
self
.
kernel_conv3
(
input_ori
,
filters_data
,
3
)
FtFy
=
self
.
kernel_conv_transpose3
(
Fy
,
filters_data
,
3
)
FtFy_sum
=
torch
.
sum
(
FtFy
,
dim
=
2
)
b
=
self
.
conv_func
(
FtFy_sum
*
mask
,
kernel
,
padding
=
'same'
)
x
=
input
Kx
=
self
.
dual_conv_grad
(
x
,
kernel
)
FKx
=
self
.
kernel_conv3
(
Kx
,
filters_data
,
3
)
FtFKx
=
self
.
kernel_conv_transpose3
(
FKx
,
filters_data
,
3
)
FtFKx_sum
=
torch
.
sum
(
FtFKx
,
dim
=
2
)
Ax
=
self
.
conv_func
(
FtFKx_sum
*
mask
,
kernel
,
padding
=
'same'
)
Gx
=
self
.
kernel_conv3
(
x
,
filters
,
5
)
GtGx
=
self
.
kernel_conv_transpose3
(
Gx
,
filters
,
5
)
GtGx_sum
=
torch
.
sum
(
GtGx
,
dim
=
2
)
Ax
=
Ax
+
GtGx_sum
r
=
b
-
Ax
for
i
in
range
(
5
):
rho
=
self
.
vector_inner_product3
(
r
,
r
)
if
i
==
0
:
p
=
r
else
:
beta
=
rho
/
rho_1
p
=
r
+
beta
*
p
Kp
=
self
.
dual_conv_grad
(
p
,
kernel
)
FKp
=
self
.
kernel_conv3
(
Kp
,
filters_data
,
3
)
FtFKp
=
self
.
kernel_conv_transpose3
(
FKp
,
filters_data
,
3
)
FtFKp_sum
=
torch
.
sum
(
FtFKp
,
dim
=
2
)
Ap
=
self
.
conv_func
(
FtFKp_sum
*
mask
,
kernel
,
padding
=
'same'
)
Gp
=
self
.
kernel_conv3
(
p
,
filters
,
5
)
GtGp
=
self
.
kernel_conv_transpose3
(
Gp
,
filters
,
5
)
GtGp_sum
=
torch
.
sum
(
GtGp
,
dim
=
2
)
Ap
=
Ap
+
GtGp_sum
q
=
Ap
alp
=
rho
/
self
.
vector_inner_product3
(
p
,
q
)
x
=
x
+
alp
*
p
r
=
r
-
alp
*
q
rho_1
=
rho
deconv_result
=
x
return
deconv_result
def
forward
(
self
,
input
,
kernel
):
b
,
c
,
h
,
w
=
input
.
size
()
_
,
_
,
ksize
,
ksize
=
kernel
.
size
()
psize
=
ksize
//
2
input_pad
=
F
.
pad
(
input
,
(
psize
,
psize
,
psize
,
psize
),
mode
=
'replicate'
)
alpha
=
torch
.
ones_like
(
input_pad
).
to
(
self
.
device
)
beta1
=
torch
.
ones_like
(
input_pad
).
to
(
self
.
device
)
beta2
=
torch
.
ones_like
(
input_pad
).
to
(
self
.
device
)
deconv_list
=
[]
for
j
in
range
(
b
):
deconv_list
.
append
(
self
.
deconv_func
(
input_pad
[
j
:
j
+
1
,
:,
:,
:],
input_pad
[
j
:
j
+
1
,
:,
:,
:],
kernel
[
j
:
j
+
1
,
:,
:,
:],
alpha
[
j
:
j
+
1
,
:,
:,
:],
beta1
[
j
:
j
+
1
,
:,
:,
:],
beta2
[
j
:
j
+
1
,
:,
:,
:]))
deconv
=
torch
.
cat
(
deconv_list
,
dim
=
0
)
filters
=
self
.
nets1
(
deconv
)
filters_data
=
self
.
nets1_data
(
deconv
)
deconv_list
=
[]
for
j
in
range
(
b
):
deconv_list
.
append
(
self
.
deconv_func2
(
deconv
[
j
:
j
+
1
,
:,
:,
:],
input_pad
[
j
:
j
+
1
,
:,
:,
:],
kernel
[
j
:
j
+
1
,
:,
:,
:],
filters
[
j
:
j
+
1
,
:,
:,
:],
filters_data
[
j
:
j
+
1
,
:,
:,
:],
alpha
[
j
:
j
+
1
,
:,
:,
:],
beta1
[
j
:
j
+
1
,
:,
:,
:],
beta2
[
j
:
j
+
1
,
:,
:,
:]))
deconv
=
torch
.
cat
(
deconv_list
,
dim
=
0
)
filters
=
self
.
nets2
(
deconv
)
filters_data
=
self
.
nets2_data
(
deconv
)
deconv_list
=
[]
for
j
in
range
(
b
):
deconv_list
.
append
(
self
.
deconv_func2
(
deconv
[
j
:
j
+
1
,
:,
:,
:],
input_pad
[
j
:
j
+
1
,
:,
:,
:],
kernel
[
j
:
j
+
1
,
:,
:,
:],
filters
[
j
:
j
+
1
,
:,
:,
:],
filters_data
[
j
:
j
+
1
,
:,
:,
:],
alpha
[
j
:
j
+
1
,
:,
:,
:],
beta1
[
j
:
j
+
1
,
:,
:,
:],
beta2
[
j
:
j
+
1
,
:,
:,
:]))
deconv
=
torch
.
cat
(
deconv_list
,
dim
=
0
)
result
=
deconv
[:,
:,
psize
:
-
psize
,
psize
:
-
psize
]
return
result
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment