此文为作者在实际跑代码中所遇到的问题,然后尝试了多种方法后找到了一种可行方案,仅供参考。
报错:
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)
报错位置:
q, k, v = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)
具体原因仍未知,但我打印出了 query,w_q,b_q 这些量的形状,已确保是正确的,但仍不行。
解决方法:
尝试使用 torch.matmul 手动实现这个线性变换操作
q = torch.matmul(query, w_q) + b_q # 对应 F.linear
k = torch.matmul(key, w_k) + b_k
v = torch.matmul(value, w_v) + b_v
可行!