word2Vec实现demo版本

package com.zishi.aiwer.fenci;

import java.util.*;

public class SimpleWord2Vec {

    // --- 1. 全局配置参数 ---
    private static final int EMBEDDING_DIM = 10; // 词向量维度 (对应每个词向量有10个数字)
    private static final int WINDOW_SIZE = 2;    // 滑动窗口大小 (前后看2个词)
    private static final double LEARNING_RATE = 0.05; // 学习率 (迈步的步伐)
    private static final int EPOCHS = 150;       // 整个文本刷多少遍 (刷得越多越精准)

    // --- 2. 模型核心组件 ---
    private List<String> vocab = new ArrayList<>();          // 词表:存单词字符串
    private Map<String, Integer> wordToId = new HashMap<>(); // 词表索引:快速通过单词查 ID
    private int vocabSize;                                   // 词表大小

    private double[][] W1; // 矩阵 W1:存储中心词向量 (行数=词表大小, 列数=维度)
    private double[][] W2; // 矩阵 W2:存储窗口词特征 (行数=维度, 列数=词表大小)

    public static void main(String[] args) {
        // 原始文本
        //String text = "I love natural language processing and deep learning";
        String text = Constant.words;

        SimpleWord2Vec model = new SimpleWord2Vec();
        model.train(text);
        
        // 训练完成后,打印来看看成果!
        model.printSimilarPairs();
    }

    public void train(String text) {
        // 【第一步】切词并建立词表
        String[] tokens = text.toLowerCase().split("\\s+");
        buildVocab(tokens);
        
        // 【第二步】随机初始化矩阵 W1 和 W2
        initMatrices();

        System.out.println("【开始训练】词表大小: " + vocabSize + ",正在玩命迭代中...");

        // 【第七步】大循环迭代 (Epochs)
        for (int epoch = 1; epoch <= EPOCHS; epoch++) {
            double totalLoss = 0;
            int sampleCount = 0;

            // 滑动窗口遍历整句话
            for (int centerIdx = 0; centerIdx < tokens.length; centerIdx++) {
                int centerId = wordToId.get(tokens[centerIdx]);

                // 左右看 WINDOW_SIZE 个词
                for (int w = -WINDOW_SIZE; w <= WINDOW_SIZE; w++) {
                    if (w == 0) continue; // 跳过中心词自己
                    
                    int contextIdx = centerIdx + w;
                    // 边界处理:碰到了开头或结尾,直接忽略
                    if (contextIdx >= 0 && contextIdx < tokens.length) {
                        int contextId = wordToId.get(tokens[contextIdx]);

                        // 抓到一个中心词和上下文词的“对”,送入核心训练
                        totalLoss += trainStep(centerId, contextId);
                        sampleCount++;
                    }
                }
            }

            // 每隔 100 遍打印一次 Loss,看看模型有没有变聪明
            if (epoch % 50 == 0 || epoch == 1) {
                System.out.format("   Epoch %3d/%d -> 平均损失 (Loss): %.6f\n", epoch, EPOCHS, (totalLoss / sampleCount));
            }
        }
        System.out.println("【训练完成】矩阵中的随机数已经成功进化为富有语义的词向量!\n");
    }

    /**
     * 单个样本的核心训练步骤 (包含前向、Softmax、Loss、反向传播)
     */
    private double trainStep(int centerId, int contextId) {
        // --- 【第三步】前向传播 (点积计算) ---
        // 从 W1 矩阵里捞出中心词对应的行向量 h
        double[] h = W1[centerId]; 
        
        // 算它和整个 W2 矩阵每一列的点积得分
        double[] scores = new double[vocabSize];
        for (int wordId = 0; wordId < vocabSize; wordId++) {
            for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
                scores[wordId] += h[dim] * W2[dim][wordId];
            }
        }

        // --- 【第四步】根据得分算 Softmax (转换成预测概率) ---
        double[] probabilities = softmax(scores);

        // --- 【第五步】计算 Loss 函数 ---
        // 只关心正确答案项的预测概率
        double targetProb = probabilities[contextId];
        double loss = -Math.log(Math.max(targetProb, 1e-15)); // 防止概率为 0 导致数学报错

        // 算出全词表的误差对账单 (Error = 预测概率 - 真实答案)
        double[] errors = new double[vocabSize];
        for (int i = 0; i < vocabSize; i++) {
            double target = (i == contextId) ? 1.0 : 0.0; // 只有正确答案是 1.0,其余是 0
            errors[i] = probabilities[i] - target;
        }

        // --- 【第六步】一气呵成反向传播,修正两个矩阵 ---
        
        // 工程细节:在修改 W2 之前,先借用旧的 W2 数字把 W1 的修改梯度攒下来
        double[] hGradient = new double[EMBEDDING_DIM];
        for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
            for (int wordId = 0; wordId < vocabSize; wordId++) {
                hGradient[dim] += errors[wordId] * W2[dim][wordId];
            }
        }

        // A. 修正 W2 矩阵
        for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
            for (int wordId = 0; wordId < vocabSize; wordId++) {
                W2[dim][wordId] -= LEARNING_RATE * errors[wordId] * h[dim];
            }
        }

        // B. 修正 W1 矩阵
        for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
            W1[centerId][dim] -= LEARNING_RATE * hGradient[dim];
        }

        return loss; // 返回这一步的损失值,用于宏观统计
    }

    /**
     * Softmax 核心计算
     */
    private double[] softmax(double[] scores) {
        double[] expScores = new double[scores.length];
        double sumExp = 0;
        
        // 技巧:找出最大值防止指数爆炸报错 (减去最大值再算指数)
        double maxScore = Double.NEGATIVE_INFINITY;
        for (double s : scores) { if (s > maxScore) maxScore = s; }

        for (int i = 0; i < scores.length; i++) {
            expScores[i] = Math.exp(scores[i] - maxScore);
            sumExp += expScores[i];
        }
        
        double[] probs = new double[scores.length];
        for (int i = 0; i < probs.length; i++) {
            probs[i] = expScores[i] / sumExp;
        }
        return probs;
    }

    /**
     * 建立词表
     */
    private void buildVocab(String[] tokens) {
        Set<String> uniqueWords = new LinkedHashSet<>(Arrays.asList(tokens));
        vocab.addAll(uniqueWords);
        vocabSize = vocab.size();
        for (int i = 0; i < vocabSize; i++) {
            wordToId.put(vocab.get(i), i);
        }
    }

    /**
     * 矩阵随机初始化 (填充很小的非零随机数)
     */
    private void initMatrices() {
        W1 = new double[vocabSize][EMBEDDING_DIM];
        W2 = new double[EMBEDDING_DIM][vocabSize];
        Random rand = new Random(42); // 固定随机种子,保证每次运行结果一致

        for (int i = 0; i < vocabSize; i++) {
            for (int j = 0; j < EMBEDDING_DIM; j++) {
                W1[i][j] = (rand.nextDouble() - 0.5) / EMBEDDING_DIM;
            }
        }
        for (int i = 0; i < EMBEDDING_DIM; i++) {
            for (int j = 0; j < vocabSize; j++) {
                W2[i][j] = (rand.nextDouble() - 0.5) / vocabSize;
            }
        }
    }

    /**
     * 训练结果验证:计算两个词向量的其余弦相似度 (看挨得近不近)
     */
    public void printSimilarPairs() {
        System.out.println("【词向量演练结果】来看看几个词之间的“亲密相似度”(1.0 为完美贴合):");
        String[][] testPairs = {
            {"walked", "for"},     // 经常在一起出现的同桌
            {"you", "heard"},  // 死党组合
            {"mouth", "open"}         // 距离相对较远的词
        };

        for (String[] pair : testPairs) {
            double sim = cosineSimilarity(W1[wordToId.get(pair[0])], W1[wordToId.get(pair[1])]);
            System.out.format("   单词 [%-8s] 和 [%-8s] 的向量相似度: %.4f\n", pair[0], pair[1], sim);
        }
    }

    private double cosineSimilarity(double[] vecA, double[] vecB) {
        double dotProduct = 0.0, normA = 0.0, normB = 0.0;
        for (int i = 0; i < vecA.length; i++) {
            dotProduct += vecA[i] * vecB[i];
            normA += Math.pow(vecA[i], 2);
            normB += Math.pow(vecB[i], 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值