Skip to content

Commit abe61c5

Browse files
committed
[Script] Add mising no_grad context
1 parent e2b0867 commit abe61c5

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

experiments/BERT.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from Simulator.simulator import TOGSimulator
1010

11-
config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml')
11+
config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml')
1212
os.environ['TOGSIM_CONFIG'] = config
1313

1414
# Try Fusion EncoderBlock first, fall back to standard test_transformer
@@ -36,7 +36,7 @@
3636
model_input = torch.randn(args.input_size, hidden_dim).to(device=device)
3737
opt_fn = torch.compile(dynamic=False)(model)
3838

39-
with TOGSimulator(config_path=config):
39+
with TOGSimulator(config_path=config), torch.no_grad():
4040
torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0)
4141
torch.npu.synchronize()
4242
print(f"BERT-{args.size} Simulation Done")

experiments/resnet18.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
opt_fn = torch.compile(dynamic=False)(model)
2323
model_input = torch.randn(args.batch, 3, 224, 224).to(device=device)
2424

25-
with TOGSimulator(config_path=config):
25+
with TOGSimulator(config_path=config), torch.no_grad():
2626
torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0)
2727
torch.npu.synchronize()
2828
print("ResNet18 Simulation Done")

experiments/resnet50.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
opt_fn = torch.compile(dynamic=False)(model)
2323
model_input = torch.randn(args.batch, 3, 224, 224).to(device=device)
2424

25-
with TOGSimulator(config_path=config):
25+
with TOGSimulator(config_path=config), torch.no_grad():
2626
torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0)
2727
torch.npu.synchronize()
2828
print("ResNet50 Simulation Done")

0 commit comments

Comments
 (0)