Skip to content

Commit c90b85e

Browse files
- Add --skip-gguf flag to avoid forced pip install (fixes microsoft#498, microsoft#499) - Fix sys.exit(1) indentation bug in run_command() (fixes microsoft#447) - Change exit(0) to sys.exit(1) for unsupported arch - Fix ARCH_ALIAS KeyError using .get() for unknown architectures - Add guard for unsupported architectures in parse_args() - Fix same indentation bug in e2e_benchmark.py (fixes microsoft#504)
1 parent 01eb415 commit c90b85e

2 files changed

Lines changed: 15 additions & 6 deletions

File tree

setup_env.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@
8282
}
8383

8484
def system_info():
85-
return platform.system(), ARCH_ALIAS[platform.machine()]
85+
machine = platform.machine()
86+
arch = ARCH_ALIAS.get(machine, machine)
87+
return platform.system(), arch
8688

8789
def get_model_name():
8890
if args.hf_repo:
@@ -104,7 +106,7 @@ def run_command(command, shell=False, log_step=None):
104106
subprocess.run(command, shell=shell, check=True)
105107
except subprocess.CalledProcessError as e:
106108
logging.error(f"Error occurred while running command: {e}")
107-
sys.exit(1)
109+
sys.exit(1)
108110

109111
def prepare_model():
110112
_, arch = system_info()
@@ -149,7 +151,10 @@ def prepare_model():
149151
else:
150152
logging.info(f"GGUF model already exists at {gguf_path}")
151153

152-
def setup_gguf():
154+
def setup_gguf(skip=False):
155+
if skip:
156+
logging.info("Skipping GGUF pip installation (--skip-gguf flag set)")
157+
return
153158
# Install the pip package
154159
run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf")
155160

@@ -209,27 +214,31 @@ def compile():
209214
_, arch = system_info()
210215
if arch not in COMPILER_EXTRA_ARGS.keys():
211216
logging.error(f"Arch {arch} is not supported yet")
212-
exit(0)
217+
sys.exit(1)
213218
logging.info("Compiling the code using CMake.")
214219
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), []), "-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"], log_step="generate_build_files")
215220
# run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
216221
run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
217222

218223
def main():
219-
setup_gguf()
224+
setup_gguf(skip=args.skip_gguf)
220225
gen_code()
221226
compile()
222227
prepare_model()
223228

224229
def parse_args():
225230
_, arch = system_info()
231+
if arch not in SUPPORTED_QUANT_TYPES:
232+
logging.error(f"Architecture {arch} is not supported")
233+
sys.exit(1)
226234
parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
227235
parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys())
228236
parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models")
229237
parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs")
230238
parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s")
231239
parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16")
232240
parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters")
241+
parser.add_argument("--skip-gguf", action="store_true", help="Skip GGUF pip installation")
233242
return parser.parse_args()
234243

235244
def signal_handler(sig, frame):

utils/e2e_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_command(command, shell=False, log_step=None):
2020
subprocess.run(command, shell=shell, check=True)
2121
except subprocess.CalledProcessError as e:
2222
logging.error(f"Error occurred while running command: {e}")
23-
sys.exit(1)
23+
sys.exit(1)
2424

2525
def run_benchmark():
2626
build_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build")

0 commit comments

Comments
 (0)