基于tensorflow的快速图像风格迁移系统

该模型分为两个部分,一个风格图像生成网络(Image Transform Net),和一个用于计算损失的网络(VGG-16)。

程序设计

风格训练

# 导包操作
from __future__ import print_function
import sys, os, pdb
sys.path.insert(0, 'src')
import numpy as np, scipy.misc
from src.optimize import optimize
from argparse import ArgumentParser
from src.utils import save_img, get_img, exists, list_files
import evaluate

导入库和模块

# 设置权重,分别为内容、风格和视频的权重
CONTENT_WEIGHT = 7.5e0
STYLE_WEIGHT = 1e2
TV_WEIGHT = 2e2

# 设置学习率、迭代次数、设备等参数
LEARNING_RATE = 1e-3
NUM_EPOCHS = 2
CHECKPOINT_DIR = 'checkpoints/0505'
CHECKPOINT_ITERATIONS = 2000
VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
# TRAIN_PATH = 'F:/hua/coco/train2014'
TRAIN_PATH = '/data1/scm/ssd/data/coco/train2014'
BATCH_SIZE = 1
DEVICE = '/gpu:0'
FRAC_GPU = 1

设置超参数

def build_parser():
    parser = ArgumentParser()
    parser.add_argument('--checkpoint-dir', type=str, default=CHECKPOINT_DIR)
    parser.add_argument('--style', type=str, required=True)
    parser.add_argument('--train-path', type=str, default=TRAIN_PATH)
    parser.add_argument('--test', type=str, default=False)
    parser.add_argument('--test-dir', type=str, default=False)
    parser.add_argument('--slow', action='store_true', default=False)
    parser.add_argument('--epochs', type=int, default=NUM_EPOCHS)
    parser.add_argument('--batch-size', type=int, default=BATCH_SIZE)
    parser.add_argument('--checkpoint-iterations', type=int, default=CHECKPOINT_ITERATIONS)
    parser.add_argument('--vgg-path', type=str, default=VGG_PATH)
    parser.add_argument('--content-weight', type=float, default=CONTENT_WEIGHT)
    parser.add_argument('--style-weight', type=float, default=STYLE_WEIGHT)
    parser.add_argument('--tv-weight', type=float, default=TV_WEIGHT)
    parser.add_argument('--learning-rate', type=float, default=LEARNING_RATE)
    return parser

参数解析

def check_opts(opts):
    exists(opts.checkpoint_dir, "checkpoint dir not found!")
    exists(opts.style, "style path not found!")
    exists(opts.train_path, "train path not found!")
    if opts.test or opts.test_dir:
        exists(opts.test, "test img not found!")
        exists(opts.test_dir, "test directory not found!")
    exists(opts.vgg_path, "vgg network data not found!")
    assert opts.epochs > 0
    assert opts.batch_size > 0
    assert opts.checkpoint_iterations > 0
    assert os.path.exists(opts.vgg_path)
    assert opts.content_weight >= 0
    assert opts.style_weight >= 0
    assert opts.tv_weight >= 0
    assert opts.learning_rate >= 0

参数及路径检查

def _get_files(img_dir):
    files = list_files(img_dir)
    return [os.path.join(img_dir, x) for x in files]

文件处理函数

def main():
    parser = build_parser()
    options = parser.parse_args()
    check_opts(options)

    style_target = get_img(options.style)
    if not options.slow:
        content_targets = _get_files(options.train_path)
    elif options.test:
        content_targets = [options.test]

    kwargs = {
        "slow": options.slow,
        "epochs": options.epochs,
        "print_iterations": options.checkpoint_iterations,
        "batch_size": options.batch_size,
        "save_path": os.path.join(options.checkpoint_dir, 'fns.ckpt'),
        "learning_rate": options.learning_rate
    }

    if options.slow:
        if options.epochs < 10:
            kwargs['epochs'] = 1000
        if options.learning_rate < 1:
            kwargs['learning_rate'] = 1e1

    args = [
        content_targets,
        style_target,
        options.content_weight,
        options.style_weight,
        options.tv_weight,
        options.vgg_path
    ]

    for preds, losses, i, epoch in optimize(*args, **kwargs):
        style_loss, content_loss, tv_loss, loss = losses
        print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
        to_print = (style_loss, content_loss, tv_loss)
        print('style: %s, content:%s, tv: %s' % to_print)
        if options.test:
            assert options.test_dir != False
            preds_path = '%s/%s_%s.png' % (options.test_dir, epoch, i)
            if not options.slow:
                ckpt_dir = os.path.dirname(options.checkpoint_dir)
                evaluate.ffwd_to_img(options.test, preds_path, options.checkpoint_dir)
            else:
                pass

    ckpt_dir = options.checkpoint_dir
    cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
    print("Training complete. For evaluation:\n    `%s`" % cmd_text)

主函数,首先解析参数并进行检查。获取风格图像和内容图像(训练图像)。设置优化参数并调用 optimize 函数进行训练。每个迭代周期打印损失值,并在测试模式下保存测试结果图像。打印训练完成后的提示信息。

风格迁移处理

使用 PyQt5 创建的桌面应用程序,旨在实现基于深度学习的图像风格迁移。用户可以选择一张图片和一个风格模型,然后将选定的风格应用到该图片上。

import sys
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from evaluate import for_UI


# 调整图片大小
def shrinkImage(img_path, output_height):
    '''
    缩小图片
    :return:
    '''
    # 固定高度
    # scale = 0.8     #每次缩小20%
    img = QImage(img_path)  # 创建图片实例
    # print(img.width())
    scale = output_height / img.height()  # 缩放比例
    output_width = int(img.width() * scale)
    size = QSize(output_width, output_height)
    pixImg = QPixmap.fromImage(img.scaled(size, Qt.IgnoreAspectRatio))
    return pixImg


# 先是组件,组件完事之后再添加布局
class MainWindow(QTabWidget):
    def __init__(self):
        super().__init__()
        self.setWindowTitle('基于Fast StyleTransfer图像风格迁移系统')
        self.setWindowIcon(QIcon('images/logo.png'))
        self.resize(1600, 600)
        self.initUI()

    def initUI(self):
        # 主布局
        # addWidget可以添加方式,比如第一个位置是父窗口,第二个参数为间距,第三个参数为方式
        # QFont(字体样式,大小,加粗QFont.bold)
        main_widget = QWidget()
        about_widget = QWidget()
        generally_font = QFont('楷体', 20)
        main_layout = QHBoxLayout()

        # 左边窗口
        left_widget = QWidget()
        left_layout = QVBoxLayout()
        label_input = QLabel("原图")
        # label_input.setAlignment(Qt.AlignCenter)
        label_input.setFont(generally_font)
        self.img_input = QLabel()
        self.input_image_path = "images/chicago.jpg"
        self.img_input.setPixmap(QPixmap(shrinkImage(self.input_image_path, 300)))
        upload_btn = QPushButton(" 选择图片 ")
        upload_btn.setFont(generally_font)
        upload_btn.clicked.connect(self.chose_file)
        self.cb = QComboBox()
        self.cb.addItem("la_muse")
        self.cb.addItem("rain_princess")
        self.cb.addItem("scream")
        self.cb.addItem("udnie")
        self.cb.addItem("wave")
        self.cb.addItem("wreck")
        self.cb.setFont(generally_font)

        left_layout.addWidget(label_input, 0, Qt.AlignCenter | Qt.AlignTop)
        left_layout.addWidget(self.img_input, 0, Qt.AlignCenter)
        left_layout.addWidget(upload_btn, 0, Qt.AlignCenter)
        left_layout.addWidget(self.cb, 0, Qt.AlignCenter)
        left_widget.setLayout(left_layout)

        # 中间窗口
        middle_widget = QWidget()
        middle_layout = QVBoxLayout()
        label_title_2 = QLabel('迁移结果')
        label_title_2.setFont(generally_font)
        self.img_middle = QLabel("中间结果")
        # img_middle = QLabel()
        self.img_middle.setPixmap(QPixmap(shrinkImage('images/temp_transfer.jpg', 300)))
        btn_chong = QPushButton(" 风格迁移 ")
        btn_chong.setFont(generally_font)
        btn_chong.clicked.connect(self.change_style)
        xx = QLabel()
        xx.setFont(generally_font)

        middle_layout.addWidget(label_title_2, 0, Qt.AlignCenter | Qt.AlignTop)
        middle_layout.addWidget(self.img_middle, 0, Qt.AlignCenter)
        middle_layout.addWidget(xx, 0, Qt.AlignCenter)
        middle_layout.addWidget(btn_chong, 0, Qt.AlignCenter)
        middle_widget.setLayout(middle_layout)

        # 关于界面
        about_layout = QVBoxLayout()
        about_title = QLabel('欢迎使用风格迁移系统\n'
                             'QQ:718005487')
        about_title.setFont(QFont('楷体', 18))
        about_title.setAlignment(Qt.AlignCenter)
        about_img = QLabel()
        about_img.setPixmap(QPixmap('images/logo.png'))
        about_img.setAlignment(Qt.AlignCenter)
        label_super = QLabel()
        label_super.setText("<a href='http://yjxyzxyz.cn'>我的个人主页</a>")
        label_super.setFont(QFont('楷体', 12))
        label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)
        # git_img = QMovie('images/')
        about_layout.addWidget(about_title)
        about_layout.addStretch()
        about_layout.addWidget(about_img)
        about_layout.addStretch()
        about_layout.addWidget(label_super)
        about_widget.setLayout(about_layout)

        # 主页面设置
        main_layout.addWidget(left_widget)
        # main_layout.addStretch(0)
        main_layout.addWidget(middle_widget)
        # main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)
        self.addTab(main_widget, '主页面')
        self.addTab(about_widget, '关于')
        self.setTabIcon(0, QIcon('images/主页面.png'))
        self.setTabIcon(1, QIcon('images/关于.png'))

    def chose_file(self):
        print("选择图片")
        fname, _ = QFileDialog.getOpenFileNames(self, 'oepn file',
                                                'D:\\ubuntu\\transfer\\fast-style-transfer_lff\\images',
                                                "Image files(*.jpg *png)")
        # print(fname)
        if len(fname) > 0:
            self.input_image_path = fname[0]
            self.img_input.setPixmap(QPixmap(shrinkImage(self.input_image_path, 300)))
        # 改变图片

    def change_style(self):
        print("风格迁移")
        inpath = self.input_image_path
        checkdir = 'models/' + self.cb.currentText() + '.ckpt'
        out_path = 'images/temp_transfer.jpg'
        for_UI(inpath, out_path, checkdir)
        self.img_middle.setPixmap(QPixmap(shrinkImage('images/temp_transfer.jpg', 300)))

    def closeEvent(self, event):
        reply = QMessageBox.question(self,
                                     'quit',
                                     "Are you sure?",
                                     QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.close()
            event.accept()
        else:
            event.ignore()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    mainWindow = MainWindow()
    mainWindow.show()
    sys.exit(app.exec_())