通过预训练模型初步了解AI和pytorch

2 minute read

Published:

通过预训练模型了解pytorch

GauGAN模型为例,该模型是nVidia团队对语义图像合成的一个研究。该模型可以将用户通过基于语义的画笔画成的图像生成一个相对真实的图像。

(GAN(Generative Adversarial Networks):生成对抗网络,是一类在无监督学习中使用的神经网络,主要是两个模型:生成器generator和判别器discriminator,生成器生成图像交由判别器判别和分辨。生成器的任务是尽可能的让判别器将它生成的图像判定为真实的,而判别器的任务是引导生成器生成更真实的图像)

网上有预训练好的模型权重,以.pth格式的文件保存。我们下载下来保存在/pretrained/文件夹中。

接下来展示一个自内向外的加载预训练模型并使用的过程。

加载模型前:

加载pretrained模型前,需要先获取定义好的模型的结构。通常一个模型的结构如下面的代码所示。

class SPADEGenerator(nn.Module):
  def __init__(self, opt):
    super().__init__()

    nf = 64

    self.sw, self.sh = self.compute_latent_vector_size(opt['crop_size'], opt['aspect_ratio'])

    self.fc = nn.Conv2d(opt['label_nc'], 16 * nf, 3, padding=1)

    self.head_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)

    self.G_middle_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
    self.G_middle_1 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)

    self.up_0 = SPADEResnetBlock(opt, 16 * nf, 8 * nf)
    self.up_1 = SPADEResnetBlock(opt, 8 * nf, 4 * nf)
    self.up_2 = SPADEResnetBlock(opt, 4 * nf, 2 * nf)
    self.up_3 = SPADEResnetBlock(opt, 2 * nf, 1 * nf)

    self.conv_img = nn.Conv2d(1 * nf, 3, 3, padding=1)

    self.up = nn.Upsample(scale_factor=2)

  def forward(self, seg):
    x = Func.interpolate(seg, size=(self.sh, self.sw))
    x = self.fc(x)

    x = self.head_0(x, seg)

    x = self.up(x)
    x = self.G_middle_0(x, seg)
    x = self.G_middle_1(x, seg)

    x = self.up(x)
    x = self.up_0(x, seg)
    x = self.up(x)
    x = self.up_1(x, seg)
    x = self.up(x)
    x = self.up_2(x, seg)
    x = self.up(x)
    x = self.up_3(x, seg)

    x = self.conv_img(Func.leaky_relu(x, 2e-1))
    x = torch.tanh(x)

    return x

模型对象继承自nn.Module,在__init__()中定义该模型网络的结构,在forward中定义网络传播的逻辑,即我们实例化一个模型对象后再次调用时则调用模型的forward方法。(由于我们使用的是预训练模型,不涉及到模型的训练,实际上该类中中还有train和test方法,在train方法中包含计算loss、反向传播等操作,在test方法中检验模型训练效果。)

加载模型:

由于我们保存的是模型的权重,为了加载模型参数,我们需要对该模型定义一个对象,与前述代码类似,但我们不需要定义网络结构,因为该类是将pth文件中的权重导入到SPADEGenerator中。而如何将pth导入到SPADEGenerator中则是我们在 initialize_networks 方法中要做的。

class LoadedModel(nn.Module):
  def __init__(self, opt):
    super().__init__()
    self.opt = opt
    self.FloatTensor = torch.cuda.FloatTensor if opt['use_gpu'] \
      else torch.FloatTensor
    # 初始化网络方法,具体实现见下面的定义
    self.netG = self.initialize_networks(opt)
  
  def forward(self, data, mode):
    # preprocess_input是对输入的图像data信息进行处理,这里不展示具体实现
    input_semantics, real_image = self.preprocess_input(data)
    # 实际上mode有多种选项,由于这里直接使用预训练的模型机型推理所以只展示inference选项
    if mode == 'inference':
      with torch.no_grad():
        fake_image = self.netG(input_semantics)
      return fake_image
    else:
      raise ValueError("mode err")

  # 初始化神经网络
  def initialize_networks(self, opt):
    # 将输入的opt载入预定义模型
    netG = SPADEGenerator(opt)

    # 为预定义模型配置初始权重,init_weights为初始权重,此处省略获取过程
    netG.apply(init_weights)

    # 使用cuda
    if self.opt['use_gpu']:
      netG.cuda()
    
    # isTrain选项为False表示不是训练模型,实际上这里有多种判断条件,我们只使用预训练模型所以不做展示
    if not opt['isTrain']:
      weights = torch.load("xxx.pth")
      netG.load_state_dict(weights)
    
    return netG

由以上代码可知,在initialize_networks中,我们先用load方法加载模型参数,然后使用load_state_dict方法将参数加载到模型中。 模型的保存与加载都涉及到了state_dict 这个方法,我们模型的参数都是以字典的方式保存起来的。

使用模型:

opt = {
        'label_nc': 182, 
        'crop_size': 512,
        'load_size': 512,
        'aspect_ratio': 1.0,
        'isTrain': False,
        'use_gpu': True
      }
model = LoadedModel(opt)
model.eval()

这里我们终于看到了前面代码里频繁出现的opt,我们在使用模型时都是通过opt参数控制模型调用的方式和过程。这里只是做了简化处理,去掉了本次实验里用不到的选项,实际上选项非常多,并且通常单独保存在一个文件中。

通过以上代码加载好模型后,就可以进行推理并得到结果了。

# 将输入input传入model中,结果返回为output
output = model(input, mode='inference')

部署模型:

为了将模型作为服务为用户使用,我们可以使用flask作为服务端部署该模型。

flask部署模型是提供模型服务最简单的方法,但是受限于性能(通常需要一个有gpu的服务器)这个方法比较适合简单的人工智能服务。还有多种部署的方式比如:

使用flask部署模型非常简单。在目录下安装一个flask包

pip install Flask

在目录下的app.py中启动一个简单的服务,其中index.html是我们服务的页面,在这里用户可以通过画笔进行绘画(这里前端代码不作展示)。

from flask import Flask, current_app
app = Flask(__name__)

import io
import numpy as np

@app.route('/')
def index():
    return current_app.send_static_file('index.html')

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=6006)

随后输入 python app.py 则会启动该服务,我们在浏览器中输入 localhost:6006 就可以访问该服务了。

然后我们需要在flask中引入我们的模型,并暴露一个接口给前端调用。前端发送一个post请求到/generate接口,请求体中携带语义图像信息。

import model_inst # import evaluate, to_image

@app.route('/generate', methods = ['POST'])
def generate():
    # 首先获取请求中的信息
    labelmap = np.asarray(request.json)
    # 处理传输的信息为模型接受的信息,此处不做具体展示,抽象为一个handle_lablemap函数
    input = handle_lablemap(lablemap)
    # 加载模型
    model = LoadedModel(opt)
    model.eval()
    # 将处理好的输入丢入模型中
    output = model(input, mode='inference')
    # 处理模型的输出将其转换为图片
    image = ToPILImage(output)
    # 对图片进行处理
    file_object = io.BytesIO()
    image.save(file_object, 'PNG')
    file_object.seek(0)
    # 接口返回该图片
    return send_file(file_object, mimetype='image/PNG')

总结

以上就是在pytorch中加载预训练模型并部署作为服务使用的过程,过程较为精简,专注过程的核心,有很多代码如处理输入图像和网络模型配置的一些细节都没有展示。