diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 249a36dc..4de57416 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -1,6 +1,10 @@ import ast +import io import logging import re +import tokenize +from collections import deque +from typing import Dict, Set from traincheck.config.config import INSTR_MODULES_TO_INSTR @@ -502,6 +506,284 @@ def instrument_model_tracker_sampler( return source +def annotate_stage( + source: str, +) -> str: + """DEBT: Refactor the source tree exploration part with a AST-based approach""" + + def _ctx(msg: str) -> str: + return f"[annotate_stage] {msg}" + + def has_stage(src: str, name: str) -> bool: + return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None + + orig_has = { + "init": has_stage(source, "init"), + "training": has_stage(source, "training"), + "testing": has_stage(source, "testing"), + "checkpointing": has_stage(source, "checkpointing"), + } + orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source) + + for stage_name, present in orig_has.items(): + if present: + logger.info( + _ctx( + f"Stage '{stage_name}' already present in source; skip adding this stage." + ) + ) + + training_lines: Set[int] = set() + testing_lines: Set[int] = set() + checkpointing_lines: Set[int] = set() + + q: deque = deque(maxlen=3) + for tok in tokenize.generate_tokens(io.StringIO(source).readline): + q.append(tok) + if len(q) < 2: + continue + a = q[-3] if len(q) >= 3 else None + b = q[-2] + c = q[-1] + + def at_attr(name: str) -> bool: + return ( + a is not None + and a.type == tokenize.OP + and a.string == "." + and b.type == tokenize.NAME + and b.string == name + and c.type == tokenize.OP + and c.string == "(" + ) + + if (at_attr("train") or at_attr("step")) and not orig_has["training"]: + training_lines.add(b.start[0]) + + if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]: + testing_lines.add(b.start[0]) + + if at_attr("save") and not orig_has["checkpointing"]: + checkpointing_lines.add(b.start[0]) + + TRAINING_PRIORITY = 3 + TESTING_PRIORITY = 2 + CHECKPOINTING_PRIORITY = 1 + priority = { + "training": TRAINING_PRIORITY, + "testing": TESTING_PRIORITY, + "checkpointing": CHECKPOINTING_PRIORITY, + } + line_to_stage: Dict[int, str] = {} + for ln in checkpointing_lines: + line_to_stage[ln] = "checkpointing" + for ln in training_lines: + if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0): + line_to_stage[ln] = "training" + for ln in testing_lines: + if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0): + line_to_stage[ln] = "testing" + + lines = source.splitlines(keepends=True) + new_lines: list[str] = [] + inserted_count = { + "training": 0, + "testing": 0, + "checkpointing": 0, + "init": 0, + "import": 0, + } + for i, line in enumerate(lines): + lineno = i + 1 + stage = line_to_stage.get(lineno) + if stage: + k = len(new_lines) - 1 + while k >= 0 and new_lines[k].strip() == "": + k -= 1 + prev = new_lines[k] if k >= 0 else "" + if not ( + ("annotate_stage" in prev) + and (f'"{stage}"' in prev or f"'{stage}'" in prev) + ): + if (m := re.match(r"\s*", line)) is None: + raise ValueError("pattern not found") + indent = m.group(0) + new_lines.append(f'{indent}annotate_stage("{stage}")\n') + inserted_count[stage] += 1 + logger.info( + _ctx( + f"Inserted stage '{stage}' before line {lineno}: {line.strip()}" + ) + ) + else: + logger.info( + _ctx( + f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)." + ) + ) + new_lines.append(line) + + new_src = "".join(new_lines) + + def _find_annotate_import_idx(lines): + for idx, line in enumerate(lines): + if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", line): + return idx + return -1 + + lines_list = new_src.splitlines(keepends=True) + annot_import_idx = _find_annotate_import_idx(lines_list) + + if annot_import_idx == -1: + insert_idx = 0 + while insert_idx < len(lines_list): + s = lines_list[insert_idx].strip() + if ( + lines_list[insert_idx].startswith("#!") + or (s.startswith("#") and "coding" in s) + or s.startswith("from __future__ import") + ): + insert_idx += 1 + else: + break + lines_list.insert(insert_idx, "from traincheck import annotate_stage\n") + annot_import_idx = insert_idx + inserted_count["import"] += 1 + logger.info( + _ctx( + f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}." + ) + ) + + new_src = "".join(lines_list) + + if not orig_has["init"]: + has_guard = ( + re.search( + r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M + ) + is not None + ) + main_def = re.search( + r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M + ) + + if has_guard and main_def: + def_line_start = main_def.start() + before_def = new_src[:def_line_start] + def_line_idx = before_def.count("\n") + indent = main_def.group(1) + step = "\t" if ("\t" in indent and " " not in indent) else " " + body_indent = indent + step + + nl = new_src.splitlines(keepends=True) + insert_at = def_line_idx + 1 + while insert_at < len(nl) and nl[insert_at].strip() == "": + insert_at += 1 + + def _is_triple_quote(s: str) -> bool: + t = s.lstrip() + return t.startswith('"""') or t.startswith("'''") + + def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: + """Return True if the line is a single-line triple-quoted string using the given quote.""" + return line.count(quote) >= 2 and line.lstrip().startswith(quote) + + if insert_at < len(nl) and _is_triple_quote(nl[insert_at]): + quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''" + if is_single_line_triple_quoted_string(nl[insert_at], quote): + insert_at += 1 + else: + insert_at += 1 + while insert_at < len(nl): + if quote in nl[insert_at]: + insert_at += 1 + break + insert_at += 1 + + k = insert_at - 1 + while k >= 0 and nl[k].strip() == "": + k -= 1 + prev = nl[k] if k >= 0 else "" + if not (("annotate_stage" in prev) and ("init" in prev)): + nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n') + inserted_count["init"] += 1 + logger.info( + _ctx( + f"Inserted stage 'init' at start of main() body (line {insert_at + 1})." + ) + ) + else: + logger.info( + _ctx( + "Skip inserting 'init' inside main(): previous non-empty line already has it." + ) + ) + new_src = "".join(nl) + else: + lines2 = new_src.splitlines(keepends=True) + annot_import_idx = _find_annotate_import_idx(lines2) + if annot_import_idx == -1: + i = 0 + while i < len(lines2): + s = lines2[i].strip() + if ( + lines2[i].startswith("#!") + or (s.startswith("#") and "coding" in s) + or s.startswith("from __future__ import") + ): + i += 1 + else: + break + while i < len(lines2): + s = lines2[i].strip() + if ( + s.startswith("import ") + or s.startswith("from ") + or s == "" + or s.startswith("#") + ): + i += 1 + else: + break + insert_at = i + else: + insert_at = annot_import_idx + 1 + + k = insert_at + while k < len(lines2) and lines2[k].strip() == "": + k += 1 + next_line = lines2[k] if k < len(lines2) else "" + if not (("annotate_stage" in next_line) and ("init" in next_line)): + lines2.insert(insert_at, 'annotate_stage("init")\n') + inserted_count["init"] += 1 + logger.info( + _ctx( + f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}." + ) + ) + else: + logger.info( + _ctx( + "Skip inserting 'init': next non-empty line after annotate_stage import is already init." + ) + ) + + new_src = "".join(lines2) + + if "annotate_stage(" not in new_src and not orig_has_any: + logger.error( + _ctx( + "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." + ) + ) + raise RuntimeError( + _ctx("annotate_stage insertion failed; see logs for details.") + ) + + return new_src + + def instrument_file( path: str, modules_to_instr: list[str], @@ -532,7 +814,8 @@ def instrument_file( funcs_to_instr, API_dump_stack_trace, ) - + # annotate stages + instrumented_source = annotate_stage(instrumented_source) # logging configs logging_start_code = f""" import os diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index 57023a0c..c51d946d 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -183,6 +183,13 @@ def _rm_incomplete_trailing_func_calls(self): self.events.groupby("func_call_id").size().reset_index(name="count") ) + multiple_func_call_ids = func_call_groups[func_call_groups["count"] > 2][ + "func_call_id" + ] + assert ( + len(multiple_func_call_ids) == 0 + ), "more than 2 events for one func call id" + incomplete_func_call_ids = func_call_groups[func_call_groups["count"] == 1][ "func_call_id" ]