基于tensorflow的快速图像风格迁移系统
该模型分为两个部分,一个风格图像生成网络(Image Transform Net),和一个用于计算损失的网络(VGG-16)。
程序设计
风格训练
# 导包操作
from __future__ import print_function
import sys, os, pdb
sys.path.insert(0, 'src')
import numpy as np, scipy.misc
from src.optimize import optimize
from argparse import ArgumentParser
from src.utils import save_img, get_img, exists, list_files
import evaluate
导入库和模块
# 设置权重,分别为内容、风格和视频的权重
CONTENT_WEIGHT = 7.5e0
STYLE_WEIGHT = 1e2
TV_WEIGHT = 2e2
# 设置学习率、迭代次数、设备等参数
LEARNING_RATE = 1e-3
NUM_EPOCHS = 2
CHECKPOINT_DIR = 'checkpoints/0505'
CHECKPOINT_ITERATIONS = 2000
VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
# TRAIN_PATH = 'F:/hua/coco/train2014'
TRAIN_PATH = '/data1/scm/ssd/data/coco/train2014'
BATCH_SIZE = 1
DEVICE = '/gpu:0'
FRAC_GPU = 1
设置超参数
def build_parser():
parser = ArgumentParser()
parser.add_argument('--checkpoint-dir', type=str, default=CHECKPOINT_DIR)
parser.add_argument('--style', type=str, required=True)
parser.add_argument('--train-path', type=str, default=TRAIN_PATH)
parser.add_argument('--test', type=str, default=False)
parser.add_argument('--test-dir', type=str, default=False)
parser.add_argument('--slow', action='store_true', default=False)
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS)
parser.add_argument('--batch-size', type=int, default=BATCH_SIZE)
parser.add_argument('--checkpoint-iterations', type=int, default=CHECKPOINT_ITERATIONS)
parser.add_argument('--vgg-path', type=str, default=VGG_PATH)
parser.add_argument('--content-weight', type=float, default=CONTENT_WEIGHT)
parser.add_argument('--style-weight', type=float, default=STYLE_WEIGHT)
parser.add_argument('--tv-weight', type=float, default=TV_WEIGHT)
parser.add_argument('--learning-rate', type=float, default=LEARNING_RATE)
return parser
参数解析
def check_opts(opts):
exists(opts.checkpoint_dir, "checkpoint dir not found!")
exists(opts.style, "style path not found!")
exists(opts.train_path, "train path not found!")
if opts.test or opts.test_dir:
exists(opts.test, "test img not found!")
exists(opts.test_dir, "test directory not found!")
exists(opts.vgg_path, "vgg network data not found!")
assert opts.epochs > 0
assert opts.batch_size > 0
assert opts.checkpoint_iterations > 0
assert os.path.exists(opts.vgg_path)
assert opts.content_weight >= 0
assert opts.style_weight >= 0
assert opts.tv_weight >= 0
assert opts.learning_rate >= 0
参数及路径检查
def _get_files(img_dir):
files = list_files(img_dir)
return [os.path.join(img_dir, x) for x in files]
文件处理函数
def main():
parser = build_parser()
options = parser.parse_args()
check_opts(options)
style_target = get_img(options.style)
if not options.slow:
content_targets = _get_files(options.train_path)
elif options.test:
content_targets = [options.test]
kwargs = {
"slow": options.slow,
"epochs": options.epochs,
"print_iterations": options.checkpoint_iterations,
"batch_size": options.batch_size,
"save_path": os.path.join(options.checkpoint_dir, 'fns.ckpt'),
"learning_rate": options.learning_rate
}
if options.slow:
if options.epochs < 10:
kwargs['epochs'] = 1000
if options.learning_rate < 1:
kwargs['learning_rate'] = 1e1
args = [
content_targets,
style_target,
options.content_weight,
options.style_weight,
options.tv_weight,
options.vgg_path
]
for preds, losses, i, epoch in optimize(*args, **kwargs):
style_loss, content_loss, tv_loss, loss = losses
print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
to_print = (style_loss, content_loss, tv_loss)
print('style: %s, content:%s, tv: %s' % to_print)
if options.test:
assert options.test_dir != False
preds_path = '%s/%s_%s.png' % (options.test_dir, epoch, i)
if not options.slow:
ckpt_dir = os.path.dirname(options.checkpoint_dir)
evaluate.ffwd_to_img(options.test, preds_path, options.checkpoint_dir)
else:
pass
ckpt_dir = options.checkpoint_dir
cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
print("Training complete. For evaluation:\n `%s`" % cmd_text)
主函数,首先解析参数并进行检查。获取风格图像和内容图像(训练图像)。设置优化参数并调用 optimize 函数进行训练。每个迭代周期打印损失值,并在测试模式下保存测试结果图像。打印训练完成后的提示信息。
风格迁移处理
使用 PyQt5 创建的桌面应用程序,旨在实现基于深度学习的图像风格迁移。用户可以选择一张图片和一个风格模型,然后将选定的风格应用到该图片上。
import sys
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from evaluate import for_UI
# 调整图片大小
def shrinkImage(img_path, output_height):
'''
缩小图片
:return:
'''
# 固定高度
# scale = 0.8 #每次缩小20%
img = QImage(img_path) # 创建图片实例
# print(img.width())
scale = output_height / img.height() # 缩放比例
output_width = int(img.width() * scale)
size = QSize(output_width, output_height)
pixImg = QPixmap.fromImage(img.scaled(size, Qt.IgnoreAspectRatio))
return pixImg
# 先是组件,组件完事之后再添加布局
class MainWindow(QTabWidget):
def __init__(self):
super().__init__()
self.setWindowTitle('基于Fast StyleTransfer图像风格迁移系统')
self.setWindowIcon(QIcon('images/logo.png'))
self.resize(1600, 600)
self.initUI()
def initUI(self):
# 主布局
# addWidget可以添加方式,比如第一个位置是父窗口,第二个参数为间距,第三个参数为方式
# QFont(字体样式,大小,加粗QFont.bold)
main_widget = QWidget()
about_widget = QWidget()
generally_font = QFont('楷体', 20)
main_layout = QHBoxLayout()
# 左边窗口
left_widget = QWidget()
left_layout = QVBoxLayout()
label_input = QLabel("原图")
# label_input.setAlignment(Qt.AlignCenter)
label_input.setFont(generally_font)
self.img_input = QLabel()
self.input_image_path = "images/chicago.jpg"
self.img_input.setPixmap(QPixmap(shrinkImage(self.input_image_path, 300)))
upload_btn = QPushButton(" 选择图片 ")
upload_btn.setFont(generally_font)
upload_btn.clicked.connect(self.chose_file)
self.cb = QComboBox()
self.cb.addItem("la_muse")
self.cb.addItem("rain_princess")
self.cb.addItem("scream")
self.cb.addItem("udnie")
self.cb.addItem("wave")
self.cb.addItem("wreck")
self.cb.setFont(generally_font)
left_layout.addWidget(label_input, 0, Qt.AlignCenter | Qt.AlignTop)
left_layout.addWidget(self.img_input, 0, Qt.AlignCenter)
left_layout.addWidget(upload_btn, 0, Qt.AlignCenter)
left_layout.addWidget(self.cb, 0, Qt.AlignCenter)
left_widget.setLayout(left_layout)
# 中间窗口
middle_widget = QWidget()
middle_layout = QVBoxLayout()
label_title_2 = QLabel('迁移结果')
label_title_2.setFont(generally_font)
self.img_middle = QLabel("中间结果")
# img_middle = QLabel()
self.img_middle.setPixmap(QPixmap(shrinkImage('images/temp_transfer.jpg', 300)))
btn_chong = QPushButton(" 风格迁移 ")
btn_chong.setFont(generally_font)
btn_chong.clicked.connect(self.change_style)
xx = QLabel()
xx.setFont(generally_font)
middle_layout.addWidget(label_title_2, 0, Qt.AlignCenter | Qt.AlignTop)
middle_layout.addWidget(self.img_middle, 0, Qt.AlignCenter)
middle_layout.addWidget(xx, 0, Qt.AlignCenter)
middle_layout.addWidget(btn_chong, 0, Qt.AlignCenter)
middle_widget.setLayout(middle_layout)
# 关于界面
about_layout = QVBoxLayout()
about_title = QLabel('欢迎使用风格迁移系统\n'
'QQ:718005487')
about_title.setFont(QFont('楷体', 18))
about_title.setAlignment(Qt.AlignCenter)
about_img = QLabel()
about_img.setPixmap(QPixmap('images/logo.png'))
about_img.setAlignment(Qt.AlignCenter)
label_super = QLabel()
label_super.setText("<a href='http://yjxyzxyz.cn'>我的个人主页</a>")
label_super.setFont(QFont('楷体', 12))
label_super.setOpenExternalLinks(True)
label_super.setAlignment(Qt.AlignRight)
# git_img = QMovie('images/')
about_layout.addWidget(about_title)
about_layout.addStretch()
about_layout.addWidget(about_img)
about_layout.addStretch()
about_layout.addWidget(label_super)
about_widget.setLayout(about_layout)
# 主页面设置
main_layout.addWidget(left_widget)
# main_layout.addStretch(0)
main_layout.addWidget(middle_widget)
# main_layout.addWidget(right_widget)
main_widget.setLayout(main_layout)
self.addTab(main_widget, '主页面')
self.addTab(about_widget, '关于')
self.setTabIcon(0, QIcon('images/主页面.png'))
self.setTabIcon(1, QIcon('images/关于.png'))
def chose_file(self):
print("选择图片")
fname, _ = QFileDialog.getOpenFileNames(self, 'oepn file',
'D:\\ubuntu\\transfer\\fast-style-transfer_lff\\images',
"Image files(*.jpg *png)")
# print(fname)
if len(fname) > 0:
self.input_image_path = fname[0]
self.img_input.setPixmap(QPixmap(shrinkImage(self.input_image_path, 300)))
# 改变图片
def change_style(self):
print("风格迁移")
inpath = self.input_image_path
checkdir = 'models/' + self.cb.currentText() + '.ckpt'
out_path = 'images/temp_transfer.jpg'
for_UI(inpath, out_path, checkdir)
self.img_middle.setPixmap(QPixmap(shrinkImage('images/temp_transfer.jpg', 300)))
def closeEvent(self, event):
reply = QMessageBox.question(self,
'quit',
"Are you sure?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No)
if reply == QMessageBox.Yes:
self.close()
event.accept()
else:
event.ignore()
if __name__ == "__main__":
app = QApplication(sys.argv)
mainWindow = MainWindow()
mainWindow.show()
sys.exit(app.exec_())