AnimeGANv2 onnx 模型调用尝试

系列仓库地址:https://github.com/xuanhao44/AnimeGANv2


在 v2 的代码库中,作者提供了新海诚风格的训练好的 onnx 模型,还有相应的 test 代码和样例。这样就可以不经过 TensorFlow 运行。

0 服务器

带显卡的服务器:RTX A4000。

1 test_by_onnx.py

这里仅需要修改最后面的 input_imgs_path,也就是样例图片路径。

input_imgs_path = '../dataset/test/HR_photo'

到该目录下运行:

conda activate /cloud/newanime
cd AnimeGANv2/pb_and_onnx_model
python test_by_onnx.py

2 编写 onnx_video2anime.py

video2anime.py 的基础上,模仿 test_by_onnx.py

需要在前篇的基础上卸载 onnxruntime,并安装 onnxruntime-gpu

pip uninstall onnxruntime
pip install --user onnxruntime-gpu

onnx_video2anime.py

更新:文档中代码可能不是最新版本的,最新请参考:https://github.com/xuanhao44/AnimeGANv2/blob/main/onnx_video2anime.py

import argparse
import os
import cv2
from tqdm import tqdm
import numpy as np
import onnxruntime as ort

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def parse_args():
    desc = "Tensorflow implementation of AnimeGANv2"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--video', type=str, default='video/input/'+ '2.mp4',
                        help='video file or number for webcam')
    parser.add_argument('--output', type=str, default='video/output/' + 'Paprika',
                        help='output path')
    parser.add_argument('--model', type=str, default='Shinkai',
                        help='model name')
    parser.add_argument('--onnx', type=str, default='pb_and_onnx_model/Shinkai_53.onnx',
                        help='path of onnx')
    parser.add_argument('--output_format', type=str, default='mp4v',
                        help='codec used in VideoWriter when saving video to file')
    return parser.parse_args()


def check_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)
    return path

def process_image(img, x32=True):
    h, w = img.shape[:2]
    if x32: # resize image to multiple of 32s
        def to_32s(x):
            return 256 if x < 256 else x - x%32
        img = cv2.resize(img, (to_32s(w), to_32s(h)))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)/ 127.5 - 1.0
    return img

def post_precess(img, wh):
    img = (img.squeeze() + 1.) / 2 * 255
    img = img.astype(np.uint8)
    img = cv2.resize(img, (wh[0], wh[1]))
    return img

def cvt2anime_video(video_filepath, output, model, onnx = 'model.onnx', output_format='mp4v'):  # 小写就不报错了,只是仍然无法在浏览器上播放

    # check onnx model
    exists = os.path.isfile(onnx)
    if not exists:
        print('Model file not found:', onnx)
        return

    # 加载模型,若有 GPU, 则用 GPU 推理
    # https://zhuanlan.zhihu.com/p/645720587
    # 慎入!https://zhuanlan.zhihu.com/p/492040015
    if ort.get_device()=='GPU':
        print('use gpu')
        providers = ['CUDAExecutionProvider','CPUExecutionProvider',]
        session = ort.InferenceSession(onnx, providers=providers)
        session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}]) #gpu 0
    else:
        print('use cpu')
        providers = ['CPUExecutionProvider',]
        session = ort.InferenceSession(onnx, providers=providers)

    input_name = session.get_inputs()[0].name

    # load video
    vid = cv2.VideoCapture(video_filepath)
    vid_name = os.path.basename(video_filepath)  # 只取文件名
    # https://blog.csdn.net/lsoxvxe/article/details/131999217
    # https://pythonjishu.com/python-os-28/
    total = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vid.get(cv2.CAP_PROP_FPS)
    width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
    codec = cv2.VideoWriter_fourcc(*output_format)

    # 输出视频名称、路径
    video_out_name = vid_name.rsplit('.', 1)[0] + '_' + model + '.mp4'
    video_out_path = os.path.join(output, video_out_name)

    video_out = cv2.VideoWriter(video_out_path, codec, fps, (width, height))

    pbar = tqdm(total=total, ncols=80)
    pbar.set_description(f"Making: {video_out_name}")

    while True:
        ret, frame = vid.read()
        if not ret:
            break
        frame = np.expand_dims(process_image(frame),0)
        fake_img = session.run(None, {input_name : frame})
        fake_img = post_precess(fake_img[0], (width, height))
        video_out.write(cv2.cvtColor(fake_img, cv2.COLOR_BGR2RGB))
        pbar.update(1)

    pbar.close()
    vid.release()
    video_out.release()
    return video_out_path

if __name__ == '__main__':
    # python onnx_video2anime.py --video video/input/お花見.mp4 --output video/output --model Shinkai --onnx pb_and_onnx_model/Shinkai_53.onnx
    arg = parse_args()
    check_folder(arg.output)
    info = cvt2anime_video(arg.video, arg.output, arg.model, arg.onnx)
    print(f'output video: {info}')

改动:

  1. 重写了 parse_args 和主函数,意在修改参数。下面是相应命令。
python onnx_video2anime.py --video video/input/お花見.mp4 --output video/output --model Shinkai --onnx pb_and_onnx_model/Shinkai_53.onnx
  1. 大幅度修改了 cvt2anime_video。详细参考和注释都放在了函数中。主要就是调用 onnx 模型。但其中有一处比较折磨人的部分:
(/cloud/newanime) ➜  pb_and_onnx_model git:(main) ✗ pip install --user onnxruntime-gpu==1.1.0
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
ERROR: Could not find a version that satisfies the requirement onnxruntime-gpu==1.1.0 (from versions: 1.2.0, 1.3.0, 1.4.0, 1.5.1, 1.5.2, 1.6.0, 1.7.0, 1.8.0, 1.8.1, 1.9.0, 1.10.0, 1.11.0, 1.11.1, 1.12.0, 1.12.1, 1.13.1, 1.14.0, 1.14.1, 1.15.0, 1.15.1, 1.16.0, 1.16.1)
ERROR: No matching distribution found for onnxruntime-gpu==1.1.0

我按照上一次的经验来编译:

git clone --recursive https://github.com/microsoft/onnxruntime.git
cd onnxruntime
git reset --hard c33dab394f4984ab28a370d96c373f0fca84d826
export CMAKE_ARGS="-DWITH_FREETYPE=ON"
export ENABLE_CONTRIB=1
export ENABLE_HEADLESS=0
export MAKEFLAGS="-j $(($(nproc)-1))"
pip wheel . --verbose

可是同样遇到了问题,并且无法解决,最后放弃,百般无奈的就随便使用了 pip install --user onnxruntime-gpu,没想到居然就成功了!我不懂这是为什么。

虽然网上确实很多文章提到了需要版本对应,但是实在无法解决的时候还是要挣扎一把 hh。

3 导入 v3 onnx 模型

AnimeGANv3: https://github.com/TachibanaYoshino/AnimeGANv3

v3 的仓库里有展示用的 exe 文件和两个加密的 onnx 模型(均为新海诚风格)压缩包,另外还有未被加密的两个模型。

一开始以为 v3 没有提供代码,但是细细想来,这种载入 onnx 的模式并没有大变化,那么 v2 的代码也能套到 v3 的 onnx 模型上。

根据其下 issue 所说:https://github.com/TachibanaYoshino/AnimeGANv3/issues/2

exe 文件使用 PyInstaller 打包,可以在 test_by_onnxZIP 的中看到硬编码的密码。至于是多少就算了,可以自己看教程破解:https://blog.csdn.net/as604049322/article/details/119834495

拆完之后发现作者写的部分基本和我推测的一样,确实可以直接套用。

之后大致修正了一下 onnx_video2anime.py,将 onnx 模型放在 pb_and_onnx_model/ 目录下。


最后同步的写了 onnx_app.py,直接给出链接:

https://github.com/xuanhao44/AnimeGANv2/blob/main/onnx_app.py

4 一个尝试

注意到作者在 hugging face 上放了一个 demo:https://huggingface.co/spaces/TachibanaYoshino/AnimeGANv3

观察文件目录可知这是把一部分 py 文件打包,然后通过动态链接的形式导入到 app.py 中。

v3_demo

观察其 app.py 中动态链接的使用,尝试推测用法。

以之前 3 中拆包得到的 test_by_onnxZIP.py 来推测,把所有函数调用去掉,然后改变一些函数的顺序,可以推测出链接部分对应哪些语句。

下面 start end 中间的部分便是 output = AnimeGANv3_src.Convert(img, f, det_face) 对应的代码。

        sample_image = cv2.imread(sample_file)

        sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 127.5 - 1.0

        # start

        h, w = sample_image.shape[:2]

        def to_32s(x):
            if x < 256:
                return 256
            return x - x % 32
        sample_image = cv2.resize(sample_image, (to_32s(w), to_32s(h)))

        sample_image = np.asarray(np.expand_dims(sample_image, 0))

        session = ort.InferenceSession(onnx, providers=['CPUExecutionProvider'])

        fake_img = session.run(None, {session.get_inputs()[0].name: sample_image})

        fake_img[0] = (fake_img[0].squeeze() + 1.0) / 2 * 255
        fake_img[0] = fake_img[0].astype(np.uint8)
        fake_img[0] = cv2.resize(fake_img[0], [w, h])

        #end

        image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))

        cv2.imwrite(image_path, cv2.cvtColor(fake_img[0], cv2.COLOR_RGB2BGR))

同理可以推测 video_test_by_onnxZIP.py

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 127.5 - 1.0

        # start

        h, w = frame.shape[:2]
        def to_32s(x):
            if x < 256:
                return 256
            return x - x % 32
        frame = cv2.resize(frame, (to_32s(w), to_32s(h)))

        frame = np.asarray(np.expand_dims(frame, 0))

        session = ort.InferenceSession(onnx, providers=['CPUExecutionProvider'])

        fake_img = session.run(None, {session.get_inputs()[0].name: frame})

        fake_img[0] = (fake_img[0].squeeze() + 1.0) / 2 * 255
        fake_img[0] = fake_img[0].astype(np.uint8)
        fake_img[0] = cv2.resize(fake_img[0], (w, h))

        # end

        video_out.write(cv2.cvtColor(fake_img[0], cv2.COLOR_RGB2BGR))

所以可以写出 so_video2anime.py,也调用他的动态链接:

import argparse
import os
import cv2
from tqdm import tqdm
import numpy as np
import AnimeGANv3_src

def parse_args():
    desc = "Tensorflow implementation of AnimeGANv2"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--video', type=str, default='video/input/'+ '2.mp4',
                        help='video file or number for webcam')
    parser.add_argument('--output', type=str, default='video/output/' + 'Paprika',
                        help='output path')
    parser.add_argument('--style', type=str, default='Shinkai',
                        help='style name')
    parser.add_argument('--output_format', type=str, default='mp4v',
                        help='codec used in VideoWriter when saving video to file')
    return parser.parse_args()


def check_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)
    return path

def cvt2anime_video(video_path, output, style, output_format='mp4v'):  # 小写就不报错了,只是仍然无法在浏览器上播放
    print(video_path, style)

    # load video
    video_in = cv2.VideoCapture(video_path)
    video_in_name = os.path.basename(video_path)  # 只取文件名

    total = int(video_in.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(video_in.get(cv2.CAP_PROP_FPS))
    width = int(video_in.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video_in.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*output_format)

    # 输出视频名称、路径
    video_out_name = video_in_name.rsplit('.', 1)[0] + '_' + style + '.mp4'
    video_out_path = os.path.join(output, video_out_name)

    video_out = cv2.VideoWriter(video_out_path, fourcc, fps, (width, height))

    pbar = tqdm(total=total, ncols=80)
    pbar.set_description(f"Making: {video_out_name}")

    while True:
        ret, frame = video_in.read()
        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 127.5 - 1.0

        if style == "AnimeGANv3_Arcane":
            f = "A"
        elif style == "AnimeGANv3_Trump v1.0":
            f = "T"
        elif style == "AnimeGANv3_Shinkai":
            f = "S"
        elif style == "AnimeGANv3_PortraitSketch":
            f = "P"
        elif style == "AnimeGANv3_Hayao":
            f = "H"
        elif style == "AnimeGANv3_Disney v1.0":
            f = "D"
        elif style == "AnimeGANv3_JP_face v1.0":
            f = "J"
        else:
            f = "U"

        output = AnimeGANv3_src.Convert(frame, f, False)

        video_out.write(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
        pbar.update(1)

    pbar.close()
    video_in.release()
    video_out.release()

    return video_out_path

if __name__ == '__main__':

    # v3

    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_Arcane
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_Trump v1.0
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_Shinkai
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_PortraitSketch
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_Hayao
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_Disney v1.0
    # python so_video2anime.py --video examples/2.mp4 --output output --style AnimeGANv3_JP_face v1.0

    arg = parse_args()
    check_folder(arg.output)
    info = cvt2anime_video(arg.video, arg.output, arg.style)
    print(f'output video: {info}')

结果是可行的,这令人开心。

但有一个问题,他在制作 so 文件的时候还是使用的是 onnx 模型的 cpu 推理 CPUExecutionProvider。这一点从他的依赖项有 onnxruntime 而没有 onnxruntime-gpu 也能看出。处理图片尚没问题,但是处理视频就有些麻烦了。

尝试破解 so 文件也是极为困难的,看到依赖项中有 pyarmorpycryptodome 两个加密的包就知道反编译是很难的。在使用 IDA Pro + 010 Editor 尝试多次之后放弃。

5 作者 patreon

https://www.patreon.com/Asher_Chan

花钱可以得到模型,但是好贵啊。

另外,一个个翻找仓库的 commit,发现一个 colab。

最后修改:2023 年 10 月 26 日
如果觉得我的文章对你有用,请随意赞赏