1、数据预处理

(1)配置超参数及特殊词(config.py)
import os
import torch

# __file__ 是 Python 中的一个特殊变量,表示当前脚本的文件路径。
# os.path.dirname() 用于提取路径中的目录部分
BASE_PATH = os.path.dirname(__file__)

TRAIN_SIMPLE_PATH = BASE_PATH + "/small_data/train.json"
VAL_SIMPLE_PATH = BASE_PATH + "/small_data/val.json"

ZH_VOCAB_PATH = BASE_PATH + "/small_data/zh.txt"
EN_VOCAB_PATH = BASE_PATH + "/small_data/en.txt"

PAD_ID = 0
UNK_ID = 1
SOS_ID = 2
EOS_ID = 3

D_MODEL = 512
N_HEAD = 8
D_FF = 2048
N = 6
DROPOUT = 0.1

BATCH_SIZE = 15
LABEL_SMOOTHING = 0.1
LR = 1e-5
EPOCH = 100

device = "cuda" if torch.cuda.is_available() else "cpu"

SAVE_MODEL_PATH = BASE_PATH + '/param/best.pt'

MAX_LEN = 50
(2)生成词表(data_processor.py)
from config import *
import json
from utils import *
from collections import Counter


# 生成词表
def generate_vocab():
    en_vocab = ['<pad>', '<unk>', '<sos>', '<eos>']
    zh_vocab = ['<pad>', '<unk>', '<sos>', '<eos>']

    en_vocab_list = []
    zh_vocab_list = []

    # 解析json文件
    with open(TRAIN_SIMPLE_PATH, encoding='utf-8') as file:
        # file.read():读取文件的全部内容,返回一个字符串。
        # json.loads():将JSON格式的字符串解析为Python对象(如字典、列表等)
        lines = json.loads(file.read())
        for en_sent, zh_sent in lines:
            en_vocab_list += divided_en(en_sent)
            zh_vocab_list += divided_zh(zh_sent)
        print("train count: ", len(lines))
        # 按次数生成词表,如果语料库够大,可以按最小次数,过滤生僻字
        # most_common(n) 返回出现次数最多的元素及其频率,若n=None则返回所有的
        en_vocab_kv = Counter(en_vocab_list).most_common()
        en_vocab += [k.lower() for k, v in en_vocab_kv]
        # en_vocab += [k.lower() for k, v in en_vocab_kv if v > 1] # 过滤只出现一次的词

        zh_vocab_kv = Counter(zh_vocab_list).most_common()
        zh_vocab += [k.lower() for k, v in zh_vocab_kv]

        print("en_vocab size: ", len(en_vocab))
        print("zh_vocab size: ", len(zh_vocab))

        # 生成词表文件
        with open(EN_VOCAB_PATH, "x", encoding="utf-8") as file:
            file.write("\n".join(en_vocab))
        with open(ZH_VOCAB_PATH, "x", encoding="utf-8") as file:
            file.write("\n".join(zh_vocab))


if __name__ == '__main__':
    generate_vocab()
(3)词表解析函数+逐字生成预测(utils.py)
import jieba
import re
from config import *
import torch
from model import get_padding_mask, get_subsequent_mask
import sacrebleu


# 中文分词
def divided_zh(sentence):
    return jieba.lcut(sentence)  # lcut返回一个列表(list);cut返回惰性生成器,适合处理大文本时节省内存


# 英文分词
def divided_en(sentence):
    # 使用正则表达式匹配单词和标点符号
    pattern = r'\w+|[^\w\s]'
    return re.findall(pattern, sentence)  # 匹配pattern内容,并返回一个列表


# 词表解析函数
def get_vocab(lang='en'):
    if lang == 'en':
        file_path = EN_VOCAB_PATH
    elif lang == 'zh':
        file_path = ZH_VOCAB_PATH
    with open(file_path, encoding='utf-8') as file:
        lines = file.read()

    id2vocab = lines.split('\n')
    vocab2id = {v: k for k, v in enumerate(id2vocab)}
    return id2vocab, vocab2id


# 逐字生成预测值
def batch_greedy_decode(model, src_x, src_mask, mask_len=MAX_LEN):
    src_x = src_x.to(device)
    src_mask = src_mask.to(device)
    # 获取中文词表
    zh_id2vocab, _ = get_vocab('zh')
    # encoder 输出 shape(3,4,512)
    memory = model.encoder(src_x, src_mask)
    # 初始化decoder输入 shape(3,1) ->(3,2) ->(3,3) ->(3,4) ->(3,5)
    prob_x = torch.tensor([[SOS_ID]] * src_x.size(0))
    prob_x = prob_x.to(device)

    for _ in range(MAX_LEN):
        # prob_mask = get_padding_mask(prob_x, PAD_ID)
        # 推理时不需要加 序列掩码 因为本来就看不到后面
        tgt_pad_mask = get_padding_mask(prob_x, PAD_ID).to(device)
        tgt_subsequent_mask = get_subsequent_mask(prob_x.size(1)).to(device)
        tgt_mask = tgt_pad_mask | tgt_subsequent_mask
        prob_mask = tgt_mask != 0

        output = model.decoder(prob_x, prob_mask, memory, src_mask)  # shape(3,1,512) ->(3,2,512)
        # model.generator(output)  # shape(3,1,19) 19为词表大小
        output = model.generator(output[:, -1, :])  # shape(3,19) 获取预测出来的最后一个字
        predict = torch.argmax(output, dim=-1, keepdim=True)  # shape(3,1)
        # 把之前的预测值拼接起来
        prob_x = torch.concat([prob_x, predict], dim=-1)  # shape(3,2) ->(3,3) ->(3,4) ->(3,5)
        # 如果预测的值为eos,则表示序列结束
        if torch.all(predict == EOS_ID).item():
            break
    # 把预测出来的索引值,转为对应的词
    batch_prob_text = []
    for prob in prob_x:
        prob_text = []
        for prob_id in prob:
            if prob_id == SOS_ID:
                continue
            elif prob_id == EOS_ID:
                break
            prob_text.append(zh_id2vocab[prob_id])
        batch_prob_text.append(''.join(prob_text))
    return batch_prob_text


def bleu_score(hyp, refs):
    bleu = sacrebleu.corpus_bleu(hyp, refs, tokenize='zh')
    return round(bleu.score, 2)


if __name__ == '__main__':
    # sen = "小明和小红是好朋友"
    # print(divided_zh(sen))
    # sen = "hello world!"
    # print(divided_en(sen))
    target = "我喜欢读书。"
    vocabs = divided_zh(target)
    zh_id2vocab, zh_vocab2id = get_vocab('zh')
    print(zh_id2vocab)
    print(zh_vocab2id)
    print(vocabs)
    tokens = [zh_vocab2id.get(v, UNK_ID) for v in vocabs]
    print(tokens)
(4)加载数据集+数据对齐和整理(data_loader.py)
from config import *
from utils import *
import torch.utils.data as data
from torch.nn.utils.rnn import pad_sequence
import json
import torch
from model import get_padding_mask, get_subsequent_mask


class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        if type == 'train':
            self.file_path = TRAIN_SIMPLE_PATH
        elif type == 'val':
            self.file_path = VAL_SIMPLE_PATH
        # 读取文件
        with open(self.file_path, encoding='utf-8') as file:
            self.lines = json.loads(file.read())
        # 词表引入
        _, self.en_vocab2id = get_vocab('en')
        _, self.zh_vocab2id = get_vocab('zh')

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, item):
        en_text, zh_text = self.lines[item]
        source = [self.en_vocab2id.get(v.lower(), UNK_ID) for v in divided_en(en_text)]
        target = [self.zh_vocab2id.get(v.lower(), UNK_ID) for v in divided_zh(zh_text)]
        return source, target, zh_text

    # 数据对齐和整理
    # 包括增加<PAD_ID>、<SOS_ID>、<EOS_ID>
    def collate_fn(self, batch):
        batch_src, batch_tgt, tgt_text = zip(*batch)
        # source
        # encoder 数据进行 填充 即可
        src_x = pad_sequence([torch.LongTensor(src) for src in batch_src], True, PAD_ID)
        # src_x = pad_sequence([torch.LongTensor([SOS_ID] + src + [EOS_ID]) for src in batch_src], True, PAD_ID)
        src_mask = get_padding_mask(src_x, PAD_ID)
        # target
        # decoder输入时,前面加sos,填充
        tgt_x = [torch.LongTensor([SOS_ID] + tgt) for tgt in batch_tgt]
        tgt_x = pad_sequence(tgt_x, True, PAD_ID)

        tgt_pad_mask = get_padding_mask(tgt_x, PAD_ID)
        tgt_subsequent_mask = get_subsequent_mask(tgt_x.size(1))
        tgt_mask = tgt_pad_mask | tgt_subsequent_mask
        tgt_mask = tgt_mask != 0
        # decoder输出时, 后面加eos,填充
        tgt_y = [torch.LongTensor(tgt + [EOS_ID]) for tgt in batch_tgt]
        tgt_y = pad_sequence(tgt_y, True, PAD_ID)
        return src_x, src_mask, tgt_x, tgt_mask, tgt_y, tgt_text


if __name__ == '__main__':
    dataset = Dataset()
    print(dataset[0])  # ([5, 12, 4, 13], [8, 5, 6, 11], '我是一个学生')
    loader = data.DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn)
    print(next(iter(loader)))
    print(next(iter(loader)))

2、训练

(1)Label Smoothing(标签平滑)

Label Smoothing(标签平滑)是一种正则化技术,用于改善模型在分类任务中的泛化能力,特别是在处理过拟合问题时。它通过调整目标标签的分布,使得模型不会过于自信地预测某个类别,从而提高模型的鲁棒性。

在标准的分类任务中,通常使用one-hot编码作为目标标签。例如,对于一个三分类问题,类别2的one-hot编码为 [0, 1, 0]。这种编码方式会导致模型过度自信,可能会降低泛化能力。

Label Smoothing通过将one-hot标签中的1替换为一个略小于1的值(如 1 - ε),并将剩余的0替换为一个较小的值(如 ε / (K - 1),其中 K 是类别数),从而软化标签分布。

(2)BLEU指标

翻译任务常用的评估指标:BLEU指标,给模型效果打一个具体分数,并且在训练过程中,把分数最高的模型参数缓存下来,供后面的预测流程使用。

BLEU(Bilingual Evaluation Understudy)分数是一种用于评估机器翻译结果质量的指标。它通过比较机器翻译结果与一个或多个参考翻译之间的相似度来衡量翻译的准确性。BLEU 分数的值通常在 0 到 1 之间,值越高表示翻译质量越好。

#pip install sacrebleu
import sacrebleu
# 参考句子
refs = [['我喜欢吃苹果。', '我喜欢吃水果。'],
        ['这本书很有意思。', '这本书很好玩。'],
        ['他是一个出色的演员。', '他是一名杰出的演员。']]
# 候选句子
hyp = ['我爱吃苹果。', '这本书非常有趣。', '他是一位优秀的演员。']
bleu = sacrebleu.corpus_bleu(hyp, refs, tokenize='zh')
print(bleu.score)  # 31.65801094780895
from config import *
from utils import *
from data_loader import *
from model import make_model
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter


def lr_lambda_fn(step, warmup):
    lr = 0
    if step <= warmup:
        lr = (step / warmup) * 10
    else:
        lr = (warmup / step) * 10
    return max(lr, 0.1)


def run_epoch(loader, model, loss_fn, optimizer=None):
    # 初始化loss值,batch数量
    total_batchs = 0.
    total_loss = 0.
    model.to(device)
    # 加载数据并进行训练
    for src_x, src_mask, tgt_x, tgt_mask, tgt_y, tgt_text in loader:
        src_x = src_x.to(device)
        src_mask = src_mask.to(device)
        tgt_x = tgt_x.to(device)
        tgt_mask = tgt_mask.to(device)
        tgt_y = tgt_y.to(device)
        output = model(src_x, src_mask, tgt_x, tgt_mask)
        # 交叉熵损失,要求预测值为2维,目标值是1维的
        loss = loss_fn(output.reshape(-1, output.shape[-1]), tgt_y.reshape(-1))
        # 累计batch数量和loss值
        total_batchs += 1
        total_loss += loss.item()

        # 如果有优化器则进行反向传播
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 返回每轮次的平均损失
        return total_loss / total_batchs


# 评估函数
def evaluate(loader, model, max_len=MAX_LEN):
    tgt_sent = []  # label
    prob_sent = []  # pred

    for src_x, src_mask, _, _, _, tgt_text in loader:
        batch_prob_text = batch_greedy_decode(model, src_x, src_mask, max_len)
        tgt_sent += tgt_text
        prob_sent += batch_prob_text

    print(prob_sent)
    print(tgt_sent)
    return bleu_score(prob_sent, [tgt_sent])


if __name__ == '__main__':
    # writer = SummaryWriter('logs')

    en_id2vocab, _ = get_vocab('en')
    zh_id2vocab, _ = get_vocab('zh')

    SRC_VOCAB_SIZE = len(en_id2vocab)
    TGT_VOCAB_SIZE = len(zh_id2vocab)

    model = make_model(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, N_HEAD, D_FF, N, DROPOUT).to(device)

    train_dataset = Dataset('train')
    train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                   collate_fn=train_dataset.collate_fn)
    val_dataset = Dataset('val')
    val_loader = data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                 collate_fn=val_dataset.collate_fn)
    # 忽略padding损失,label_smoothing标签平滑正则化
    loss_fn = CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=LABEL_SMOOTHING)
    optimizer = Adam(model.parameters(), lr=LR)
    lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda_fn(step, EPOCH / 4))

    best_bleu = 0
    for e in range(EPOCH):
        # 训练
        model.train()
        train_loss = run_epoch(train_loader, model, loss_fn, optimizer)
        # 训练一轮进行学习率调整
        lr_scheduler.step()
        # 打印当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        print(current_lr)

        # 验证流程
        model.eval()
        val_loss = run_epoch(val_loader, model, loss_fn)
        # writer.add_scalars('loss', {"train_loss": train_loss, "val_loss": val_loss})
        print('epoch:', e, 'train_loss:', train_loss, 'val_loss:', val_loss)

        # 评估模型
        val_bleu = evaluate(val_loader, model, MAX_LEN)
        # 保存模型
        if val_bleu > best_bleu:
            torch.save(model.state_dict(), SAVE_MODEL_PATH)
            best_bleu = val_bleu

3、推理预测

from config import *
from utils import *
from model import *
import torch
from torch.nn.utils.rnn import pad_sequence

if __name__ == '__main__':
    en_id2vocab, en_vocab2id = get_vocab('en')
    zh_id2vocab, zh_vocab2id = get_vocab('zh')

    SRC_VOCAB_SIZE = len(en_id2vocab)
    TGT_VOCAB_SIZE = len(zh_id2vocab)
    # 模型定义
    model = make_model(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, N_HEAD, D_FF, N, DROPOUT).to(device)
    # 加载权重pt
    model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=device))
    model.eval()

    texts = [
        "I like playing basketball",
        "He is a doctor",
        "I am a student"
    ]

    # 把词转为对应的索引
    batch_src_token = [[en_vocab2id.get(v.lower(), UNK_ID) for v in divided_en(text)] for text in texts]
    # 每句话前后加上sos eos padding
    # src_x = pad_sequence([torch.LongTensor([SOS_ID] + src + [EOS_ID]) for src in batch_src_token], True, PAD_ID)

    # 每句话加上padding
    src_x = pad_sequence([torch.LongTensor(src) for src in batch_src_token], True, PAD_ID)

    src_mask = get_padding_mask(src_x, PAD_ID)

    prob_sent = batch_greedy_decode(model, src_x, src_mask)
    print(prob_sent)

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐