diff --git a/containers/BasicTerm_ME_python/main.py b/containers/BasicTerm_ME_python/main.py index ba13c04..9e9cc0f 100644 --- a/containers/BasicTerm_ME_python/main.py +++ b/containers/BasicTerm_ME_python/main.py @@ -3,7 +3,7 @@ def main(): parser = argparse.ArgumentParser(description="Term ME model runner") parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points") - # add an argument that must be either "torch" or "jax" + # add an argument that must be either "torch_recursive" or "jax_iterative" parser.add_argument("--model", type=str, default="jax_iterative", choices=["torch_recursive", "jax_iterative"], help="Model to run") args = parser.parse_args()