From 1faa1052a85a81efa482aa1a41a7d73daa55042a Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Mon, 1 Sep 2025 17:00:06 +0000 Subject: [PATCH 1/9] add annotate stage logic --- traincheck/instrumentor/source_file.py | 276 ++++++++++++++++++++++++- 1 file changed, 274 insertions(+), 2 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 249a36dc..cf40a983 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -1,8 +1,11 @@ import ast import logging import re - +import tokenize +import io from traincheck.config.config import INSTR_MODULES_TO_INSTR +from collections import deque +from typing import Dict, Set logger = logging.getLogger(__name__) @@ -502,6 +505,274 @@ def instrument_model_tracker_sampler( return source +def annotate_stage( + source: str, +) -> str: + + def _ctx(msg: str) -> str: + return f"[annotate_stage] {msg}" + + def has_stage(src: str, name: str) -> bool: + return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None + + orig_has = { + "init": has_stage(source, "init"), + "training": has_stage(source, "training"), + "testing": has_stage(source, "testing"), + "checkpointing": has_stage(source, "checkpointing"), + } + orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source) + + for stage_name, present in orig_has.items(): + if present: + logger.info( + _ctx( + f"Stage '{stage_name}' already present in source; skip adding this stage." + ) + ) + + training_lines: Set[int] = set() + testing_lines: Set[int] = set() + checkpointing_lines: Set[int] = set() + + q = deque(maxlen=3) + for tok in tokenize.generate_tokens(io.StringIO(source).readline): + q.append(tok) + if len(q) < 2: + continue + a = q[-3] if len(q) >= 3 else None + b = q[-2] + c = q[-1] + + def at_attr(name: str) -> bool: + return ( + a is not None + and a.type == tokenize.OP + and a.string == "." + and b.type == tokenize.NAME + and b.string == name + and c.type == tokenize.OP + and c.string == "(" + ) + + if (at_attr("train") or at_attr("step")) and not orig_has["training"]: + training_lines.add(b.start[0]) + + if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]: + testing_lines.add(b.start[0]) + + if at_attr("save") and not orig_has["checkpointing"]: + checkpointing_lines.add(b.start[0]) + + priority = {"training": 3, "testing": 2, "checkpointing": 1} + line_to_stage: Dict[int, str] = {} + for ln in checkpointing_lines: + line_to_stage[ln] = "checkpointing" + for ln in training_lines: + if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0): + line_to_stage[ln] = "training" + for ln in testing_lines: + if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0): + line_to_stage[ln] = "testing" + + lines = source.splitlines(keepends=True) + new_lines = [] + inserted_count = { + "training": 0, + "testing": 0, + "checkpointing": 0, + "init": 0, + "import": 0, + } + for i, line in enumerate(lines): + lineno = i + 1 + stage = line_to_stage.get(lineno) + if stage: + k = len(new_lines) - 1 + while k >= 0 and new_lines[k].strip() == "": + k -= 1 + prev = new_lines[k] if k >= 0 else "" + if not ( + ("annotate_stage" in prev) + and (f'"{stage}"' in prev or f"'{stage}'" in prev) + ): + indent = re.match(r"\s*", line).group(0) + new_lines.append(f'{indent}annotate_stage("{stage}")\n') + inserted_count[stage] += 1 + logger.info( + _ctx( + f"Inserted stage '{stage}' before line {lineno}: {line.strip()}" + ) + ) + else: + logger.info( + _ctx( + f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)." + ) + ) + new_lines.append(line) + + new_src = "".join(new_lines) + + def _find_annotate_import_idx(lines): + for idx, l in enumerate(lines): + if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", l): + return idx + return -1 + + lines_list = new_src.splitlines(keepends=True) + annot_import_idx = _find_annotate_import_idx(lines_list) + + if annot_import_idx == -1: + insert_idx = 0 + while insert_idx < len(lines_list): + s = lines_list[insert_idx].strip() + if ( + lines_list[insert_idx].startswith("#!") + or (s.startswith("#") and "coding" in s) + or s.startswith("from __future__ import") + ): + insert_idx += 1 + else: + break + lines_list.insert(insert_idx, "from traincheck import annotate_stage\n") + annot_import_idx = insert_idx + inserted_count["import"] += 1 + logger.info( + _ctx( + f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}." + ) + ) + + new_src = "".join(lines_list) + + if not orig_has["init"]: + has_guard = ( + re.search( + r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M + ) + is not None + ) + main_def = re.search( + r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M + ) + + if has_guard and main_def: + def_line_start = main_def.start() + before_def = new_src[:def_line_start] + def_line_idx = before_def.count("\n") + indent = main_def.group(1) + step = "\t" if ("\t" in indent and " " not in indent) else " " + body_indent = indent + step + + nl = new_src.splitlines(keepends=True) + insert_at = def_line_idx + 1 + while insert_at < len(nl) and nl[insert_at].strip() == "": + insert_at += 1 + + def _is_triple_quote(s: str) -> bool: + t = s.lstrip() + return t.startswith('"""') or t.startswith("'''") + + if insert_at < len(nl) and _is_triple_quote(nl[insert_at]): + quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''" + if nl[insert_at].count(quote) >= 2 and nl[ + insert_at + ].lstrip().startswith(quote): + insert_at += 1 + else: + insert_at += 1 + while insert_at < len(nl): + if quote in nl[insert_at]: + insert_at += 1 + break + insert_at += 1 + + k = insert_at - 1 + while k >= 0 and nl[k].strip() == "": + k -= 1 + prev = nl[k] if k >= 0 else "" + if not (("annotate_stage" in prev) and ("init" in prev)): + nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n') + inserted_count["init"] += 1 + logger.info( + _ctx( + f"Inserted stage 'init' at start of main() body (line {insert_at + 1})." + ) + ) + else: + logger.info( + _ctx( + "Skip inserting 'init' inside main(): previous non-empty line already has it." + ) + ) + new_src = "".join(nl) + else: + lines2 = new_src.splitlines(keepends=True) + annot_import_idx = _find_annotate_import_idx(lines2) + if annot_import_idx == -1: + i = 0 + while i < len(lines2): + s = lines2[i].strip() + if ( + lines2[i].startswith("#!") + or (s.startswith("#") and "coding" in s) + or s.startswith("from __future__ import") + ): + i += 1 + else: + break + while i < len(lines2): + s = lines2[i].strip() + if ( + s.startswith("import ") + or s.startswith("from ") + or s == "" + or s.startswith("#") + ): + i += 1 + else: + break + insert_at = i + else: + insert_at = annot_import_idx + 1 + + k = insert_at + while k < len(lines2) and lines2[k].strip() == "": + k += 1 + next_line = lines2[k] if k < len(lines2) else "" + if not (("annotate_stage" in next_line) and ("init" in next_line)): + lines2.insert(insert_at, 'annotate_stage("init")\n') + inserted_count["init"] += 1 + logger.info( + _ctx( + f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}." + ) + ) + else: + logger.info( + _ctx( + "Skip inserting 'init': next non-empty line after annotate_stage import is already init." + ) + ) + + new_src = "".join(lines2) + + if "annotate_stage(" not in new_src and not orig_has_any: + logger.error( + _ctx( + "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." + ) + ) + raise RuntimeError( + _ctx( + "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." + ) + ) + + return new_src + + def instrument_file( path: str, modules_to_instr: list[str], @@ -532,7 +803,8 @@ def instrument_file( funcs_to_instr, API_dump_stack_trace, ) - + # annotate stages + instrumented_source = annotate_stage(instrumented_source) # logging configs logging_start_code = f""" import os From 78a31a06d596c29bb5c70ef833932a07b4ba1da9 Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Mon, 1 Sep 2025 19:42:59 +0000 Subject: [PATCH 2/9] fix mypy error --- traincheck/instrumentor/source_file.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index cf40a983..9f57034e 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -535,7 +535,7 @@ def has_stage(src: str, name: str) -> bool: testing_lines: Set[int] = set() checkpointing_lines: Set[int] = set() - q = deque(maxlen=3) + q: deque = deque(maxlen=3) for tok in tokenize.generate_tokens(io.StringIO(source).readline): q.append(tok) if len(q) < 2: @@ -576,7 +576,7 @@ def at_attr(name: str) -> bool: line_to_stage[ln] = "testing" lines = source.splitlines(keepends=True) - new_lines = [] + new_lines: list[str] = [] inserted_count = { "training": 0, "testing": 0, @@ -596,7 +596,9 @@ def at_attr(name: str) -> bool: ("annotate_stage" in prev) and (f'"{stage}"' in prev or f"'{stage}'" in prev) ): - indent = re.match(r"\s*", line).group(0) + if (m := re.match(r"\s*", line)) is None: + raise ValueError("pattern not found") + indent = m.group(0) new_lines.append(f'{indent}annotate_stage("{stage}")\n') inserted_count[stage] += 1 logger.info( From c812b600e7401cca4fbe46e4ab0658557ed3910d Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Mon, 1 Sep 2025 20:36:52 +0000 Subject: [PATCH 3/9] fix import --- traincheck/instrumentor/source_file.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 9f57034e..50760ece 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -1,12 +1,13 @@ import ast +import io import logging import re import tokenize -import io -from traincheck.config.config import INSTR_MODULES_TO_INSTR from collections import deque from typing import Dict, Set +from traincheck.config.config import INSTR_MODULES_TO_INSTR + logger = logging.getLogger(__name__) """ From 436d41617fbfaeb77f5df5adad8ed6022c29e406 Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Tue, 2 Sep 2025 04:26:03 +0000 Subject: [PATCH 4/9] fix ruff error --- traincheck/instrumentor/source_file.py | 4 ++-- traincheck/trace/trace_pandas.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 50760ece..ce06979e 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -618,8 +618,8 @@ def at_attr(name: str) -> bool: new_src = "".join(new_lines) def _find_annotate_import_idx(lines): - for idx, l in enumerate(lines): - if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", l): + for idx, line in enumerate(lines): + if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", line): return idx return -1 diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index 57023a0c..4a4119eb 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -183,6 +183,11 @@ def _rm_incomplete_trailing_func_calls(self): self.events.groupby("func_call_id").size().reset_index(name="count") ) + multiple_func_call_ids = func_call_groups[func_call_groups["count"] >2][ + "func_call_id" + ] + assert len(multiple_func_call_ids)==0, "more than 2 events for one func call id" + incomplete_func_call_ids = func_call_groups[func_call_groups["count"] == 1][ "func_call_id" ] From 15c7cb42b441aa04a648968dc5d2ae958c330299 Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Tue, 2 Sep 2025 04:42:14 +0000 Subject: [PATCH 5/9] fix format error --- traincheck/trace/trace_pandas.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index 4a4119eb..a2372e98 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -183,11 +183,13 @@ def _rm_incomplete_trailing_func_calls(self): self.events.groupby("func_call_id").size().reset_index(name="count") ) - multiple_func_call_ids = func_call_groups[func_call_groups["count"] >2][ + multiple_func_call_ids = func_call_groups[func_call_groups["count"] > 2][ "func_call_id" ] - assert len(multiple_func_call_ids)==0, "more than 2 events for one func call id" - + assert ( + len(multiple_func_call_ids) == 0 + ), "more than 2 events for one func call id" + incomplete_func_call_ids = func_call_groups[func_call_groups["count"] == 1][ "func_call_id" ] @@ -866,9 +868,9 @@ def get_var_insts(self) -> dict[VarInstId, dict[str, list[AttrState]]]: safe_isnan(attr_values[attr_name][-1].value) and safe_isnan(curr_value) ): - attr_values[attr_name][-1].liveness.end_time = ( - state_change["time"] - ) + attr_values[attr_name][ + -1 + ].liveness.end_time = state_change["time"] attr_values[attr_name].append( AttrState( curr_value, From 197503a4266bc27095043ae6dd32734cd74fa394 Mon Sep 17 00:00:00 2001 From: liurt1218 Date: Tue, 2 Sep 2025 07:47:37 +0000 Subject: [PATCH 6/9] fix format error --- traincheck/trace/trace_pandas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index a2372e98..c51d946d 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -868,9 +868,9 @@ def get_var_insts(self) -> dict[VarInstId, dict[str, list[AttrState]]]: safe_isnan(attr_values[attr_name][-1].value) and safe_isnan(curr_value) ): - attr_values[attr_name][ - -1 - ].liveness.end_time = state_change["time"] + attr_values[attr_name][-1].liveness.end_time = ( + state_change["time"] + ) attr_values[attr_name].append( AttrState( curr_value, From ff2bdcc92b90673f75fe906af0b5328e8edecdaf Mon Sep 17 00:00:00 2001 From: Yuxuan <31838999+Essoz@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:13:10 -0400 Subject: [PATCH 7/9] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- traincheck/instrumentor/source_file.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index ce06979e..7b2884f5 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -565,7 +565,14 @@ def at_attr(name: str) -> bool: if at_attr("save") and not orig_has["checkpointing"]: checkpointing_lines.add(b.start[0]) - priority = {"training": 3, "testing": 2, "checkpointing": 1} + TRAINING_PRIORITY = 3 + TESTING_PRIORITY = 2 + CHECKPOINTING_PRIORITY = 1 + priority = { + "training": TRAINING_PRIORITY, + "testing": TESTING_PRIORITY, + "checkpointing": CHECKPOINTING_PRIORITY, + } line_to_stage: Dict[int, str] = {} for ln in checkpointing_lines: line_to_stage[ln] = "checkpointing" @@ -677,11 +684,13 @@ def _is_triple_quote(s: str) -> bool: t = s.lstrip() return t.startswith('"""') or t.startswith("'''") + def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: + """Return True if the line is a single-line triple-quoted string using the given quote.""" + return line.count(quote) >= 2 and line.lstrip().startswith(quote) + if insert_at < len(nl) and _is_triple_quote(nl[insert_at]): quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''" - if nl[insert_at].count(quote) >= 2 and nl[ - insert_at - ].lstrip().startswith(quote): + if is_single_line_triple_quoted_string(nl[insert_at], quote): insert_at += 1 else: insert_at += 1 @@ -767,9 +776,7 @@ def _is_triple_quote(s: str) -> bool: "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." ) ) - raise RuntimeError( - _ctx( - "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." + "annotate_stage insertion failed; see logs for details." ) ) From 95465bcf5fe9da1ec8834ed0aa29ccb5652a37cf Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Tue, 2 Sep 2025 10:19:07 -0400 Subject: [PATCH 8/9] fix errors by copilot --- traincheck/instrumentor/source_file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 7b2884f5..3012ae47 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -776,8 +776,8 @@ def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." ) ) - "annotate_stage insertion failed; see logs for details." - ) + raise RuntimeError( + _ctx("annotate_stage insertion failed; see logs for details.") ) return new_src From f9d389033c0edd1d50723e8fec69fbb346474e58 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Tue, 2 Sep 2025 10:22:33 -0400 Subject: [PATCH 9/9] add information w.r.t technical debt --- traincheck/instrumentor/source_file.py | 1 + 1 file changed, 1 insertion(+) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 3012ae47..4de57416 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -509,6 +509,7 @@ def instrument_model_tracker_sampler( def annotate_stage( source: str, ) -> str: + """DEBT: Refactor the source tree exploration part with a AST-based approach""" def _ctx(msg: str) -> str: return f"[annotate_stage] {msg}"