Shark源码分析(十二):线性SVM

本文详细解析了Shark库中的线性SVM算法实现,涵盖了LinearClassifier类、ArgMaxConverter类、AbstractLinearSvmTrainer类、LinearCSvmTrainer类及QpBoxLinear类等关键组件的设计与功能。

Shark源码分析(十二):线性SVM

关于svm算法,这个在我关于机器学习的博客中已经描述的比较详实了,这里就不再赘述。svm主要有三种类型,这里我所介绍的是线性svm算法的代码。相较于使用核函数的svm算法,代码的整体框架应该是一样的,只是在对偶问题的求解上所使用的方法可能是不一样的。

LinearClassifier类

这个类所表示的是算法的决策平面,是一个多分类的线性分类模型。定义在<include/shark/Models/LinearClassifier.h>中。

template<class VectorType = RealVector>
class LinearClassifier : public ArgMaxConverter<LinearModel<VectorType> >
{
public:
    LinearClassifier(){}

    std::string name() const
    { return "LinearClassifier"; }
};

相当简单的一个类,并没有什么好说明的地方。

ArgMaxConverter类

该类是LinearClassifier的基类,其作用是将一个输出的向量通过arg_max操作转变为一个类标记,就是输出分量最大的那一维。该类定义在<include/shark/Models/Converter.h>

template<class Model>
class ArgMaxConverter : public AbstractModel<typename Model::InputType, unsigned int>
{
private:
    typedef typename Model::BatchOutputType ModelBatchOutputType;
public:
    typedef typename Model::InputType InputType;
    typedef unsigned int OutputType;
    typedef typename Batch<InputType>::type BatchInputType;
    typedef Batch<unsigned int>::type BatchOutputType;

    ArgMaxConverter()
    { }
    ArgMaxConverter(Model const& decisionFunction)
    : m_decisionFunction(decisionFunction)
    { }

    std::string name() const
    { return "ArgMaxConverter<"+m_decisionFunction.name()+">"; }

    RealVector parameterVector() const{
        return m_decisionFunction.parameterVector();
    }

    void setParameterVector(RealVector const& newParameters){
        m_decisionFunction.setParameterVector(newParameters);
    }

    std::size_t numberOfParameters() const{
        return m_decisionFunction.numberOfParameters();
    }

    Model const& decisionFunction()const{
        return m_decisionFunction;
    }

    Model& decisionFunction(){
        return m_decisionFunction;
    }

    // 计算输入数据的类标签
    void eval(BatchInputType const& input, BatchOutputType& output)const{

        ModelBatchOutputType modelResult;
        m_decisionFunction.eval(input,modelResult);
        std::size_t batchSize = shark::size(modelResult);
        output.resize(batchSize);
        if(modelResult.size2()== 1) //对于二分类的情况
        {
            for(std::size_t i = 0; i != batchSize; ++i){
  
  // 如果输出大于0表示正类,否则为负类
                output(i) = modelResult(i,0) > 0.0;
            }
        }
        else{
            for(std::size_t i = 0; i != batchSize; ++i){
                output(i) = static_cast<unsigned int>(arg_max(row(modelResult,i)));
            }
        }
    }

    void eval(BatchInputType const& input, BatchOutputType& output, State& state)const{
        eval(input,output);
    }

    void eval(InputType const & pattern, OutputType& output)const{
        typename Model::OutputType modelResult;
        m_decisionFunction.eval(pattern,modelResult);
        if(modelResult.size()== 1){
            output = modelResult(0) > 0.0;
        }
        else{
            output = static_cast<unsigned int>(arg_max(modelResult));
        }
    }

    void read(InArchive& archive){
        archive >> m_decisionFunction;
    }

    void write(OutArchive& archive) const{
        archive << m_decisionFunction;
    }

private:
    Model m_decisionFunction;
};

在LinearClassifier类的代码中,该模板类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值