fig, axes = plt.subplots(1, top_k, figsize=(15, 3))
for i, (idx, score) in enumerate(zip(indices, values)):
# Print text and score
print(f"{data['text'][idx]}: {score:.3f}")
# Display image
axes[i].imshow(data['image'][idx])
axes[i].axis('off')
axes[i].set_title(f"Score: {score:.3f}")
plt.tight_layout()
plt.show()