pytorch版本PSEnet训练并部署方式(pytorch训练好的模型如何部署)深度揭秘

随心笔谈12个月前发布 admin
95 0

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
“””Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
“””
#assert os.path.exists(img), ‘file is not exists’
#img=cv2.imread(img)
img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# h, w=image.shape[:2]
# scale=long_size / max(h, w)
img=cv2.resize(img, target_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor=transforms.ToTensor()(img)
tensor=tensor.unsqueeze_(0)
tensor=tensor.to(torch.device(“cuda:0″))
return tensor
def report_speed(outputs, speed_meters):
total_time=0
for key in outputs:
if ‘time’ in key:
total_time +=outputs[key]
speed_meters[key].update(outputs[key])
print(‘%s: %.4f’ % (key, speed_meters[key].avg))
speed_meters[‘total_time’].update(total_time)
print(‘FPS: %.1f’ % (1.0 / speed_meters[‘total_time’].avg))
def load_model(cfg):
model=build_model(cfg.model)
model=model.cuda()
model.eval()
checkpoint=”psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar”
if checkpoint is not None:
if os.path.isfile(checkpoint):
print(“Loading model and optimizer from checkpoint ‘{}'”.format(checkpoint))
sys.stdout.flush()
checkpoint=torch.load(checkpoint)
d=dict()
for key, value in checkpoint[‘state_dict’].items():
tmp=key[7:]
d[tmp]=value
model.load_state_dict(d)
else:
print(“No checkpoint found at”)
raise
# fuse conv and bn
model=fuse_module(model)
return model
if __name__==’__main__’:
src_dir=”testimg/”
save_dir=”test_save/”
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg=Config.fromfile(“PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py”)
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
if cfg.report_speed:
speed_meters=dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
model=load_model(cfg)
model.eval()
count=0
for img_name in os.listdir(src_dir):
img=cv2.imread(src_dir + img_name)
tensor=prepare_image(img, target_size=(1376, 1024))
data=dict()
img_metas=dict()
data[‘imgs’]=tensor
img_metas[‘org_img_size’]=torch.tensor([[img.shape[0], img.shape[1]]])
img_metas[‘img_size’]=torch.tensor([[1376, 1024]])
data[‘img_metas’]=img_metas
data.update(dict(
cfg=cfg
))
with torch.no_grad():
outputs=model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
for bboxes in outputs[‘bboxes’]:
x1=bboxes[0]
y1=bboxes[1]
x2=bboxes[4]
y2=bboxes[5]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
count=count + 1
cv2.imwrite(save_dir + img_name, img)
print(“img test:”, count)

© 版权声明

相关文章