package com.zishi.aiwer.fenci;
import java.util.*;
public class SimpleWord2Vec {
// --- 1. 全局配置参数 ---
private static final int EMBEDDING_DIM = 10; // 词向量维度 (对应每个词向量有10个数字)
private static final int WINDOW_SIZE = 2; // 滑动窗口大小 (前后看2个词)
private static final double LEARNING_RATE = 0.05; // 学习率 (迈步的步伐)
private static final int EPOCHS = 150; // 整个文本刷多少遍 (刷得越多越精准)
// --- 2. 模型核心组件 ---
private List<String> vocab = new ArrayList<>(); // 词表:存单词字符串
private Map<String, Integer> wordToId = new HashMap<>(); // 词表索引:快速通过单词查 ID
private int vocabSize; // 词表大小
private double[][] W1; // 矩阵 W1:存储中心词向量 (行数=词表大小, 列数=维度)
private double[][] W2; // 矩阵 W2:存储窗口词特征 (行数=维度, 列数=词表大小)
public static void main(String[] args) {
// 原始文本
//String text = "I love natural language processing and deep learning";
String text = Constant.words;
SimpleWord2Vec model = new SimpleWord2Vec();
model.train(text);
// 训练完成后,打印来看看成果!
model.printSimilarPairs();
}
public void train(String text) {
// 【第一步】切词并建立词表
String[] tokens = text.toLowerCase().split("\\s+");
buildVocab(tokens);
// 【第二步】随机初始化矩阵 W1 和 W2
initMatrices();
System.out.println("【开始训练】词表大小: " + vocabSize + ",正在玩命迭代中...");
// 【第七步】大循环迭代 (Epochs)
for (int epoch = 1; epoch <= EPOCHS; epoch++) {
double totalLoss = 0;
int sampleCount = 0;
// 滑动窗口遍历整句话
for (int centerIdx = 0; centerIdx < tokens.length; centerIdx++) {
int centerId = wordToId.get(tokens[centerIdx]);
// 左右看 WINDOW_SIZE 个词
for (int w = -WINDOW_SIZE; w <= WINDOW_SIZE; w++) {
if (w == 0) continue; // 跳过中心词自己
int contextIdx = centerIdx + w;
// 边界处理:碰到了开头或结尾,直接忽略
if (contextIdx >= 0 && contextIdx < tokens.length) {
int contextId = wordToId.get(tokens[contextIdx]);
// 抓到一个中心词和上下文词的“对”,送入核心训练
totalLoss += trainStep(centerId, contextId);
sampleCount++;
}
}
}
// 每隔 100 遍打印一次 Loss,看看模型有没有变聪明
if (epoch % 50 == 0 || epoch == 1) {
System.out.format(" Epoch %3d/%d -> 平均损失 (Loss): %.6f\n", epoch, EPOCHS, (totalLoss / sampleCount));
}
}
System.out.println("【训练完成】矩阵中的随机数已经成功进化为富有语义的词向量!\n");
}
/**
* 单个样本的核心训练步骤 (包含前向、Softmax、Loss、反向传播)
*/
private double trainStep(int centerId, int contextId) {
// --- 【第三步】前向传播 (点积计算) ---
// 从 W1 矩阵里捞出中心词对应的行向量 h
double[] h = W1[centerId];
// 算它和整个 W2 矩阵每一列的点积得分
double[] scores = new double[vocabSize];
for (int wordId = 0; wordId < vocabSize; wordId++) {
for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
scores[wordId] += h[dim] * W2[dim][wordId];
}
}
// --- 【第四步】根据得分算 Softmax (转换成预测概率) ---
double[] probabilities = softmax(scores);
// --- 【第五步】计算 Loss 函数 ---
// 只关心正确答案项的预测概率
double targetProb = probabilities[contextId];
double loss = -Math.log(Math.max(targetProb, 1e-15)); // 防止概率为 0 导致数学报错
// 算出全词表的误差对账单 (Error = 预测概率 - 真实答案)
double[] errors = new double[vocabSize];
for (int i = 0; i < vocabSize; i++) {
double target = (i == contextId) ? 1.0 : 0.0; // 只有正确答案是 1.0,其余是 0
errors[i] = probabilities[i] - target;
}
// --- 【第六步】一气呵成反向传播,修正两个矩阵 ---
// 工程细节:在修改 W2 之前,先借用旧的 W2 数字把 W1 的修改梯度攒下来
double[] hGradient = new double[EMBEDDING_DIM];
for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
for (int wordId = 0; wordId < vocabSize; wordId++) {
hGradient[dim] += errors[wordId] * W2[dim][wordId];
}
}
// A. 修正 W2 矩阵
for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
for (int wordId = 0; wordId < vocabSize; wordId++) {
W2[dim][wordId] -= LEARNING_RATE * errors[wordId] * h[dim];
}
}
// B. 修正 W1 矩阵
for (int dim = 0; dim < EMBEDDING_DIM; dim++) {
W1[centerId][dim] -= LEARNING_RATE * hGradient[dim];
}
return loss; // 返回这一步的损失值,用于宏观统计
}
/**
* Softmax 核心计算
*/
private double[] softmax(double[] scores) {
double[] expScores = new double[scores.length];
double sumExp = 0;
// 技巧:找出最大值防止指数爆炸报错 (减去最大值再算指数)
double maxScore = Double.NEGATIVE_INFINITY;
for (double s : scores) { if (s > maxScore) maxScore = s; }
for (int i = 0; i < scores.length; i++) {
expScores[i] = Math.exp(scores[i] - maxScore);
sumExp += expScores[i];
}
double[] probs = new double[scores.length];
for (int i = 0; i < probs.length; i++) {
probs[i] = expScores[i] / sumExp;
}
return probs;
}
/**
* 建立词表
*/
private void buildVocab(String[] tokens) {
Set<String> uniqueWords = new LinkedHashSet<>(Arrays.asList(tokens));
vocab.addAll(uniqueWords);
vocabSize = vocab.size();
for (int i = 0; i < vocabSize; i++) {
wordToId.put(vocab.get(i), i);
}
}
/**
* 矩阵随机初始化 (填充很小的非零随机数)
*/
private void initMatrices() {
W1 = new double[vocabSize][EMBEDDING_DIM];
W2 = new double[EMBEDDING_DIM][vocabSize];
Random rand = new Random(42); // 固定随机种子,保证每次运行结果一致
for (int i = 0; i < vocabSize; i++) {
for (int j = 0; j < EMBEDDING_DIM; j++) {
W1[i][j] = (rand.nextDouble() - 0.5) / EMBEDDING_DIM;
}
}
for (int i = 0; i < EMBEDDING_DIM; i++) {
for (int j = 0; j < vocabSize; j++) {
W2[i][j] = (rand.nextDouble() - 0.5) / vocabSize;
}
}
}
/**
* 训练结果验证:计算两个词向量的其余弦相似度 (看挨得近不近)
*/
public void printSimilarPairs() {
System.out.println("【词向量演练结果】来看看几个词之间的“亲密相似度”(1.0 为完美贴合):");
String[][] testPairs = {
{"walked", "for"}, // 经常在一起出现的同桌
{"you", "heard"}, // 死党组合
{"mouth", "open"} // 距离相对较远的词
};
for (String[] pair : testPairs) {
double sim = cosineSimilarity(W1[wordToId.get(pair[0])], W1[wordToId.get(pair[1])]);
System.out.format(" 单词 [%-8s] 和 [%-8s] 的向量相似度: %.4f\n", pair[0], pair[1], sim);
}
}
private double cosineSimilarity(double[] vecA, double[] vecB) {
double dotProduct = 0.0, normA = 0.0, normB = 0.0;
for (int i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += Math.pow(vecA[i], 2);
normB += Math.pow(vecB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
}
word2Vec实现demo版本
最新推荐文章于 2026-06-23 21:55:01 发布

2552

被折叠的 条评论
为什么被折叠?



