Shark源码分析(十二):线性SVM
关于svm算法,这个在我关于机器学习的博客中已经描述的比较详实了,这里就不再赘述。svm主要有三种类型,这里我所介绍的是线性svm算法的代码。相较于使用核函数的svm算法,代码的整体框架应该是一样的,只是在对偶问题的求解上所使用的方法可能是不一样的。
LinearClassifier类
这个类所表示的是算法的决策平面,是一个多分类的线性分类模型。定义在<include/shark/Models/LinearClassifier.h>中。
template<class VectorType = RealVector>
class LinearClassifier : public ArgMaxConverter<LinearModel<VectorType> >
{
public:
LinearClassifier(){}
std::string name() const
{ return "LinearClassifier"; }
};
相当简单的一个类,并没有什么好说明的地方。
ArgMaxConverter类
该类是LinearClassifier的基类,其作用是将一个输出的向量通过arg_max操作转变为一个类标记,就是输出分量最大的那一维。该类定义在<include/shark/Models/Converter.h>。
template<class Model>
class ArgMaxConverter : public AbstractModel<typename Model::InputType, unsigned int>
{
private:
typedef typename Model::BatchOutputType ModelBatchOutputType;
public:
typedef typename Model::InputType InputType;
typedef unsigned int OutputType;
typedef typename Batch<InputType>::type BatchInputType;
typedef Batch<unsigned int>::type BatchOutputType;
ArgMaxConverter()
{ }
ArgMaxConverter(Model const& decisionFunction)
: m_decisionFunction(decisionFunction)
{ }
std::string name() const
{ return "ArgMaxConverter<"+m_decisionFunction.name()+">"; }
RealVector parameterVector() const{
return m_decisionFunction.parameterVector();
}
void setParameterVector(RealVector const& newParameters){
m_decisionFunction.setParameterVector(newParameters);
}
std::size_t numberOfParameters() const{
return m_decisionFunction.numberOfParameters();
}
Model const& decisionFunction()const{
return m_decisionFunction;
}
Model& decisionFunction(){
return m_decisionFunction;
}
// 计算输入数据的类标签
void eval(BatchInputType const& input, BatchOutputType& output)const{
ModelBatchOutputType modelResult;
m_decisionFunction.eval(input,modelResult);
std::size_t batchSize = shark::size(modelResult);
output.resize(batchSize);
if(modelResult.size2()== 1) //对于二分类的情况
{
for(std::size_t i = 0; i != batchSize; ++i){
// 如果输出大于0表示正类,否则为负类
output(i) = modelResult(i,0) > 0.0;
}
}
else{
for(std::size_t i = 0; i != batchSize; ++i){
output(i) = static_cast<unsigned int>(arg_max(row(modelResult,i)));
}
}
}
void eval(BatchInputType const& input, BatchOutputType& output, State& state)const{
eval(input,output);
}
void eval(InputType const & pattern, OutputType& output)const{
typename Model::OutputType modelResult;
m_decisionFunction.eval(pattern,modelResult);
if(modelResult.size()== 1){
output = modelResult(0) > 0.0;
}
else{
output = static_cast<unsigned int>(arg_max(modelResult));
}
}
void read(InArchive& archive){
archive >> m_decisionFunction;
}
void write(OutArchive& archive) const{
archive << m_decisionFunction;
}
private:
Model m_decisionFunction;
};
在LinearClassifier类的代码中,该模板类

本文详细解析了Shark库中的线性SVM算法实现,涵盖了LinearClassifier类、ArgMaxConverter类、AbstractLinearSvmTrainer类、LinearCSvmTrainer类及QpBoxLinear类等关键组件的设计与功能。
:线性SVM&spm=1001.2101.3001.5002&articleId=54743714&d=1&t=3&u=3dac4d4de9d942d8bf63f86228f53a0b)
650

被折叠的 条评论
为什么被折叠?



