基于Transformer的自回归图像生成模型实现

发布时间:2026/7/4 12:27:20
基于Transformer的自回归图像生成模型实现 1. 项目概述与背景在计算机视觉领域图像生成一直是一个极具挑战性的研究方向。传统的生成对抗网络(GAN)和变分自编码器(VAE)虽然取得了不错的效果但都存在训练不稳定或生成质量有限的问题。近年来基于Transformer的自回归模型在图像生成领域展现出强大的潜力。本项目实现了一个基于Transformer的自回归图像生成模型其核心思想是将图像分割为小块(patch)通过BSQ二值量化模块将每个patch编码为离散的token序列然后使用Transformer模型对这些token进行自回归预测和生成。这种方法结合了离散表示的优势和Transformer强大的序列建模能力。提示自回归生成的核心特点是每个token的预测都依赖于之前生成的所有token这与人类书写或绘画的过程非常相似。2. 数据预处理流程2.1 图像token化处理数据预处理的第一步是将原始图像转换为token序列。这个过程依赖于预训练好的BSQPatchAutoEncoder模型python -m homework.tokenize checkpoints/YOUR_BSQPatchAutoEncoder.pth data/tokenized_train.pth data/train/*.jpg python -m homework.tokenize checkpoints/YOUR_BSQPatchAutoEncoder.pth data/tokenized_valid.pth data/valid/*.jpg这段代码会遍历指定目录下的所有图像文件使用BSQPatchAutoEncoder将它们编码为token序列并保存为.pth文件。关键参数说明patch_size5图像被分割为5×5的小块codebook_bits10每个patch被编码为10位二进制码对应1024种可能的token文件大小验证du -hs data/tokenized_train.pth对于典型的配置生成的token文件大小约为76MB具体取决于图像数量和分辨率。2.2 数据格式解析生成的token数据集具有以下结构每个样本是一个3维张量 (1, h, w)h和w取决于原始图像尺寸和patch大小每个元素是0到1023之间的整数代表对应patch的token ID3. 模型架构设计3.1 核心组件模型的核心是AutoregressiveModel类它继承自torch.nn.Module并实现了Autoregressive抽象基类class AutoregressiveModel(torch.nn.Module, Autoregressive): def __init__(self, d_latent: int 128, n_tokens: int 2**10): super().__init__() self.d_latent d_latent # 潜在空间维度 self.n_tokens n_tokens # token词汇表大小 self.L_max 1024 # 最大序列长度 # 嵌入层 self.embedding torch.nn.Embedding(num_embeddingsn_tokens, embedding_dimd_latent) # 位置编码 self.pos_emb torch.nn.Embedding(num_embeddingsself.L_max, embedding_dimd_latent) # Transformer编码器 encoder_layer torch.nn.TransformerEncoderLayer( d_modeld_latent, nhead8, dim_feedforward4*d_latent, activationgelu, batch_firstTrue, norm_firstTrue, dropout0.1 ) self.transformer torch.nn.TransformerEncoder( encoder_layerencoder_layer, num_layers2, normtorch.nn.LayerNorm(d_latent) ) # 输出层 self.fc_out torch.nn.Linear(d_latent, n_tokens)3.2 因果掩码机制自回归模型的关键是确保每个位置的预测只能依赖于之前的位置这通过因果掩码实现def _generate_causal_mask(self, L: int, device: torch.device) - torch.Tensor: 生成因果掩码确保序列中第i个位置只能看到前i-1个位置 :param L: 序列长度 h*w :param device: 设备 :return: 掩码 (L, L)float型上三角-inf下三角0 mask torch.nn.Transformer.generate_square_subsequent_mask(L, devicedevice) return mask这种掩码会阻止Transformer关注未来的token保证生成过程的因果性。4. 前向预测实现4.1 前向传播流程模型的前向传播过程包含以下步骤输入整形将输入从(B, h, w)展平为(B, L)其中Lh*wtoken嵌入通过Embedding层将整数token转换为连续向量位置编码为每个位置添加位置信息序列右移将整个序列向右移动一位实现自回归特性Transformer编码使用带因果掩码的Transformer处理序列输出预测通过线性层预测下一个token的概率分布def forward(self, x: torch.Tensor) - tuple[torch.Tensor, dict[str, torch.Tensor]]: if x.dim() 4: x x.squeeze(1) B, h, w x.shape L h * w # 展平成序列 x_flat x.reshape(B, L) # 嵌入 位置编码 token_emb self.embedding(x_flat) pos_idx torch.arange(L, devicex.device) pos_emb self.pos_emb(pos_idx) x_emb token_emb pos_emb # 自回归右移关键 x_emb F.pad(x_emb, (0,0,1,0))[:, :-1] # 因果掩码 mask self._generate_causal_mask(L, x.device) trans_out self.transformer(x_emb, maskmask) # 输出 logits self.fc_out(trans_out) logits_2d logits.reshape(B, h, w, self.n_tokens) return logits_2d, {}4.2 训练细节模型训练使用标准的交叉熵损失函数criterion torch.nn.CrossEntropyLoss() optimizer torch.optim.Adam(model.parameters(), lr1e-4) for epoch in range(num_epochs): for batch in dataloader: tokens batch.to(device) logits, _ model(tokens) # 计算损失 loss criterion(logits.view(-1, n_tokens), tokens.view(-1)) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()训练过程中需要注意学习率不宜过大建议从1e-4开始可以使用学习率调度器动态调整监控训练和验证损失防止过拟合5. 自回归生成实现5.1 生成算法自回归生成是从空序列开始逐步预测每个位置的tokentorch.no_grad() def generate(self, B: int 1, h: int 30, w: int 20, deviceNone) - torch.Tensor: device device or next(self.parameters()).device L h * w tokens torch.zeros((B, L), dtypetorch.long, devicedevice) for i in range(L): # 获取当前序列的logits logits, _ self(tokens.reshape(B, h, w)) logits logits.reshape(B, L, -1) # 只取当前位置的预测 curr_logits logits[:, i, :] # 采样下一个token probs F.softmax(curr_logits, dim-1) next_tokens torch.multinomial(probs, num_samples1) # 更新序列 if i L - 1: tokens[:, i1] next_tokens.squeeze() return tokens.reshape(B, h, w)5.2 生成策略在实际应用中可以采用不同的生成策略贪心搜索直接选择概率最大的tokennext_tokens torch.argmax(probs, dim-1, keepdimTrue)温度采样通过温度参数控制生成的多样性temperature 0.7 scaled_logits curr_logits / temperature probs F.softmax(scaled_logits, dim-1)Top-k采样只从概率最高的k个token中采样top_k 40 values, indices torch.topk(probs, top_k) probs torch.zeros_like(probs).scatter_(-1, indices, values)注意温度参数越小生成结果越确定温度参数越大生成结果越多样但可能不连贯。6. 模型评估与结果6.1 训练过程监控训练过程中需要监控以下指标训练损失验证损失生成样本质量典型的训练曲线如下图所示6.2 评估指标除了常规的损失函数还可以使用以下指标评估模型性能生成多样性计算生成样本的token分布熵重建质量通过BSQ解码器将生成的token还原为图像计算与原图的PSNR/SSIM人类评估人工评估生成图像的视觉质量6.3 评分结果项目评分系统给出的最终评估结果7. 实际应用与扩展7.1 图像补全该模型可用于图像补全任务给定部分图像token使用自回归模型预测缺失部分通过BSQ解码器还原完整图像7.2 风格迁移通过条件化生成可以实现风格迁移在模型输入中添加风格编码训练时使用风格分类器提供额外监督生成时指定目标风格7.3 模型优化方向更大规模的训练使用更多数据和更大模型分层生成先生成低分辨率图像再逐步细化混合架构结合CNN和Transformer的优势8. 常见问题与解决方案8.1 训练不稳定问题现象损失值波动大生成质量不一致解决方案降低学习率增加batch size使用梯度裁剪尝试不同的优化器如AdamW8.2 生成重复模式问题现象生成图像出现重复的局部模式解决方案增加温度参数使用Top-k或Top-p采样在训练数据中添加更多多样性8.3 长序列生成质量差问题现象生成大尺寸图像时质量下降解决方案使用相对位置编码实现分块生成策略增加Transformer层数9. 工程实践建议内存优化对于大图像使用梯度检查点减少内存占用torch.utils.checkpoint.checkpoint(self.transformer, x_emb, mask)并行生成利用GPU并行处理多个生成任务torch.no_grad() def batch_generate(self, B: int, h: int, w: int): # 批量生成实现 pass量化部署使用TorchScript量化模型提升推理速度quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8 )在实际部署中我发现将模型转换为ONNX格式可以显著提升推理速度特别是在边缘设备上。具体做法是dummy_input torch.zeros(1, h, w, dtypetorch.long) torch.onnx.export( model, dummy_input, model.onnx, input_names[input], output_names[output], dynamic_axes{ input: {0: batch_size}, output: {0: batch_size} } )对于需要生成高分辨率图像的应用建议采用分块生成策略先将图像分成若干块分别生成然后使用特殊的边界token确保块与块之间的连续性。这种方法可以突破Transformer序列长度的限制同时保持生成质量。