P-Tuning 源码分析
class PromptEncoder(torch.nn.Module):
"""
```py
>>> from peft import PromptEncoder, PromptEncoderConfig
>>> config = PromptEncoderConfig(
... peft_type="P_TUNING",
... task_type="SEQ_2_SEQ_LM",
... num_virtual_tokens=20,
... token_dim=768,
... num_transformer_submodules=1,
... num_attention_heads=12,
... num_layers=12,
... encoder_reparameterization_type="MLP",
... encoder_hidden_size=768,
... )
>>> prompt_encoder = PromptEncoder(config)
```
"""
def __init__(self, config):
super().__init__()
self.token_dim = config.token_dim
self.input_size = self.token_dim
self.output_size = self.token_dim
self.hidden_size = config.encoder_hidden_size
self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.encoder_type = config.encoder_reparameterization_type
self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
if not config.inference_mode:
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
lstm_dropout = config.encoder_dropout
num_layers = config.encoder_num_layers
self.lstm_head = torch.nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
dropout=lstm_dropout,
bidirectional=True,
batch_first=True,
)
self.mlp_head = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size * 2, self.output_size),
)
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers
if config.encoder_num_layers != encoder_num_layers_default:
warnings.warn(
f"for {self.encoder_type}, the argument `encoder_num_layers` is ignored. "
f"Exactly {encoder_num_layers_default} MLP layers are used."
)
layers = [
torch.nn.Linear(self.input_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.output_size),
]
self.mlp_head = torch.nn.Sequential(*layers)
else:
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
def forward(self, indices):
input_embeds = self.embedding(indices)
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
output_embeds = self.mlp_head(input_embeds)
else:
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
return output_embeds