Train_Identify/ai_matrix/framework/EngineManager.cpp

356 lines
11 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* engine管理实现
* */
#include "EngineManager.h"
#include "MyYaml.h"
#include "myutils.h"
#include <curl/curl.h>
namespace ai_matrix
{
EngineManager::EngineManager() {}
EngineManager::~EngineManager() {}
//初始化acl
APP_ERROR EngineManager::Init()
{
//初始化curl
if (1 == MyYaml::GetIns()->GetIntValue("gc_http_open") || 1 == MyYaml::GetIns()->GetIntValue("gc_ftp_open"))
{
curl_global_init(CURL_GLOBAL_ALL);
}
if (!InitDeviceIds())
{
LogError << "InitDeviceIds err";
return APP_ERR_COMM_INVALID_PARAM;
}
gLogger_ = new Logger;
#ifdef ASCEND
//AscendCL初始化
ResourceInfo resourceInfo_;
resourceInfo_.aclConfigPath = MyYaml::GetIns()->GetStringValue("gc_acl_path");
resourceInfo_.deviceIds = init_deviceIds_;
APP_ERROR ret = ResourceManager::GetInstance()->InitResource(resourceInfo_);
if (ret != APP_ERR_OK)
{
LogError << "Failed to InitResource, ret = " << ret;
return APP_ERR_COMM_FAILURE;
}
return ret;
#else
return APP_ERR_OK;
#endif
}
//去初始化acl
APP_ERROR EngineManager::DeInit(void)
{
//去初始化curl
if (1 == MyYaml::GetIns()->GetIntValue("gc_http_open") || 1 == MyYaml::GetIns()->GetIntValue("gc_ftp_open"))
{
curl_global_cleanup();
}
#ifdef ASCEND
//AscendCL去初始化
ResourceManager::GetInstance()->Release();
#endif
return APP_ERR_OK;
}
//加载yaml文件中的配置
APP_ERROR EngineManager::load_yaml_config(std::string path)
{
try
{
YAML::Node config = YAML::LoadFile(path);
//退出程序
if (config.IsNull())
{
LogError << "matrix.yaml err";
return APP_ERR_COMM_INVALID_PARAM;
}
//engine使用deviceid
mapUseDevice_["ALL"] = *init_deviceIds_.begin(); //默认所有engine使用初始化中最小deviceid
if(config["use_deviceid"].IsDefined())
{
for (YAML::const_iterator it = config["use_deviceid"].begin(); it != config["use_deviceid"].end(); it++)
{
std::string engineInfo = it->first.as<std::string>();
int deviceid = it->second.as<int>();
//使用deviceid必须是经过初始化的
if (init_deviceIds_.count(deviceid) == 0)
{
LogError << "use_deviceid set err value:" << deviceid;
return APP_ERR_COMM_INVALID_PARAM;
}
mapUseDevice_[engineInfo] = deviceid;
}
}
// //engine实例
for (YAML::const_iterator it = config["engines"].begin(); it != config["engines"].end(); it++)
{
std::string engine_name = it->first.as<std::string>();
int engine_id = it->second.as<int>();
//检查是否有重复engine
std::string engine_unique = engine_name + "_" + std::to_string(engine_id);
printf(engine_unique.c_str());
auto iter = engine_map_.find(engine_unique);
if (iter != engine_map_.end())
{
continue;
}
//实例化engine
std::shared_ptr<EngineBase> engineInstance = nullptr;
EngineBase* base = (static_cast<EngineBase*>(EngineFactory::MakeEngine(engine_name)));
if (base == nullptr)
{
continue;
}
engineInstance.reset(base);
//初始化engine
APP_ERROR ret = InitEngineInstance(engineInstance, engine_name, engine_id);
if (ret != APP_ERR_OK)
{
continue;
}
//存入map
engine_map_[engine_unique] = engineInstance;
}
//return APP_ERR_OK;
//engine连接
for (YAML::const_iterator it = config["connects"].begin(); it != config["connects"].end(); it++)
{
std::string from = it->first.as<std::string>();
std::string to = it->second.as<std::string>();
int iPos = to.find(" ");
int iQueueSize = 0;
if (iPos != std::string::npos)
{
iQueueSize = atoi(to.substr(iPos+1, to.length()).c_str());
to = to.substr(0, iPos);
}
std::size_t pos = from.find_last_of("_");
if (pos == std::string::npos)
{
continue;
}
std::string src_engine = from.substr(0, pos);
pos = to.find_last_of("_");
if (pos == std::string::npos)
{
continue;
}
std::string dst_engine = to.substr(0, pos);
auto iterSend = engine_map_.find(src_engine);
auto iterRecv = engine_map_.find(dst_engine);
if (iterSend == engine_map_.end() )
{
LogError << "Cann't find engine " << src_engine ;
continue;
}
if ( iterRecv == engine_map_.end())
{
LogError << "Cann't find engine " << dst_engine;
continue;
}
std::shared_ptr<MyQueue<std::shared_ptr<void>>> dataQueue = iterRecv->second->GetInputMap(to);
if (dataQueue == nullptr)
{
dataQueue = std::make_shared<MyQueue<std::shared_ptr<void>>>();
if (iQueueSize > 0)
{
dataQueue->setMaxSize(iQueueSize);
}
//设置engine输入队列
iterRecv->second->SetInputMap(to, dataQueue);
//设置engine输出队列
iterSend->second->SetOutputMap(from, dataQueue);
}
else
{
//设置engine输出队列
iterSend->second->SetOutputMap(from, dataQueue);
}
}
}
catch (...) //捕获所有异常
{
LogInfo<<"catch error";
return APP_ERR_COMM_INVALID_PARAM;
}
return APP_ERR_OK;
}
//初始化engine实例
APP_ERROR EngineManager::InitEngineInstance(std::shared_ptr<EngineBase> engineInstance, std::string engineName, int engineId)
{
LogInfo << "EngineManager: begin to init engine instance,name=" << engineName << ", engine id = " << engineId << ".";
//获取egnine使用的上下文
std::string engineInfo = engineName + "_" + std::to_string(engineId);
int deviceid;
if(mapUseDevice_.count(engineInfo) > 0)
{
deviceid = mapUseDevice_[engineInfo];
}
else if(mapUseDevice_.count(std::to_string(engineId)) > 0)
{
deviceid = mapUseDevice_[std::to_string(engineId)];
}
else
{
deviceid = mapUseDevice_["ALL"];
}
#ifdef ASCEND
aclContext_ = ResourceManager::GetInstance()->GetContext(deviceid);
#endif
EngineInitArguments initArgs;
initArgs.deviceId = deviceId_;
#ifdef ASCEND
initArgs.context = aclContext_;
initArgs.runMode = runMode_;
#endif
initArgs.engineName = engineName;
initArgs.engineId = engineId;
engineInstance->AssignInitArgs(initArgs); //填充参数
APP_ERROR ret = engineInstance->Init(); //执行初始化
if (ret != APP_ERR_OK)
{
LogError << "EngineManager: fail to init engine, name = " << engineName << ", engine id = " << engineId << ".";
return ret;
}
LogInfo << "EngineManager: engine " << engineName << "[" << engineId << "] init success.";
return ret;
}
//运行所有engine
APP_ERROR EngineManager::RunAllEngine()
{
LogInfo << "begin to run engine."<<engine_map_.size();
for (auto it = engine_map_.begin(); it != engine_map_.end(); it++)
{
it->second->setLoggerPoint(gLogger_);
it->second->Run();
LogInfo << "begin to run engine. end-------------------------";
}
return APP_ERR_OK;
}
//停止所有engine
APP_ERROR EngineManager::StopAllEngine()
{
LogInfo << "begin to stop engine.";
for (auto it = engine_map_.begin(); it != engine_map_.end(); it++)
{
it->second->Stop();
LogInfo << "begin to stop engine. end-------------------------";
}
return APP_ERR_OK;
}
//得到engine指针
EngineBase *EngineManager::get_engine(std::string engineName)
{
auto iter = engine_map_.find(engineName);
if (iter == engine_map_.end())
{
return nullptr;
}
return iter->second.get();
}
/**
* 初始化设备id
* inParam : N/A
* outParam: N/A
* return : true(成功);false(失败)
*/
bool EngineManager::InitDeviceIds()
{
#ifdef ASCEND
//1.运行模型为device侧deviceid设为0
APP_ERROR ret = aclrtGetRunMode(&runMode_);
if (ret != APP_ERR_OK)
{
LogError << "ModuleManager: fail to get run mode of device, ret=" << ret << ".";
return false;
}
if (runMode_ == ACL_DEVICE)
{
init_deviceIds_.insert(0);
return true;
}
//2.获取可用Device的数量
uint32_t device_count = 0;
ret = aclrtGetDeviceCount(&device_count);
if (ret != APP_ERR_OK)
{
LogError << "aclrtGetDeviceCount failed, ret:" << ret;
return false;
}
//3.设置需初始化的Device号
std::string deviceids = MyYaml::GetIns()->GetStringValue("gc_init_deviceid");
if (deviceids == "ALL" || deviceids.empty())
{
for (int i = 0; i < device_count; ++i)
{
init_deviceIds_.insert(i);
}
return true;
}
std::string delimiter(",");
std::vector<std::string> splits = MyUtils::getins()->split(deviceids, delimiter);
for (int i = 0; i < splits.size(); i++)
{
int oneDeviceId = atoi(splits[i].c_str());
//config.yaml配置有误
if (oneDeviceId < 0 || oneDeviceId >= device_count)
{
LogError<<"config.yaml init_deviceid err value:"<<deviceids;
return false;
}
init_deviceIds_.insert(oneDeviceId);
}
#else
init_deviceIds_.insert(0);
#endif
return true;
}
}