code2seq代码复现运行笔记pytorch版本

本文记录了code2seq模型的复现过程,使用PyTorch实现,主要涉及数据预处理、项目框架解析、关键代码模块及运行结果。通过调整配置文件,使用java-small子集进行训练,虽结果不尽如人意,但验证了流程的可行性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

摘要

本文用于学习代码生成论文:code2seq: Generating Sequences from Structured Representations of Code
代码地址:
Tensorflow版本:https://github.com/tech-srl/code2seq
jupyter+pytorch版本:https://github.com/m3yrin/code2seq
本文跑的是举jupyter版本,有一些小改动

注意:如果想简单跑一下的话,建议直接fork github中的jupyter版本的项目,本文未提及的代码都在那个github中

项目框架

在这里插入图片描述
项目文件夹下有code、dataset、logs、runs四个子文件夹
其中code下有三个重要子文件夹configs,notebooks(放源码,preparation初步下载并且处理数据,code2seq为项目主代码,上图中.jupyter文件为github中源文件,.py文件是我将其中的代码摘到空python文件中的,因为要放在服务器中运行。),src(工具属性代码,由code2seq代码在最前方引用工具)

处理数据(preparation文件)

#下载数据前将项目框架中需要的三个空文件夹创建一下(代码前有!的都是终端运行或者jupyter运行,自行理解)
!mkdir dataset runs logs
#下载Dataset到刚才创建的dataset文件夹中
!wget https://s3.amazonaws.com/code2seq/datasets/java-small-preprocessed.tar.gz -P dataset/
#将下载的数据解压
!tar -xvzf data/java-small-preprocessed.tar.gz -C dataset/
#切换到刚解压生成的文件夹java-small中
%cd data/java-small/
#for dev(暂时没看出有啥用处)
!head -20000 java-small.train.c2s > java-small.train_dev.c2s
#在java-small文件夹中创建四个不同的文件夹train、train_dev、val、test
!mkdir train train_dev val test
# split命令在shell中不存在可以在该文件夹中使用git bash执行split命令,时间比较长,分割的数据比较小(这一步比较魔幻因为将每一段数据代码路径都放进了一个.txt文件中,猜测此举会将训练时间大大延长,但自己就是试着跑一下,所以就直接用了)
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.test.c2s test/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.val.c2s val/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train.c2s train/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train_dev.c2s train_dev/

主代码文件(code2seq)

# 跑此项目需要把终端的路径cd到code/notebooks中,否则会出现导入src的包错误(自行判断),与路径相关的基本都在configs文件中(本文代码的路径是我自己改过的,与github中的不太一样)

import sys
sys.path.append('../')

import os
import time
import yaml
import random
import numpy as np
import warnings
import logging
import pickle
from datetime import datetime
from tqdm import tqdm_notebook as tqdm

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import torch
from torch import einsum
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

from src import utils, messenger

config_file = '../configs/config_code2seq.yml'

config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# Data source
DATA_HOME = config['data']['home']
DICT_FILE = DATA_HOME + config['data']['dict']
TRAIN_DIR = DATA_HOME + config['data']['train']
VALID_DIR = DATA_HOME + config['data']['valid']
TEST_DIR  = DATA_HOME + config['data']['test']

# Training parameter
batch_size = config['training']['batch_size']
num_epochs = config['training']['num_epochs']
lr = config['training']['lr']
teacher_forcing_rate = config['training']['teacher_forcing_rate']
nesterov = config['training']['nesterov']
weight_decay = config['training']['weight_decay']
momentum = config['training']['momentum']
decay_ratio = config['training']['decay_ratio']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']



# Model parameter
token_size = config['model']['token_size']
hidden_size = config['model']['hidden_size']
num_layers = config['model']['num_layers']
bidirectional = config['model']['bidirectional']
rnn_dropout = config['model']['rnn_dropout']
embeddings_dropout = config['model']['embeddings_dropout']
num_k = config['model']['num_k']

# etc
slack_url_path = config['etc']['slack_url_path']
info_prefix = config['etc']['info_prefix']


slack_url = None
if os.path.exists(slack_url_path):
    slack_url = yaml.load(open(slack_url_path), Loader=yaml.FullLoader)['slack_url']

warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)
random_state = 42

run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = '../../logs/' + run_id + '.log'
exp_dir = '../../runs/' + run_id
os.mkdir(exp_dir)

logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = messenger.Info(info_prefix, slack_url)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir : {}'.format(exp_dir))
msgr.print_msg('device : {}'.format(device))
msgr.print_msg(str(config))

PAD_TOKEN = '<PAD>' 
BOS_TOKEN = '<S>' 
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

# load vocab dict
with open(DICT_FILE, 'rb') as file:
    subtoken_to_count = pickle.load(file)
    node_to_count = pickle.load(file) 
    target_to_count = pickle.load(file)
    max_contexts = pickle.load(file)
    num_training_examples = pickle.load(file)
    msgr.print_msg('Dictionaries loaded.')

# making vocab dicts for terminal subtoken, nonterminal node and target.

word2id = {
   
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    UNK_TOKEN: UNK,
    }

vocab_subtoken = utils.Vocab(word2id=word2id)
vocab_nodes = utils.Vocab(word2id=word2id)
vocab_target = utils.Vocab(word2id=word2id)

vocab_subtoken.build_vocab(list(subtoken_to_count.keys()), min_count=0)
vocab_nodes.build_vocab(list(node_to_count.keys()), min_count=0)
vocab_target.build_vocab(list(target_to_count.keys()), min_count=0)

vocab_size_subtoken = len(vocab_subtoken.id2word)
vocab_size_nodes = len(vocab_nodes.id2word)
vocab_size_target = len(vocab_target.id2word)


msgr.print_msg('vocab_size_subtoken:' + str(vocab_size_subtoken))
msgr.print_msg('vocab_size_nodes:' + str(vocab_size_nodes))
msgr.print_msg('vocab_size_target:' + str(vocab_size_target))

num_length_train = num_training_examples
msgr.print_msg('num_examples : ' + str(num_length_train))

class DataLoader(object):

    def __init__(self, data_path, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=True, batch_time = False):
        
        """
        data_path : path for data 
        num_examples : total lines of data file
        batch_size : batch size
        num_k : max ast pathes included to one examples
        vocab_subtoken : dict of subtoken and its id
        vocab_nodes : dict of node simbol and its id
        vocab_target : dict of target simbol and its id
        """
        
        self.data_path = data_path
        self.batch_size = batch_size
        
        self.num_examples = self.file_count(data_path)
        self.num_k = num_k
        
        self.vocab_subtoken = vocab_subtoken
        self.vocab_nodes = vocab_nodes
        self.vocab_target = vocab_target
        
        self.index = 0
        self.pointer = np.array(range(self.num_examples))
        self.shuffle = shuffle
        
        self
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值