torchscript格式和safetensors格式区别

两种格式是什么?一句话对比

TorchScript (.pt / .ptl) SafeTensors (.safetensors)
存了什么 网络结构 + 权重 只有权重(张量数据)
谁维护 PyTorch 官方 Hugging Face
能否脱离 Python 类定义运行 ✅ 可以 ❌ 不行,必须配合模型类代码
主要用途 部署、跨语言调用(C++/Java/移动端) 模型共享、HuggingFace Hub 分发
安全性 ⚠️ 可能执行任意代码(pickle) ✅ 纯数据,无代码执行风险
加载速度 较慢(反序列化结构) 非常快(mmap 直接映射)

一句话总结:
SafeTensors 是一个”纯数据仓库”,只管存权重;TorchScript 是一个”完整可执行程序”,结构和权重都打包在里面,拿来就能跑。

内部文件结构对比

1、TorchScript

TorchScript .pt 本质是一个 ZIP 压缩包,可以直接解压查看:

1
2
3
4
5
6
7
8
model.pt(解压后)
├── code/
│ └── __torch__/
│ └── model.py # 编译后的 TorchScript IR 代码
├── data/
│ └── 0 # 权重二进制数据(pickle 格式)
├── constants.pkl # 常量
└── version # 格式版本号
1
2
# 可以直接解压验证
unzip -l model.pt
2、SafeTensors

SafeTensors 只存张量数据本身,不包含任何代码或结构描述:

  • 一个 JSON header:记录每个张量的名字、dtype、shape、在文件中的字节偏移
  • 紧跟其后的原始二进制数据
1
2
3
4
# 读取 safetensors 内容,直接拿到一个 {名字: 张量} 的字典
from safetensors.torch import load_file
state_dict = load_file("model.safetensors")
# {'encoder.weight': tensor(...), 'decoder.bias': tensor(...), ...}

这个字典就是 PyTorch 的 state_dict,需要配合模型类代码才能使用。

1
2
3
4
5
6
7
8
9
10
11
[ 8 bytes ] header 长度 N(小端 uint64)
[ N bytes ] JSON header,例如:
{
"encoder.weight": {
"dtype": "F32",
"shape": [768, 768],
"data_offsets": [0, 2359296]
},
...
}
[ 剩余字节 ] 所有张量的原始二进制数据(紧密排列)

加载时可以用 内存映射(mmap) 直接定位到某个张量的数据偏移,无需加载整个文件,速度极快,内存占用低。

SafeTensors为什么”安全”?
PyTorch 传统的 .pt / .pth state dict 用 pickle 序列化,pickle 可以在反序列化时执行任意 Python 代码,存在供应链攻击风险(恶意模型文件可以在 load 时运行任意代码)。
SafeTensors 的文件格式是纯数据,header 是 JSON,body 是原始字节,完全没有可执行代码,加载过程不会触发任何代码执行。

如何加载 TorchScript 模型进行推理

1、基础加载
1
2
3
4
import torch

model = torch.jit.load("model.pt")
print(model.eval())
2、CPU / GPU 指定
1
2
3
4
5
6
7
8
9
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("model.pt", "rb") as f:
model = torch.jit.load(f, map_location=device)

model.eval()
model.to(device) # 确保模型在目标设备
3、调用推理
1
2
3
4
5
6
7
import torch

with torch.no_grad(): # 关闭梯度计算,节省显存
output = model(input_tensor) # 直接调用,与普通 nn.Module 一样

# 返回结果转 numpy
result = output.cpu().numpy()
4、查看模型输入输出签名

TorchScript 模型自带类型信息,可以直接打印查看:

1
2
3
4
5
6
7
8
9
# 查看 forward 方法的签名
print(model.forward.schema)
# 例如:forward(__self: ..., image: Tensor, mask: Tensor) -> Tensor

# 查看所有方法
print(dir(model))

# 查看模型图(高级调试)
print(model.graph)

如何加载 SafeTensors 模型进行推理

SafeTensors 必须配合模型类代码使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 安装:pip install safetensors
from safetensors.torch import load_file
import torch

# 1. 定义或导入模型类(必须有)
from my_model import MyModel # 你自己的模型代码

# 2. 创建模型实例
model = MyModel()

# 3. 加载权重
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

# 4. 推理
with torch.no_grad():
output = model(input_tensor)

对于 Hugging Face 模型,更简单:

1
2
3
4
5
6
7
8
from transformers import AutoModel

# HuggingFace 会自动选择 safetensors(如果存在)
model = AutoModel.from_pretrained("bert-base-uncased")
model.eval()

with torch.no_grad():
output = model(**inputs)

SafeTensors → TorchScript 转换

1
2
3
4
5
6
SafeTensors 文件
↓ load_file() → state_dict
↓ model.load_state_dict()
↓ 准备示例输入
↓ torch.jit.trace() 或 torch.jit.script()
TorchScript .pt 文件

SafeTensors 本身不能直接转成 TorchScript,必须先还原成带权重的模型实例,再编译。

用 jit.trace 转换(推荐)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from safetensors.torch import load_file
from my_model import MyModel # 需要模型类定义

# 第一步:还原模型
model = MyModel()
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

# 第二步:准备示例输入(形状、dtype 必须与推理时完全一致)
example_input = torch.randn(1, 3, 512, 512)

# 第三步:trace 编译
with torch.no_grad():
traced_model = torch.jit.trace(model, example_input)

# 第四步:保存
traced_model.save("model_traced.pt")
print("转换完成,文件大小:", os.path.getsize("model_traced.pt") / 1e6, "MB")

多输入模型(如 LaMa 接受 image + mask):

1
2
3
4
5
6
7
8
# 多输入用 tuple 传给 trace
example_image = torch.randn(1, 3, 512, 512)
example_mask = torch.randn(1, 1, 512, 512)

with torch.no_grad():
traced_model = torch.jit.trace(model, (example_image, example_mask))

traced_model.save("model_traced.pt")

特殊:Hugging Face Transformers 模型的转换

Transformers 模型由于大量动态逻辑,推荐用 ONNX 或 TorchScript trace

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import BertModel, BertTokenizer
import torch

# 加载 HuggingFace 模型(自动用 safetensors)
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

# 准备示例输入
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello world", return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# Trace(注意:trace 会固化当前的序列长度,推理时序列长度需一致)
with torch.no_grad():
traced = torch.jit.trace(
model,
(input_ids, attention_mask),
strict=False # Transformers 模型需要 strict=False
)

traced.save("bert_traced.pt")

注意: 大型语言模型(LLM)结构复杂,包含大量动态控制流,TorchScript 转换困难,实际场景更推荐用 ONNXvLLM/TGI 部署。