Commit 7c08d83c authored by Jiangxin Dong's avatar Jiangxin Dong
Browse files

Upload New File

parent edac7512
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
def make_model(args, parent=False):
return DEBLUR(args)
class DEBLUR(nn.Module):
def __init__(self, args):
super(DEBLUR, self).__init__()
ksize = 5
n_kernel = 5
self.device = "cuda"
self.ksize = ksize
self.n_kernel = n_kernel
self.sigma = args.sigma_for_initialization
nets = []
nets.extend([
nn.Conv2d(args.n_colors, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
for i in range(4):
nets.extend([
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
nets.extend([
nn.Conv2d(64, ksize * ksize * n_kernel * args.n_colors, kernel_size=3, stride=1, padding=1)
])
self.nets1 = nn.Sequential(*nets)
nets = []
nets.extend([
nn.Conv2d(args.n_colors, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
for i in range(4):
nets.extend([
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
nets.extend([
nn.Conv2d(64, ksize * ksize * n_kernel * args.n_colors, kernel_size=3, stride=1, padding=1)
])
self.nets2 = nn.Sequential(*nets)
n_kernel_data = 3
nets = []
nets.extend([
nn.Conv2d(1 * args.n_colors, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
for i in range(4):
nets.extend([
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
nets.extend([
nn.Conv2d(64, ksize * ksize * n_kernel_data * args.n_colors, kernel_size=3, stride=1, padding=1)
])
self.nets1_data = nn.Sequential(*nets)
nets = []
nets.extend([
nn.Conv2d(1 * args.n_colors, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
for i in range(4):
nets.extend([
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
])
nets.extend([
nn.Conv2d(64, ksize * ksize * n_kernel_data * args.n_colors, kernel_size=3, stride=1, padding=1)
])
self.nets2_data = nn.Sequential(*nets)
self.g1_kernel = torch.from_numpy(
np.array([[0, 0, 0], [0, -1, 1], [0, 0, 0]], dtype="float32").reshape((1, 1, 3, 3))).to(self.device)
self.g2_kernel = torch.from_numpy(
np.array([[0, 0, 0], [0, -1, 0], [0, 1, 0]], dtype="float32").reshape((1, 1, 3, 3))).to(self.device)
def kernel_conv3(self, x, kernel, n_kernel):
b, c, h, w = x.size()
psize = self.ksize // 2
x_pad = F.pad(x, (psize, psize, psize, psize))
x_list = []
for i in range(self.ksize):
for j in range(self.ksize):
xp = x_pad[:, :, i:i + h, j:j + w]
x_list.append(torch.unsqueeze(xp,2))
x_repeat = torch.cat(x_list, dim=2)
x_list = [x_repeat for _ in range(n_kernel)]
x_stack = torch.stack(x_list, dim=2)
kernel_stack = kernel.view(b, c, n_kernel, self.ksize * self.ksize, h, w)
result = x_stack * kernel_stack
result = torch.mean(result, dim=3)
return result
def kernel_conv_transpose3(self, x, kernel, n_kernel):
b, c, n_kernel, h, w = x.size()
psize = self.ksize // 2
x_pad = F.pad(x, (psize, psize, psize, psize))
x_list2 = []
for nk in range(n_kernel):
x_list = []
for i in range(self.ksize):
for j in range(self.ksize):
x_list.append(x_pad[:, :, nk:nk+1, i:i + h, j:j + w])
x_repeat = torch.cat(x_list, dim=2)
x_list2.append(x_repeat)
x_stack = torch.stack(x_list2, dim=2)
kernel_stack = kernel.view(b, c, n_kernel, self.ksize * self.ksize, h, w)
idx = [idxx for idxx in range(self.ksize * self.ksize - 1, -1, -1)]
result = x_stack * kernel_stack[:, :, :, idx, :, :]
result = torch.mean(result, dim=3)
return result
def auto_crop_kernel(self, kernel):
end = 0
for i in range(kernel.size()[2]):
if kernel[0, 0, end, 0] == -1:
break
end += 1
kernel = kernel[:, :, :end, :end]
return kernel
def conv_func(self, input, kernel, padding='same'):
b, c, h, w = input.size()
_, _, ksize, ksize = kernel.size()
if padding == 'same':
pad = ksize // 2
elif padding == 'valid':
pad = 0
else:
raise Exception("not support padding flag!")
conv_result = []
for i in range(c):
conv_result.append(F.conv2d(input[:, i:i + 1, :, :], kernel, bias=None, stride=1, padding=pad))
conv_result_tensor = torch.cat(conv_result, dim=1)
return conv_result_tensor
def dual_conv(self, input, kernel, mask, coefficient):
kernel_numpy = kernel.cpu().numpy()[:, :, ::-1, ::-1]
kernel_numpy = np.ascontiguousarray(kernel_numpy)
kernel_flip = torch.from_numpy(kernel_numpy).to(self.device)
result = self.conv_func(input, kernel_flip, padding='same')
result = self.conv_func(result * mask * coefficient, kernel, padding='same')
return result
def dual_conv_grad(self, input, kernel):
kernel_numpy = kernel.cpu().numpy()[:, :, ::-1, ::-1]
kernel_numpy = np.ascontiguousarray(kernel_numpy)
kernel_flip = torch.from_numpy(kernel_numpy).to(self.device)
result = self.conv_func(input, kernel_flip, padding='same')
return result
def vector_inner_product3(self, x1, x2):
b, c, h, w = x1.size()
x1 = x1.view(b, c, -1)
x2 = x2.view(b, c, -1)
re = x1 * x2
re = torch.sum(re, dim=2)
re = re.view(b, c, 1, 1)
return re
def deconv_func(self, input, input_ori, kernel, alpha, beta1, beta2): # , beta11, beta22, beta12
b, c, h, w = input.size()
assert b == 1, "only support one image deconv operation!"
kernel = self.auto_crop_kernel(kernel)
kernel = torch.from_numpy(np.ascontiguousarray(kernel.cpu().numpy()[:, :, ::-1, ::-1])).to(self.device)
kb, kc, ksize, ksize = kernel.size()
psize = ksize // 2
assert kb == b, "kernel batch must be equal to input batch!"
assert kc == 1, "kernel channel must be 1!"
assert ksize % 2 == 1, "only support odd kernel size!"
mask = torch.zeros_like(alpha).to(self.device)
mask[:, :, psize:-psize, psize:-psize] = 1.
mask_beta = torch.ones_like(alpha).to(self.device)
alphanew = torch.ones_like(alpha).to(self.device)
x = input
b = self.conv_func(input_ori * mask, kernel, padding='same')
sigma = self.sigma
Ax = self.dual_conv(x, kernel, mask, alphanew)
Ax = Ax + sigma * self.dual_conv(x, self.g1_kernel, mask_beta, beta1) \
+ sigma * self.dual_conv(x, self.g2_kernel, mask_beta, beta2)
r = b - Ax
for i in range(25):
rho = self.vector_inner_product3(r, r)
if i == 0:
p = r
else:
beta = rho / rho_1
p = r + beta * p
Ap = self.dual_conv(p, kernel, mask, alphanew)
Ap = Ap + sigma * self.dual_conv(p, self.g1_kernel, mask_beta, beta1) \
+ sigma * self.dual_conv(p, self.g2_kernel, mask_beta, beta2)
q = Ap
alp = rho / self.vector_inner_product3(p, q)
x = x + alp * p
r = r - alp * q
rho_1 = rho
deconv_result = x
return deconv_result
def deconv_func2(self, input, input_ori, kernel, filters, filters_data, alpha, beta1, beta2): # , beta11, beta22, beta12
b, c, h, w = input.size()
assert b == 1, "only support one image deconv operation!"
kernel = self.auto_crop_kernel(kernel)
kernel = torch.from_numpy(np.ascontiguousarray(kernel.cpu().numpy()[:, :, ::-1, ::-1])).to(self.device)
kb, kc, ksize, ksize = kernel.size()
psize = ksize // 2
assert kb == b, "kernel batch must be equal to input batch!"
assert kc == 1, "kernel channel must be 1!"
assert ksize % 2 == 1, "only support odd kernel size!"
mask = torch.zeros_like(alpha).to(self.device)
mask[:, :, psize:-psize, psize:-psize] = 1.
Fy = self.kernel_conv3(input_ori, filters_data, 3)
FtFy = self.kernel_conv_transpose3(Fy, filters_data, 3)
FtFy_sum = torch.sum(FtFy, dim=2)
b = self.conv_func(FtFy_sum * mask, kernel, padding='same')
x = input
Kx = self.dual_conv_grad(x, kernel)
FKx = self.kernel_conv3(Kx, filters_data, 3)
FtFKx = self.kernel_conv_transpose3(FKx, filters_data, 3)
FtFKx_sum = torch.sum(FtFKx, dim=2)
Ax= self.conv_func(FtFKx_sum * mask, kernel, padding='same')
Gx = self.kernel_conv3(x, filters, 5)
GtGx = self.kernel_conv_transpose3(Gx, filters, 5)
GtGx_sum = torch.sum(GtGx, dim=2)
Ax = Ax + GtGx_sum
r = b - Ax
for i in range(5):
rho = self.vector_inner_product3(r, r)
if i == 0:
p = r
else:
beta = rho / rho_1
p = r + beta * p
Kp = self.dual_conv_grad(p, kernel)
FKp = self.kernel_conv3(Kp, filters_data, 3)
FtFKp = self.kernel_conv_transpose3(FKp, filters_data, 3)
FtFKp_sum = torch.sum(FtFKp, dim=2)
Ap = self.conv_func(FtFKp_sum * mask, kernel, padding='same')
Gp = self.kernel_conv3(p, filters, 5)
GtGp = self.kernel_conv_transpose3(Gp, filters, 5)
GtGp_sum = torch.sum(GtGp, dim=2)
Ap = Ap + GtGp_sum
q = Ap
alp = rho / self.vector_inner_product3(p, q)
x = x + alp * p
r = r - alp * q
rho_1 = rho
deconv_result = x
return deconv_result
def forward(self, input, kernel):
b, c, h, w = input.size()
_, _, ksize, ksize = kernel.size()
psize = ksize // 2
input_pad = F.pad(input, (psize, psize, psize, psize), mode='replicate')
alpha = torch.ones_like(input_pad).to(self.device)
beta1 = torch.ones_like(input_pad).to(self.device)
beta2 = torch.ones_like(input_pad).to(self.device)
deconv_list = []
for j in range(b):
deconv_list.append(
self.deconv_func(input_pad[j:j + 1, :, :, :], input_pad[j:j + 1, :, :, :],
kernel[j:j + 1, :, :, :],
alpha[j:j + 1, :, :, :], beta1[j:j + 1, :, :, :], beta2[j:j + 1, :, :, :]))
deconv = torch.cat(deconv_list, dim=0)
filters = self.nets1(deconv)
filters_data = self.nets1_data(deconv)
deconv_list = []
for j in range(b):
deconv_list.append(
self.deconv_func2(deconv[j:j + 1, :, :, :], input_pad[j:j + 1, :, :, :],
kernel[j:j + 1, :, :, :], filters[j:j + 1, :, :, :],
filters_data[j:j + 1, :, :, :],
alpha[j:j + 1, :, :, :], beta1[j:j + 1, :, :, :], beta2[j:j + 1, :, :, :]))
deconv = torch.cat(deconv_list, dim=0)
filters = self.nets2(deconv)
filters_data = self.nets2_data(deconv)
deconv_list = []
for j in range(b):
deconv_list.append(
self.deconv_func2(deconv[j:j + 1, :, :, :], input_pad[j:j + 1, :, :, :],
kernel[j:j + 1, :, :, :], filters[j:j + 1, :, :, :],
filters_data[j:j + 1, :, :, :],
alpha[j:j + 1, :, :, :], beta1[j:j + 1, :, :, :], beta2[j:j + 1, :, :, :]))
deconv = torch.cat(deconv_list, dim=0)
result = deconv[:, :, psize:-psize, psize:-psize]
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment