Python GOFA: A Generative One-For-All Model for JointGraph Language Modeling,GFM模型代码实战,GNN+LLMs

     本文确定了GFM的三个关键理想属性:自监督预训练、任务中 的流动性和图感知。为了考虑这些特性,本文将传统的语言建模扩展到 图领域,提出了一种新的生成图语言模型GOFA 来解决该问题。该模型将 随机初始化的GNN层交织到冻结的预训练LLM中,以便将语义和结构建 模能力有机结合。GOFA 对新提出的图级下一个单词预测、问答和结构任 务进行了预训练,以获得上述GFM属性。预训练模型在下游任务上进一 步微调,以获得任务解决能力。在各种下游任务上对微调模型进行了评 估,显示了在零样本场景中解决结构和上下文问题的强大能力。

 大规模自监督训练,流式任务,图结构理解

 本文模型分为三部分:

LLM压缩层:将文本和graph压缩为固定长度的向量表示,其实是压缩的文本然后拼接的图嵌入;

Token-level的GNN层:其实就是RGAT,将带有边信息的x输入到模型里面,经过GAT层,得到的输出继续拼接上层LLM的输出,所谓下一层的输入;

LLM解码层:将graph language encoder的输出输入到LLMs的decoder,进行sft。

 训练规模:

句子补全:将TAG中的文本进行截断,后面的作为标签,去训练句子补全任务;

图结构理解任务:依据图结构,进行最短路径以及邻居节点的相关推理任务;

图问答任务:基于图信息,用问答数据集进行问答的sft。

 相关工作:

图prompt的学习;LLM作为增强器;LLM作为预测器

模型代码 model.py

from collections import namedtuple, OrderedDict

import torch
import numpy as np

from gp.nn.models.GNN import MultiLayerMessagePassing
from gp.nn.layer.pyg import RGCNEdgeConv
from gp.nn.layer.pyg import TransformerConv as MConv
from .helper import GOFALlamaHelper, GOFAMistralHelper, LlamaHelper

LLM_DIM_DICT = {"ST": 768, "BERT": 768, "e5": 1024, "llama2_7b": 4096, "llama2_13b": 5120, "mamba": 768, "icae": 4096,
                "icae_mem": 4096}

def print_fixed_length(text, line_width=120):
    t_pointer = 0
    while t_pointer < len(text):
        t_pointer += line_width
        t_cur_text = text[t_pointer - line_width:t_pointer]
        print(t_cur_text)


def print_text_side_by_side(text_1, text_2, line_width=120, space=10):
    t1_len = len(text_1)
    t2_len = len(text_2)
    seg_width = int((line_width-space)/2)
    t1_pointer = 0
    t2_pointer = 0
    text_1 = text_1.replace('\n', ' ').replace('\r', ' ')
    text_2 = text_2.replace('\n', ' ').replace('\r', ' ')
    while t1_pointer<t1_len or t2_pointer<t2_len:
        t1_pointer += seg_width
        t2_pointer += seg_width
        t1_cur_text = text_1[t1_pointer-seg_width:t1_pointer]
        t2_cur_text = text_2[t2_pointer-seg_width:t2_pointer]
        t1_cur_text = t1_cur_text + " "*(seg_width-len(t1_cur_text))
        t2_cur_text = t2_cur_text + " "*(seg_width-len(t2_cur_text))
        full_text = t1_cur_text + " "*space + t2_cur_text
        print(full_text)



class GOFA(torch.nn.Module):
    def __init__(self, transformer_args, mode="autoencoder", base_llm="llama7b", save_dir=""):
        super().__init__()

        self.mode = mode
        self.save_dir = save_dir

        if base_llm == 'llama7b':
            self.llm_model = GOFALlamaHelper(transformer_args)
        elif base_llm == 'mistral7b':
            self.llm_model = GOFAMistralHelper(transformer_args)
        elif base_llm == 'llama7blora' or 'mistral7blora':
            self.llm_model = LlamaHelper(transformer_args)
        else:
            raise NotImplementedError(base_llm + " is not supported. Please choose from: llama7b, mistral7b,")

        if mode == "autoencoder":
            self.encode = self.auto_encode
            self.decode = self.auto_decode
        elif mode == "autoencodergen":
            self.encode = self.auto_encode
            self.decode = self.generate
        elif mode == "direct":
            self.encode = self.direct_decode
            self.decode = lambda x: x
        elif mode == "nograph":
            self.encode = self.llm_decode
            self.decode = lambda x: x
        elif mode == "nographgen":
            self.encode = self.llm_gen
            self.decode = lambda x: x
        else:
            # TODO: not implemented
            raise NotImplementedError(mode + " mode not implemented")

    def llm_decode(self, g):
        text_inputs = g.x[g.node_map.cpu().numpy()][g.question_index.cpu().numpy()].tolist()
        com_text_inputs = []
        for i in range(len(text_inputs)):
            t = text_inputs[i] + g.question[i]
            com_text_inputs.append(t)
        answer_texts = g.answer[g.answer_map.cpu().numpy()].tolist()

        answer_logits, answer_id, masks = self.llm_model(answer_texts, com_text_inputs)
        GNNLMOutput = namedtuple("GNNLMOutput", ["logits", "answer_id", "pred_text", "answer"])
        return GNNLMOutput(logits=answer_logits[masks], pred_text=self.logit_to_text(answer_logits, masks),
                           answer_id=answer_id, answer=answer_texts)

    def llm_gen(self, g):
        text_inputs = g.x[g.node_map.cpu().numpy()][g.question_index.cpu().numpy()].tolist()
        com_text_inputs = []
        for i in range(len(g.question)):
            # t = '[INST]\n' + g.question[i] + '\n[/INST]\n\n'
            if len(text_inputs) == len(g.question):
                t = '[INST]\n' + text_inputs[i] + g.question[i] + 'Please only answer the category name, no other information or explanation. \n[/INST]\n\n'
            elif len(text_inputs) == len(g.question) * 2:
                t = '[INST]\n' + text_inputs[2*i] + text_inputs[2*i+1] + g.question[
                    i] + 'Please only answer the category name, no other information or explanation. \n[/INST]\n\n'
            elif len(text_inputs) > len(g.question):
                t = '[INST]\n' + g.question[i] + 'Please only answer the category name, no other information or explanation. \n[/INST]\n\n'
            else:
                raise ValueError('Text_input length did not match question length.')
            com_text_inputs.append(t)
        answer_texts = g.answer[g.answer_map.cpu().numpy()].tolist()
        generated_text = self.llm_model.generate(com_text_inputs)
        for i, txt in enumerate(generated_text):
            print_fixed_length(f"question: {com_text_inputs}")
            print("-" * 120)
            print_text_side_by_side("target: " + answer_texts[i], "gen: " + generated_text[i])
            print("=" * 120)
        GNNLMOutput = namedtuple("GNNLMOutput", ["logits", "answer_id", "pred_text", "answer"])
        return GNNLMOutput(logits=torch.randn([1, 32132]), pred_text=generated_text, answer_id=torch.tensor([1]),
                           answer=answer_texts)

    def auto_encode(self, g):
        g.num_node_feat = g.x.shape[0]
        if g.edge_attr is not None:
            text_inputs = np.concatenate([g.x, g.edge_attr], axis=0)
        else:
            text_inputs = g.x
        llm_output = self.llm_model.encode(text_inputs.tolist(), graph=g, partial_grad=True)
        g.x = llm_output[:g.node_map.size(-1)]
        return g

    def auto_decode(self, g):
        emb = g.x
        answer_texts = g.answer[g.answer_map.cpu().numpy()].tolist()
        prompt_texts = g.question[g.question_map.cpu().numpy()].tolist()
        emb = emb[g.question_index]
        answer_logits, answer_id, masks = self.llm_model.decode(answer_texts, emb, prompt=prompt_texts)
        GNNLMOutput = namedtuple("GNNLMOutput", ["logits", "answer_id", "pred_text", "answer"])
        return GNNLMOutput(logits=answer_logits[masks], pred_text=self.logit_to_text(answer_logits, masks),
                           answer_id=answer_id, answer=answer_texts)

    def generate(self, g):
        emb = g.x
        answer_texts = g.answer[g.answer_map.cpu().numpy()].tolist()
        prompt_texts = g.question[g.question_map.cpu().numpy()].tolist()
        emb = emb[g.question_index]
        generated_text = self.llm_model.generate(emb, prompt=prompt_texts)
        for i, txt in enumerate(generated_text):
            print_fixed_length("question: " + prompt_texts[i])
            print("-"*120)
            print_text_side_by_side("target: "+answer_texts[i], "gen: "+generated_text[i])
            print("="*120)
        GNNLMOutput = namedtuple("GNNLMOutput", ["logits", "answer_id", "pred_text", "answer"])
        return GNNLMOutput(logits=torch.randn([1, 32132]).to(emb.device), pred_text=generated_text, answer_id=torch.tensor([1]).to(emb.device),
                           answer=answer_texts)

    def forward(self, g):
        g = self.encode(g)
        return self.decode(g)

    def direct_decode(self, g):
        num_node_feat = g.x.shape[0]
        g.num_node_feat = num_node_feat
        g.x = g.x[g.node_map.cpu().numpy()]
        answer_texts = g.answer.tolist()
        answer_logits, answer_id, masks = self.llm_model(g.x.tolist(), answer_texts, g.edge_attr.tolist(), graph=g,
                                                         partial_grad=True)
        GNNLMOutput = namedtuple("GNNLMOutput", ["logits", "answer_id", "pred_text", "answer"])
        return GNNLMOutput(logits=answer_logits[masks], pred_text=self.logit_to_text(answer_logits, masks),
                           answer_id=answer_id, answer=answer_texts)

    def save_partial(self, save_dir):
        state_dict = self.llm_model.model.icae.model.model.g_layers.state_dict()
        full_state_dict = self.state_dict()
        for k in full_state_dict:
            if "default" in k:
                state_dict[k] = full_state_dict[k]
        torch.save(state_dict, save_dir)

    def load_partial(self, load_dir=None, state_dict=None):
        if load_dir is not None and state_dict is not None:
            raise RuntimeError("You should only specify either load_dict or load_dir")
        if load_dir is None and state_dict is None:
            print("No state dict loaded")
            return
        if load_dir is not None:
            state_dict = torch.load(load_dir)
        new_state_dict = OrderedDict()
        for name in state_dict:
            if "decadapt" in name:
                new_state_dict[name.replace("decadapt", "default")] = state_dict[name]
            else:
                new_state_dict[name] = state_dict[name]
        self.llm_model.model.icae.model.model.g_layers.load_state_dict(new_state_dict, strict=False)
        self.load_state_dict(new_state_dict, strict=False)

    def logit_to_text(self, logits, masks):
        tokenizer = self.llm_model.get_tokenizer()
        if len(logits.size()) == 2:
            logits = logits.unsqueeze(0)
        decoded_texts = []
        for i in range(logits.size(0)):
            sample_logit = logits[i]
            sample_mask = masks[i]
            sample_logit = sample_logit[sample_mask]
            token_ids = sample_logit[:, :32000].argmax(dim=-1).unsqueeze(0)
            token_ids[token_ids >= 32000] = 1
            sample_text = tokenizer.batch_decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            decoded_texts.extend(sample_text)
        return decoded_texts


class PyGRGCNEdge(MultiLayerMessagePassing):
    def __init__(self, num_layers: int, num_rels: int, inp_dim: int, out_dim: int, drop_ratio=0, JK="last",
                 batch_norm=True, ):
        super().__init__(num_layers, inp_dim, out_dim, drop_ratio, JK, batch_norm)
        self.num_rels = num_rels
        self.build_layers()

    def build_input_layer(self):
        return RGCNEdgeConv(self.inp_dim, self.out_dim, self.num_rels)

    def build_hidden_layer(self):
        return RGCNEdgeConv(self.inp_dim, self.out_dim, self.num_rels)

    def build_message_from_input(self, g):
        return {"g": g.edge_index, "h": g.x, "e": g.edge_type, "he": g.edge_attr, }

    def build_message_from_output(self, g, h):
        return {"g": g.edge_index, "h": h, "e": g.edge_type, "he": g.edge_attr}

    def layer_forward(self, layer, message):
        return self.conv[layer](message["h"], message["he"], message["g"], message["e"])


class PyGRGATCN(MultiLayerMessagePassing):
    def __init__(self, num_layers: int, t_layer: int, inp_dim: int, out_dim: int, drop_ratio=0.0, JK="last",
                 batch_norm=True, ):
        super().__init__(num_layers, inp_dim, out_dim, 0, JK, batch_norm)
        self.t_layer = t_layer
        self.layer_dropout = drop_ratio
        self.build_layers()

    def build_input_layer(self):
        # return GATConv(self.inp_dim*self.t_layer, self.inp_dim*self.t_layer)
        # return TransformerConv(self.inp_dim*self.t_layer, self.inp_dim*self.t_layer)
        return MConv(self.inp_dim, self.t_layer, 8, add_self_loops=True, dropout=self.layer_dropout,
                     layer_idx=len(self.conv))

    def build_hidden_layer(self):
        # return GATConv(self.inp_dim * self.t_layer, self.inp_dim * self.t_layer)
        # return TransformerConv(self.inp_dim*self.t_layer, self.inp_dim*self.t_layer)
        return MConv(self.inp_dim, self.t_layer, 8, add_self_loops=True, dropout=self.layer_dropout,
                     layer_idx=len(self.conv))

    def build_message_from_input(self, g):
        return {"g": g.edge_index, "h": g.x, "e": g.edge_type, "he": g.edge_attr, }

    def build_message_from_output(self, g, h):
        return {"g": g.edge_index, "h": h, "e": g.edge_type, "he": g.edge_attr}

    def layer_forward(self, layer, message):
        return self.conv[layer](message["h"], message["g"], message["he"])

helper.py 

# example code for running inference with fine-tuned checkpoint
import numpy as np
import torch
from modules.gofa_icae_llama_modeling import LlamaICAE
from modules.gofa_icae_mistral_modeling import MistralICAE
from modules.llama_modeling import LlamaLora
from collections import OrderedDict
from safetensors.torch import load_file


class GOFALlamaHelper(torch.nn.Module):
    def __init__(self, transformer_args):
        super().__init__()
        model_args, training_args, gofa_args = transformer_args
        model = LlamaICAE(model_args, training_args, gofa_args)  # restored llama2-7b-chat model
        state_dict = torch.load(
            "./cache_data/model/llama-2-7b-chat-finetuned-icae_zeroweight_llama2.pt")  # change the path for your model
        new_state_dict = OrderedDict()

        for layer_name, weight in state_dict.items():
            if isinstance(weight, torch.Tensor) or weight != 0.0:
                new_state_dict[layer_name.replace("default", "encadapt")] = weight
        model.load_state_dict(new_state_dict, strict=False)
        # model.merge_lora()
        self.dec_lora = model_args.dec_lora
        self.mem_tokens = list(range(model.vocab_size, model.vocab_size + model_args.mem_size))
        self.mem_size = model_args.mem_size
        self.model = model
        self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
        self.model.left_tokenizer.pad_token = self.model.left_tokenizer.bos_token
        for param in self.model.icae.parameters():
            param.requires_grad = False
        for param in self.model.icae.get_base_model().model.g_layers.parameters():
            param.requires_grad = True
        if self.dec_lora:
            for name, param in self.model.icae.named_parameters():
                if "default" in name:
                    param.requires_grad = True

    def get_tokenizer(self):
        return self.model.tokenizer

    def train_mode(self):
        # for param in self.model.dec.parameters():
        #     param.requires_grad = False
        self.model.icae.set_adapter("encadapt")
        for param in self.model.icae.parameters():
            param.requires_grad = False

    def forward(self, data, answer, edge_data, prompt=None, graph=None, partial_grad=None):
        cur_device = self.model.memory_token_embed.weight.device
        batch_size = len(data)
        if prompt is None:
            prompt = [""] * len(data)

        text_input = self.model.tokenizer(data, truncation=True, max_length=self.model.training_args.model_max_length,
                                          padding=False, return_attention_mask=False)["input_ids"]
        text_target = \
            self.model.tokenizer(answer, truncation=True, max_length=self.model.training_args.model_max_length,
                                 padding=False, return_attention_mask=False)["input_ids"]
        edge_input = \
            self.model.tokenizer(edge_data, truncation=True, max_length=self.model.training_args.model_max_length,
                                 padding=False, return_attention_mask=False)["input_ids"] if len(edge_data) > 0 else []

        text_target = [p + [self.model.tokenizer.eos_token_id] for p in text_target]
        target_ids = torch.cat([torch.tensor(p, dtype=torch.long) for p in text_target], dim=-1).to(cur_device)

        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        prompt_input = [[self.model.ft_token_id] + a + [self.model.ft_token_id] if len(a) > 0 else a for a in
                        prompt_input]

        text_ids = [a + self.mem_tokens + b + c for a, b, c in zip(text_input, prompt_input, text_target)]
        target_mask = [[False] * (len(a) + self.mem_size + len(b) - 1) + [True] * (len(c)) + [False] for a, b, c in
                       zip(text_input, prompt_input, text_target)]

        edge_text_ids = [a + self.mem_tokens for a in edge_input]

        graph.num_node_feat = len(text_ids)

        input_ids = text_ids + edge_text_ids
        target_mask = target_mask + [[False] * len(a) for a in edge_text_ids]

        text_output = {"input_ids": input_ids, "attention_mask": target_mask}
        text_output = self.model.tokenizer.pad(text_output, padding=True, return_tensors="pt")
        input_ids = text_output["input_ids"].to(device=cur_device)
        target_mask = text_output["attention_mask"].to(torch.bool)
        mem_mask = torch.logical_and(input_ids >= self.model.vocab_size,
                                     input_ids < self.model.vocab_size + self.mem_size)

        mem_mask = mem_mask.to(cur_device)

        autoencoder_input_embedding = self.model.icae.get_base_model().model.embed_tokens(input_ids)
        autoencoder_input_embedding[mem_mask] = self.model.memory_token_embed(
            input_ids[mem_mask] - self.model.vocab_size).to(autoencoder_input_embedding)
        self.model.icae.enable_adapter_layers()
        compress_outputs = self.model.icae(inputs_embeds=autoencoder_input_embedding, output_hidden_states=True,
                                           graph=graph, mem_mask=mem_mask, partial_grad=partial_grad, map_node=False)
        self.model.icae.disable_adapter_layers()

        compress_outputs = compress_outputs.logits
        return compress_outputs, target_ids, target_mask

    def encode(self, data, graph=None, partial_grad=None):
        batch_size = len(data)
        text_output = \
        self.model.tokenizer(data, truncation=True, max_length=self.model.training_args.model_max_length, padding=False,
                             return_attention_mask=False)["input_ids"]
        text_output = [t + self.mem_tokens for t in text_output]
        text_output = {"input_ids": text_output}
        text_output = self.model.tokenizer.pad(text_output, padding=True, return_tensors="pt")["input_ids"].to(
            self.model.memory_token_embed.weight.device)
        mem_mask = text_output >= self.model.vocab_size

        mem_mask = mem_mask.to(self.model.memory_token_embed.weight.device)

        autoencoder_input_embedding = self.model.icae.get_base_model().model.embed_tokens(text_output)
        autoencoder_input_embedding[mem_mask] = self.model.memory_token_embed(
            text_output[mem_mask] - self.model.vocab_size).to(autoencoder_input_embedding)
        self.model.icae.set_adapter("encadapt")
        self.model.icae.enable_adapter_layers()
        for name, param in self.model.icae.named_parameters():
            if "encadapt" in name:
                param.requires_grad = False
        compress_outputs = self.model.icae(inputs_embeds=autoencoder_input_embedding, output_hidden_states=True,
                                           graph=graph, mem_mask=mem_mask, partial_grad=partial_grad, map_node=True)
        self.model.icae.disable_adapter_layers()
        compress_outputs = compress_outputs.hidden_states[-1]
        if graph is not None:
            node_emb = compress_outputs[:len(graph.node_map)]
            map_mem_mask = mem_mask[:graph.num_node_feat][graph.node_map]
            memory_embedding = node_emb[map_mem_mask].view(len(node_emb), self.mem_size, -1)
        else:
            memory_embedding = compress_outputs[mem_mask].view(batch_size, self.mem_size, -1)
        return memory_embedding

    def llm_output(self, data, input, prompt=None):
        self.model.icae.disable_adapter_layers()
        cur_device = self.model.memory_token_embed.weight.device
        prompt_output = self.model.tokenizer(data, add_special_tokens=False, padding=False, truncation=True,
                                             max_length=self.model.training_args.model_max_length)["input_ids"]
        input_tokens = self.model.tokenizer(input, add_special_tokens=False, padding=False, truncation=True,
                                            max_length=self.model.training_args.model_max_length)["input_ids"]
        prompt_output = [p + [self.model.tokenizer.eos_token_id] for p in prompt_output]
        if prompt is None:
            prompt = [""] * len(data)
        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        prompt_input = [[self.model.ft_token_id] + a + [self.model.ft_token_id] if len(a) > 0 else a for a in
                        prompt_input]
        prompt_ids = [a + b + c for a, b, c in zip(input_tokens, prompt_input, prompt_output)]
        prompt_mask = [[False] * (len(a) + len(b) - 1) + [True] * (len(c)) + [False] for a, b, c in
                       zip(input_tokens, prompt_input, prompt_output)]
        answer_prompt = torch.cat([torch.tensor(p, dtype=torch.long) for p in prompt_output], dim=-1).to(cur_device)

        prompt_output = {"input_ids": prompt_ids, "attention_mask": prompt_mask}
        prompt_output = self.model.tokenizer.pad(prompt_output, padding=True, return_tensors="pt")

        prompt_answer_ids = prompt_output["input_ids"].to(cur_device)

        target_mask = prompt_output["attention_mask"].to(cur_device).to(torch.bool)

        prompt_answer_embs = self.model.icae.get_base_model().model.embed_tokens(prompt_answer_ids)

        output_emb = self.model.dec(inputs_embeds=prompt_answer_embs).logits

        return output_emb, answer_prompt, target_mask

    def decode(self, data, mem_embs, graph=None, prompt=None):
        prompt_output = self.model.tokenizer(data, add_special_tokens=False, padding=False, truncation=True,
                                             max_length=self.model.training_args.model_max_length)["input_ids"]
        prompt_output = [p + [self.model.tokenizer.eos_token_id] for p in prompt_output]
        if prompt is None:
            prompt = [""] * len(data)
        prompt_input = self.model.left_tokenizer(prompt, add_special_tokens=False, padding=False, truncation=True, max_length=512)["input_ids"]
        # print(self.model.left_tokenizer.batch_decode(prompt_input))
        prompt_input = [[self.model.ft_token_id] + a + [self.model.ft_token_id] if len(a) > 0 else a for a in
                        prompt_input]
        prompt_ids = [a + b for a, b in zip(prompt_input, prompt_output)]
        prompt_mask = [[False] * len(a) + [True] * (len(b)) + [False] for a, b in zip(prompt_input, prompt_output)]
        mem_mask = torch.tensor([[False] * (self.mem_size - 1) for _ in prompt_output], dtype=torch.long).to(mem_embs.device)
        answer_prompt = torch.cat([torch.tensor(p, dtype=torch.long) for p in prompt_output], dim=-1).to(
            mem_embs.device)
        prompt_output = {"input_ids": prompt_ids, "attention_mask": prompt_mask}
        prompt_output = self.model.tokenizer.pad(prompt_output, padding=True, return_tensors="pt")
        prompt_answer_ids = prompt_output["input_ids"].to(mem_embs.device)
        special_prompt = prompt_answer_ids >= self.model.vocab_size
        target_mask = torch.cat([mem_mask, prompt_output["attention_mask"].to(mem_mask)], dim=-1).to(torch.bool)
        prompt_answer_embs = self.model.icae.get_base_model().model.embed_tokens(prompt_answer_ids)
        prompt_answer_embs[special_prompt] = self.model.memory_token_embed(
            prompt_answer_ids[special_prompt] - self.model.vocab_size).to(prompt_answer_embs)
        decode_embed = torch.cat([mem_embs.to(prompt_answer_embs), prompt_answer_embs], dim=1)
        if self.dec_lora:
            self.model.icae.set_adapter("default")
            self.model.icae.enable_adapter_layers()
        else:
            self.model.icae.disable_adapter_layers()
        output_emb = self.model.icae(inputs_embeds=decode_embed).logits

        return output_emb, answer_prompt, target_mask

    def generate(self, mem_embs, graph=None, prompt=None):
        if prompt is None:
            prompt = [""] * len(mem_embs)
        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        prompt_ids = [[self.model.ft_token_id] + a + [self.model.ft_token_id] if len(a) > 0 else a for a in
                      prompt_input]

        mem_mask = [[True] * self.mem_size + [False] * len(a) for a in prompt_ids]
        att_mask = [[True] * (self.mem_size + len(a)) for a in prompt_ids]
        prompt_ids = [[self.model.tokenizer.pad_token_id] * self.mem_size + a for a in prompt_ids]
        input_prompt_ids = self.model.left_tokenizer.pad({"input_ids": prompt_ids, "attention_mask": mem_mask},
                                                         padding=True, return_tensors="pt")
        mem_mask = input_prompt_ids["attention_mask"].to(device=mem_embs.device, dtype=torch.bool)

        input_prompt_ids = self.model.left_tokenizer.pad({"input_ids": prompt_ids, "attention_mask": att_mask},
                                                         padding=True, return_tensors="pt")

        prompt_ids = input_prompt_ids["input_ids"]
        att_mask = input_prompt_ids["attention_mask"].to(device=mem_embs.device)

        prompt_answer_ids = prompt_ids.to(device=mem_embs.device, dtype=torch.long)
        special_prompt = prompt_answer_ids >= self.model.vocab_size
        prompt_answer_embs = self.model.icae.get_base_model().model.embed_tokens(prompt_answer_ids)
        prompt_answer_embs[special_prompt] = self.model.memory_token_embed(
            prompt_answer_ids[special_prompt] - self.model.vocab_size).to(prompt_answer_embs)

        prompt_answer_embs[mem_mask] = mem_embs.view(-1, mem_embs.size()[-1])

        # decode_embed = torch.cat([mem_embs.to(prompt_answer_embs), prompt_answer_embs], dim=1)
        decode_embed = prompt_answer_embs
        output = decode_embed.clone()

        generate_text = []
        eos_reached = torch.zeros(len(output), dtype=torch.bool).to(output.device)

        past_key_values = None
        if self.dec_lora:
            self.model.icae.set_adapter("default")
            self.model.icae.enable_adapter_layers()
        else:
            self.model.icae.disable_adapter_layers()
        for i in range(128):
            out = self.model.icae(inputs_embeds=output, attention_mask=att_mask, past_key_values=past_key_values,
                                 use_cache=True)

            logits = out.logits[:, -1]

            past_key_values = out.past_key_values

            next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

            eos_reached = torch.logical_or(eos_reached, (next_token_id == self.model.tokenizer.eos_token_id).view(-1))

            eos_reached = torch.logical_or(eos_reached, (next_token_id == self.model.tokenizer.bos_token_id).view(-1))

            eos_reached = torch.logical_or(eos_reached, (next_token_id >= 32000).view(-1))
            generate_text.append(next_token_id.view(-1, 1))
            if torch.all(eos_reached):
                break

            output = self.model.icae.get_base_model().model.embed_tokens(next_token_id).to(mem_embs.device)

            att_mask = torch.cat(
                [att_mask, torch.ones((len(att_mask), 1), dtype=att_mask.dtype, device=att_mask.device)], dim=-1)
        generate_text = torch.cat(generate_text, dim=-1)
        generate_text[generate_text >= 32000] = 1

        generated_text = self.model.tokenizer.batch_decode(generate_text)

        return generated_text


class GOFAMistralHelper(torch.nn.Module):
    def __init__(self, transformer_args):
        super().__init__()
        model_args, training_args, gofa_args = transformer_args
        model = MistralICAE(model_args, training_args, gofa_args)  # restored llama2-7b-chat model
        state_dict = load_file("./cache_data/model/mistral_7b_ft_icae.safetensors")  # change the path for your model
        new_state_dict = OrderedDict()
        for layer_name, weight in state_dict.items():
            new_state_dict[layer_name.replace("default", "encadapt")] = weight
        model.load_state_dict(new_state_dict, strict=False)
        # model.merge_lora()
        self.dec_lora = model_args.dec_lora
        self.mem_tokens = list(range(model.vocab_size, model.vocab_size + model_args.mem_size))
        self.mem_size = model_args.mem_size
        self.model = model
        self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
        self.model.left_tokenizer.pad_token = self.model.left_tokenizer.bos_token
        for param in self.model.icae.parameters():
            param.requires_grad = False
        for param in self.model.icae.get_base_model().model.g_layers.parameters():
            param.requires_grad = True
        if self.dec_lora:
            for name, param in self.model.icae.named_parameters():
                if "default" in name:
                    param.requires_grad = True

    def get_tokenizer(self):
        return self.model.tokenizer

    def train_mode(self):
        self.model.icae.set_adapter("encadapt")
        for param in self.model.icae.parameters():
            param.requires_grad = False

    def forward(self, data, answer, edge_data, prompt=None, graph=None, partial_grad=None):
        cur_device = self.model.memory_token_embed.weight.device
        batch_size = len(data)
        if prompt is None:
            prompt = [""] * len(data)

        text_input = \
            self.model.tokenizer(data, truncation=True, max_length=5120, padding=False, return_attention_mask=False)[
                "input_ids"]
        text_target = \
            self.model.tokenizer(answer, truncation=True, max_length=self.model.training_args.model_max_length,
                                 padding=False, return_attention_mask=False)["input_ids"]
        edge_input = \
            self.model.tokenizer(edge_data, truncation=True, max_length=self.model.training_args.model_max_length,
                                 padding=False, return_attention_mask=False)["input_ids"] if len(edge_data) > 0 else []

        text_target = [p + [self.model.tokenizer.eos_token_id] for p in text_target]
        target_ids = torch.cat([torch.tensor(p, dtype=torch.long) for p in text_target], dim=-1).to(cur_device)

        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        prompt_left_ids = [[1, 733, 16289, 28793]]
        prompt_right_ids = [[self.model.ft_token_id] + a + [733, 28748, 16289, 28793] if len(a) > 0 else a for a in
                            prompt_input]
        prompt_right_ids = torch.LongTensor([prompt_right_ids]).to(cur_device)

        text_ids = [a + b + self.mem_tokens + c + d for a, b, c, d in
                    zip(text_input, prompt_left_ids, prompt_input, text_target)]
        print(text_ids)
        target_mask = [[False] * (len(a) + self.mem_size + len(b) + len(c) - 1) + [True] * (len(d)) + [False] for
                       a, b, c, d in zip(text_input, prompt_left_ids, prompt_input, text_target)]
        edge_text_ids = [a + self.mem_tokens for a in edge_input]

        graph.num_node_feat = len(text_ids)
        print(graph.num_node_feat)

        input_ids = text_ids + edge_text_ids
        target_mask = target_mask + [[False] * len(a) for a in edge_text_ids]

        text_output = {"input_ids": input_ids, "attention_mask": target_mask}
        text_output = self.model.tokenizer.pad(text_output, padding=True, return_tensors="pt")
        input_ids = text_output["input_ids"].to(device=cur_device)
        target_mask = text_output["attention_mask"].to(torch.bool)
        mem_mask = torch.logical_and(input_ids >= self.model.vocab_size,
                                     input_ids < self.model.vocab_size + self.mem_size)

        mem_mask = mem_mask.to(cur_device)

        autoencoder_input_embedding = self.model.icae.get_base_model().model.embed_tokens(input_ids)
        autoencoder_input_embedding[mem_mask] = self.model.memory_token_embed(
            input_ids[mem_mask] - self.model.vocab_size).to(autoencoder_input_embedding)
        self.model.icae.enable_adapter_layers()
        print(autoencoder_input_embedding.shape)
        print('---' * 30)

        compress_outputs = self.model.icae(inputs_embeds=autoencoder_input_embedding, output_hidden_states=True,
                                           graph=graph, mem_mask=mem_mask, partial_grad=partial_grad, map_node=False)
        self.model.icae.disable_adapter_layers()

        compress_outputs = compress_outputs.logits

        return compress_outputs[target_mask], target_ids

    def encode(self, data, graph=None, partial_grad=None):
        cur_device = self.model.memory_token_embed.weight.device
        batch_size = len(data)
        text_output = \
        self.model.tokenizer(data, truncation=True, max_length=self.model.training_args.model_max_length, padding=False,
                             return_attention_mask=False)["input_ids"]

        text_output = [t + self.mem_tokens for t in text_output]
        text_output = {"input_ids": text_output}
        text_output = self.model.tokenizer.pad(text_output, padding=True, return_tensors="pt")["input_ids"].to(
            cur_device)
        mem_mask = text_output >= self.model.vocab_size

        mem_mask = mem_mask.to(cur_device)
        autoencoder_input_embedding = self.model.tokens_to_embeddings(text_output)

        self.model.icae.set_adapter("encadapt")
        self.model.icae.enable_adapter_layers()
        for name, param in self.model.icae.named_parameters():
            if "encadapt" in name:
                param.requires_grad = False
        compress_outputs = self.model.icae(inputs_embeds=autoencoder_input_embedding, output_hidden_states=True,
                                           graph=graph, mem_mask=mem_mask, partial_grad=partial_grad, map_node=True)
        self.model.icae.disable_adapter_layers()
        compress_outputs = compress_outputs.hidden_states[-1]

        if graph is not None:
            node_emb = compress_outputs[:len(graph.node_map)]
            map_mem_mask = mem_mask[:graph.num_node_feat][graph.node_map]
            memory_embedding = node_emb[map_mem_mask].view(len(node_emb), self.mem_size, -1)
        else:
            memory_embedding = compress_outputs[mem_mask].view(batch_size, self.mem_size, -1)
        return memory_embedding

    def decode(self, data, mem_embs, graph=None, prompt=None):
        prompt_output = self.model.tokenizer(data, add_special_tokens=False, padding=False, truncation=True,
                                             max_length=self.model.training_args.model_max_length)["input_ids"]
        prompt_output = [p + [self.model.tokenizer.eos_token_id] for p in prompt_output]
        original_prompt_output = prompt_output

        if prompt is None:
            prompt = [""] * len(data)
        prompt_input = self.model.left_tokenizer(prompt, add_special_tokens=False, padding=False, truncation=True, max_length=512)["input_ids"]
        batch_size = len(prompt_input)

        # For Mistral, decode contains: prefix, memory slots and suffix
        prompt_left_ids = [[1, 733, 16289, 28793] if len(a) > 0 else [] for a in prompt_input]
        prompt_right_ids = [[self.model.ft_token_id] + a + [733, 28748, 16289, 28793] if len(a) > 0 else a for a in
                            prompt_input]
        prompt_ids = [a + [self.model.tokenizer.pad_token_id] * self.mem_size + b + c for a, b, c in
                      zip(prompt_left_ids, prompt_right_ids, prompt_output)]
        prompt_mask = [
            [False] * (len(prompt_left_ids[i]) + self.mem_size - 1 + len(prompt_right_ids[i])) + [True] * len(
                prompt_output[i]) + [False] for i in range(batch_size)]

        answer_prompt = torch.cat([torch.tensor(p, dtype=torch.long) for p in prompt_output], dim=-1).to(
            mem_embs.device)

        prompt_output = {"input_ids": prompt_ids, "attention_mask": prompt_mask}
        prompt_output = self.model.tokenizer.pad(prompt_output, padding=True, return_tensors="pt")
        prompt_answer_ids = prompt_output["input_ids"].to(mem_embs.device)
        prompt_answer_embs = self.model.tokens_to_embeddings(prompt_answer_ids)

        mem_mask = [[False] * len(prompt_left_ids[i]) + [True] * self.mem_size + [False] * (
                len(prompt_output["input_ids"][i]) - len(prompt_left_ids[i]) - self.mem_size) for i in
                    range(batch_size)]
        prompt_mask = [
            [False] * (len(prompt_left_ids[i]) + self.mem_size - 1 + len(prompt_right_ids[i])) + [True] * len(
                original_prompt_output[i]) + [False] * (1 + len(prompt_output["input_ids"][i]) - len(prompt_ids[i])) for
            i in range(batch_size)]

        prompt_answer_embs[torch.tensor(mem_mask)] = mem_embs.view(-1, mem_embs.size()[-1])

        target_mask = torch.tensor(prompt_mask, dtype=torch.long, device=mem_embs.device).to(torch.bool)

        if self.dec_lora:
            self.model.icae.set_adapter("default")
            self.model.icae.enable_adapter_layers()
        else:
            self.model.icae.disable_adapter_layers()
        output_emb = self.model.icae(inputs_embeds=prompt_answer_embs).logits

        return output_emb, answer_prompt, target_mask

    def generate(self, mem_embs, graph=None, prompt=None):
        cur_device = self.model.memory_token_embed.weight.device

        if prompt is None:
            prompt = [""] * len(mem_embs)
        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        batch_size = len(prompt_input)

        prompt_left_ids = [[1, 733, 16289, 28793] if len(a) > 0 else [] for a in prompt_input]

        prompt_right_ids = [[self.model.ft_token_id] + a + [733, 28748, 16289, 28793] if len(a) > 0 else a for a in
                            prompt_input]

        mem_mask = [[False] * len(prompt_left_ids[i]) + [True] * self.mem_size + [False] * len(prompt_right_ids[i]) for
                    i in range(batch_size)]
        att_mask = [[True] * (len(prompt_left_ids[i]) + self.mem_size + len(prompt_right_ids[i])) for i in
                    range(batch_size)]
        prompt_ids = [prompt_left_ids[i] + [self.model.tokenizer.pad_token_id] * self.mem_size + prompt_right_ids[i] for
                      i in range(batch_size)]

        input_prompt_ids = self.model.left_tokenizer.pad({"input_ids": prompt_ids, "attention_mask": mem_mask},
                                                         padding=True, return_tensors="pt")
        mem_mask = input_prompt_ids["attention_mask"].to(device=mem_embs.device, dtype=torch.bool)

        input_prompt_ids = self.model.left_tokenizer.pad({"input_ids": prompt_ids, "attention_mask": att_mask},
                                                         padding=True, return_tensors="pt")
        prompt_ids = input_prompt_ids["input_ids"]
        att_mask = input_prompt_ids["attention_mask"].to(device=mem_embs.device)

        prompt_answer_ids = prompt_ids.to(device=mem_embs.device, dtype=torch.long)
        prompt_answer_embs = self.model.tokens_to_embeddings(prompt_answer_ids)
        prompt_answer_embs[mem_mask] = mem_embs.view(-1, mem_embs.size()[-1])

        decode_embed = prompt_answer_embs
        output = decode_embed.clone()

        generate_text = []
        eos_reached = torch.zeros(len(output), dtype=torch.bool).to(output.device)

        past_key_values = None
        if self.dec_lora:
            self.model.icae.set_adapter("default")
            self.model.icae.enable_adapter_layers()
        else:
            self.model.icae.disable_adapter_layers()
        for i in range(128):
            out = self.model.icae(inputs_embeds=output, attention_mask=att_mask, past_key_values=past_key_values,
                                 use_cache=True)

            logits = out.logits[:, -1, :self.model.vocab_size - 1]

            past_key_values = out.past_key_values

            next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

            eos_reached = torch.logical_or(eos_reached, (next_token_id == self.model.tokenizer.eos_token_id).view(-1))

            # eos_reached = torch.logical_or(eos_reached, (next_token_id==self.model.tokenizer.bos_token_id).view(-1))

            # eos_reached = torch.logical_or(eos_reached, (next_token_id>=32000).view(-1))

            output = self.model.icae.get_base_model().model.embed_tokens(next_token_id).to(mem_embs.device)

            generate_text.append(next_token_id.view(-1, 1))
            att_mask = torch.cat(
                [att_mask, torch.ones((len(att_mask), 1), dtype=att_mask.dtype, device=att_mask.device)], dim=-1)

            if torch.all(eos_reached):
                break

        generate_text = torch.cat(generate_text, dim=-1)
        generate_text[generate_text >= 32000] = 1

        generated_text = self.model.tokenizer.batch_decode(generate_text)

        return generated_text


class LlamaHelper(torch.nn.Module):
    def __init__(self, transformer_args):
        super().__init__()
        model_args, training_args, gofa_args = transformer_args
        model = LlamaLora(model_args, training_args, gofa_args)  # restored llama2-7b-chat model

        self.model = model
        self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
        self.model.left_tokenizer.pad_token = self.model.left_tokenizer.bos_token

    def get_tokenizer(self):
        return self.model.tokenizer

    def train_mode(self):
        # for param in self.model.dec.parameters():
        #     param.requires_grad = False
        pass

    def forward(self, data, input, prompt=None):
        cur_device = self.model.icae.get_base_model().model.embed_tokens.weight.device
        prompt_output = self.model.tokenizer(data, add_special_tokens=False, padding=False, truncation=True,
                                       max_length=self.model.training_args.model_max_length)["input_ids"]
        input_tokens = self.model.tokenizer(input, add_special_tokens=False, padding=False, truncation=True,
                                      max_length=self.model.training_args.model_max_length)["input_ids"]
        prompt_output = [p + [self.model.tokenizer.eos_token_id] for p in prompt_output]
        if prompt is None:
            prompt = [""] * len(data)
        prompt_input = self.model.tokenizer(prompt, add_special_tokens=False, padding=False)["input_ids"]
        prompt_ids = [a + b + c for a, b, c in zip(input_tokens, prompt_input, prompt_output)]
        prompt_mask = [[False] * (len(a) + len(b) - 1) + [True] * (len(c)) + [False] for a, b, c in
                       zip(input_tokens, prompt_input, prompt_output)]
        answer_prompt = torch.cat([torch.tensor(p, dtype=torch.long) for p in prompt_output], dim=-1).to(cur_device)

        prompt_output = {"input_ids": prompt_ids, "attention_mask": prompt_mask}
        prompt_output = self.model.tokenizer.pad(prompt_output, padding=True, return_tensors="pt")

        prompt_answer_ids = prompt_output["input_ids"].to(cur_device)

        target_mask = prompt_output["attention_mask"].to(cur_device).to(torch.bool)

        prompt_answer_embs = self.model.icae.get_base_model().model.embed_tokens(prompt_answer_ids)

        output_emb = self.model.icae(inputs_embeds=prompt_answer_embs).logits
        # for name, p in self.model.named_parameters():
        #     if "default" in name:
        #         print(p.abs().sum())
        #         break

        return output_emb, answer_prompt, target_mask

    def encode(self, data, input, prompt=None):
        raise NotImplementedError("no encdoe for llama")


    def decode(self, data, input, prompt=None):
        return self(data, input, prompt)

    def generate(self, input, prompt=None):
        cur_device = self.model.icae.get_base_model().model.embed_tokens.weight.device
        if prompt is None:
            prompt = [""] * len(input)
        prompt_ids = self.model.tokenizer(input, add_special_tokens=False, padding=False, truncation=True,
                                      max_length=self.model.training_args.model_max_length)["input_ids"]

        att_mask = [[True] * (len(a)) for a in prompt_ids]

        input_prompt_ids = self.model.tokenizer.pad({"input_ids": prompt_ids, "attention_mask": att_mask},
                                                         padding=True, return_tensors="pt")

        prompt_ids = input_prompt_ids["input_ids"]
        att_mask = input_prompt_ids["attention_mask"].to(device=cur_device)

        prompt_answer_ids = prompt_ids.to(device=cur_device, dtype=torch.long)

        with torch.no_grad():
            outputs = self.model.icae.generate(prompt_answer_ids, max_length=1024, num_return_sequences=1, pad_token_id = self.model.eos_id)

        generated_text = [self.model.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        generated_text = [self.extract_content_after_inst(t) for t in generated_text]

        return generated_text

    def extract_content_after_inst(self, generated_text):
        # Find the index of the closing tag [/INST]
        closing_tag = "[/INST]"
        start_index = generated_text.find(closing_tag)

        if start_index == -1:
            # If the closing tag is not found, return the entire text
            return generated_text

        # Extract the content after the closing tag
        content_after_inst = generated_text[start_index + len(closing_tag):].strip()

        return content_after_inst
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值