问题描述
tensor(1, device='cuda:0')
File "/mnt/lab/xxx/code/InstantStyle-Plus/infer_style.py", line 326, in save_feature_map
cur_idx = idx_time_dict[time]
File "/mnt/lab/xxx/code/InstantStyle-Plus/infer_style.py", line 316, in save_feature_maps
save_feature_map(q, f"{feature_type}_{x}_{y}_self_attn_q", i)
File "/mnt/lab/xxxy/code/InstantStyle-Plus/infer_style.py", line 321, in save_feature_maps_callback
save_feature_maps(atte_model, i, "up_block")
File "/mnt/lab/xxx/code/InstantStyle-Plus/infer_style.py", line 301, in ddim_sampler_callback
save_feature_maps_callback(i)
File "/mnt/lab/xxx/code/InstantStyle-Plus/src/pipes/sdxl_inversion_pipeline.py", line 263, in __call__
img_callback(latents, t)
File "/mnt/lab/xxx/code/InstantStyle-Plus/inversion.py", line 41, in run
res = pipe_inversion(prompt = prompt,
File "/mnt/lab/xxx/code/InstantStyle-Plus/infer_style.py", line 354, in <module>
invert(style_image,
File "/mnt/lab/xxx/anaconda3/envs/InstantStyle-Plus/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/lab/xxx/anaconda3/envs/InstantStyle-Plus/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
KeyError: tensor(1, device='cuda:0')
问题分析
问题出在:
- time的引用:
cur_idx = idx_time_dict[time]
代码中引用time
,返回的是值tensor(1, device='cuda:0')
:
- 列表
idx_time_dict
的生成
for i, t in enumerate(timesteps):
idx_time_dict[t] = i
代码中引用timesteps
返回的是tensor(x, device='cuda:0')
:
问题解决
要实现只提取tensor(1, device='cuda:0')
中的值,而不提取其类型和device:
- 使用
.item()
提取单个tensor的值
cur_idx = idx_time_dict[time.item()]
- 使用
.tolist()
提取一系列tensor的值
for i, t in enumerate(timesteps.tolist()):
idx_time_dict[t] = i