113 lines
3.0 KiB
C++
113 lines
3.0 KiB
C++
|
|
#include "EngineBase.h"
|
||
|
|
|
||
|
|
namespace ai_matrix
|
||
|
|
{
|
||
|
|
//初始化engine
|
||
|
|
void EngineBase::AssignInitArgs(const EngineInitArguments &initArgs)
|
||
|
|
{
|
||
|
|
deviceId_ = initArgs.deviceId;
|
||
|
|
|
||
|
|
#ifdef ASCEND
|
||
|
|
aclContext_ = initArgs.context;
|
||
|
|
runMode_ = initArgs.runMode;
|
||
|
|
#endif
|
||
|
|
|
||
|
|
engineName_ = initArgs.engineName;
|
||
|
|
engineId_ = initArgs.engineId;
|
||
|
|
|
||
|
|
#ifdef ASCEND
|
||
|
|
APP_ERROR ret;
|
||
|
|
ret = aclrtSetCurrentContext(aclContext_);
|
||
|
|
if (ret != APP_ERR_OK)
|
||
|
|
{
|
||
|
|
LogError << "Fail to set context for " << engineName_ << "[" << engineId_ << "]"
|
||
|
|
<< ", ret=" << ret << "(" << GetAppErrCodeInfo(ret) << ").";
|
||
|
|
}
|
||
|
|
#endif
|
||
|
|
}
|
||
|
|
|
||
|
|
void EngineBase::setLoggerPoint(Logger *gLogger_in)
|
||
|
|
{
|
||
|
|
// gLogger_obj = gLogger_in;
|
||
|
|
}
|
||
|
|
|
||
|
|
// get the data from input queue then call Process function in the new thread
|
||
|
|
void EngineBase::ProcessThread()
|
||
|
|
{
|
||
|
|
#ifdef ASCEND
|
||
|
|
APP_ERROR ret;
|
||
|
|
ret = aclrtSetCurrentContext(aclContext_);
|
||
|
|
if (ret != APP_ERR_OK)
|
||
|
|
{
|
||
|
|
LogError << "Fail to set context for " << engineName_ << "[" << engineId_ << "]"
|
||
|
|
<< ", ret=" << ret << "(" << GetAppErrCodeInfo(ret) << ").";
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
#endif
|
||
|
|
|
||
|
|
Process();
|
||
|
|
}
|
||
|
|
|
||
|
|
//设置engine输入队列
|
||
|
|
void EngineBase::SetInputMap(std::string engineAddress, std::shared_ptr<MyQueue<std::shared_ptr<void>>> inputQueue)
|
||
|
|
{
|
||
|
|
inputQueMap_[engineAddress] = inputQueue;
|
||
|
|
}
|
||
|
|
//得到输入队列
|
||
|
|
std::shared_ptr<MyQueue<std::shared_ptr<void>>> EngineBase::GetInputMap(std::string engineAddress)
|
||
|
|
{
|
||
|
|
if (inputQueMap_.find(engineAddress) == inputQueMap_.end())
|
||
|
|
{
|
||
|
|
return nullptr;
|
||
|
|
}
|
||
|
|
return inputQueMap_.at(engineAddress);
|
||
|
|
}
|
||
|
|
|
||
|
|
//设置engine输出队列
|
||
|
|
void EngineBase::SetOutputMap(std::string engineAddress, std::shared_ptr<MyQueue<std::shared_ptr<void>>> outputQue)
|
||
|
|
{
|
||
|
|
outputQueMap_[engineAddress] = outputQue;
|
||
|
|
}
|
||
|
|
|
||
|
|
//启动engine
|
||
|
|
APP_ERROR EngineBase::Run()
|
||
|
|
{
|
||
|
|
LogInfo << engineName_ << "[" << engineId_ << "] Run";
|
||
|
|
isStop_ = false;
|
||
|
|
processThr_ = std::thread(&EngineBase::ProcessThread, this);
|
||
|
|
return APP_ERR_OK;
|
||
|
|
}
|
||
|
|
|
||
|
|
// 停止engine
|
||
|
|
APP_ERROR EngineBase::Stop()
|
||
|
|
{
|
||
|
|
#ifdef ASCEND
|
||
|
|
//设置上下文
|
||
|
|
APP_ERROR ret = aclrtSetCurrentContext(aclContext_);
|
||
|
|
if (ret != APP_ERR_OK)
|
||
|
|
{
|
||
|
|
LogError << "ModuleManager: fail to set context, ret[%d]" << ret << ".";
|
||
|
|
return ret;
|
||
|
|
}
|
||
|
|
#endif
|
||
|
|
|
||
|
|
//停止线程
|
||
|
|
isStop_ = true;
|
||
|
|
|
||
|
|
//停止所有输入队列
|
||
|
|
for (auto it = inputQueMap_.begin(); it != inputQueMap_.end(); it++)
|
||
|
|
{
|
||
|
|
it->second->stop();
|
||
|
|
}
|
||
|
|
|
||
|
|
//等待线程结束
|
||
|
|
if (processThr_.joinable())
|
||
|
|
{
|
||
|
|
processThr_.join();
|
||
|
|
}
|
||
|
|
|
||
|
|
//其他清理
|
||
|
|
return DeInit();
|
||
|
|
}
|
||
|
|
}
|