Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 284 additions & 1 deletion traincheck/instrumentor/source_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import ast
import io
import logging
import re
import tokenize
from collections import deque
from typing import Dict, Set

from traincheck.config.config import INSTR_MODULES_TO_INSTR

Expand Down Expand Up @@ -502,6 +506,284 @@ def instrument_model_tracker_sampler(
return source


def annotate_stage(
source: str,
) -> str:
"""DEBT: Refactor the source tree exploration part with a AST-based approach"""

def _ctx(msg: str) -> str:
return f"[annotate_stage] {msg}"

def has_stage(src: str, name: str) -> bool:
return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None

orig_has = {
"init": has_stage(source, "init"),
"training": has_stage(source, "training"),
"testing": has_stage(source, "testing"),
"checkpointing": has_stage(source, "checkpointing"),
}
orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source)

for stage_name, present in orig_has.items():
if present:
logger.info(
_ctx(
f"Stage '{stage_name}' already present in source; skip adding this stage."
)
)

training_lines: Set[int] = set()
testing_lines: Set[int] = set()
checkpointing_lines: Set[int] = set()

q: deque = deque(maxlen=3)
for tok in tokenize.generate_tokens(io.StringIO(source).readline):
q.append(tok)
if len(q) < 2:
continue
a = q[-3] if len(q) >= 3 else None
b = q[-2]
c = q[-1]

def at_attr(name: str) -> bool:
return (
a is not None
and a.type == tokenize.OP
and a.string == "."
and b.type == tokenize.NAME
and b.string == name
and c.type == tokenize.OP
and c.string == "("
)

if (at_attr("train") or at_attr("step")) and not orig_has["training"]:
training_lines.add(b.start[0])

if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]:
testing_lines.add(b.start[0])

if at_attr("save") and not orig_has["checkpointing"]:
checkpointing_lines.add(b.start[0])

TRAINING_PRIORITY = 3
TESTING_PRIORITY = 2
CHECKPOINTING_PRIORITY = 1
priority = {
"training": TRAINING_PRIORITY,
"testing": TESTING_PRIORITY,
"checkpointing": CHECKPOINTING_PRIORITY,
}
line_to_stage: Dict[int, str] = {}
for ln in checkpointing_lines:
line_to_stage[ln] = "checkpointing"
for ln in training_lines:
if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0):
line_to_stage[ln] = "training"
for ln in testing_lines:
if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0):
line_to_stage[ln] = "testing"

lines = source.splitlines(keepends=True)
new_lines: list[str] = []
inserted_count = {
"training": 0,
"testing": 0,
"checkpointing": 0,
"init": 0,
"import": 0,
}
for i, line in enumerate(lines):
lineno = i + 1
stage = line_to_stage.get(lineno)
if stage:
k = len(new_lines) - 1
while k >= 0 and new_lines[k].strip() == "":
k -= 1
prev = new_lines[k] if k >= 0 else ""
if not (
("annotate_stage" in prev)
and (f'"{stage}"' in prev or f"'{stage}'" in prev)
):
if (m := re.match(r"\s*", line)) is None:
raise ValueError("pattern not found")
indent = m.group(0)
new_lines.append(f'{indent}annotate_stage("{stage}")\n')
inserted_count[stage] += 1
logger.info(
_ctx(
f"Inserted stage '{stage}' before line {lineno}: {line.strip()}"
)
)
else:
logger.info(
_ctx(
f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)."
)
)
new_lines.append(line)

new_src = "".join(new_lines)

def _find_annotate_import_idx(lines):
for idx, line in enumerate(lines):
if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", line):
return idx
return -1

lines_list = new_src.splitlines(keepends=True)
annot_import_idx = _find_annotate_import_idx(lines_list)

if annot_import_idx == -1:
insert_idx = 0
while insert_idx < len(lines_list):
s = lines_list[insert_idx].strip()
if (
lines_list[insert_idx].startswith("#!")
or (s.startswith("#") and "coding" in s)
or s.startswith("from __future__ import")
):
insert_idx += 1
else:
break
lines_list.insert(insert_idx, "from traincheck import annotate_stage\n")
annot_import_idx = insert_idx
inserted_count["import"] += 1
logger.info(
_ctx(
f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}."
)
)

new_src = "".join(lines_list)

if not orig_has["init"]:
has_guard = (
re.search(
r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M
)
is not None
)
main_def = re.search(
r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M
)

if has_guard and main_def:
def_line_start = main_def.start()
before_def = new_src[:def_line_start]
def_line_idx = before_def.count("\n")
indent = main_def.group(1)
step = "\t" if ("\t" in indent and " " not in indent) else " "
body_indent = indent + step

nl = new_src.splitlines(keepends=True)
insert_at = def_line_idx + 1
while insert_at < len(nl) and nl[insert_at].strip() == "":
insert_at += 1

def _is_triple_quote(s: str) -> bool:
t = s.lstrip()
return t.startswith('"""') or t.startswith("'''")

def is_single_line_triple_quoted_string(line: str, quote: str) -> bool:
"""Return True if the line is a single-line triple-quoted string using the given quote."""
return line.count(quote) >= 2 and line.lstrip().startswith(quote)

if insert_at < len(nl) and _is_triple_quote(nl[insert_at]):
quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''"
if is_single_line_triple_quoted_string(nl[insert_at], quote):
insert_at += 1
else:
insert_at += 1
while insert_at < len(nl):
if quote in nl[insert_at]:
insert_at += 1
break
insert_at += 1

k = insert_at - 1
while k >= 0 and nl[k].strip() == "":
k -= 1
prev = nl[k] if k >= 0 else ""
if not (("annotate_stage" in prev) and ("init" in prev)):
nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n')
inserted_count["init"] += 1
logger.info(
_ctx(
f"Inserted stage 'init' at start of main() body (line {insert_at + 1})."
)
)
else:
logger.info(
_ctx(
"Skip inserting 'init' inside main(): previous non-empty line already has it."
)
)
new_src = "".join(nl)
else:
lines2 = new_src.splitlines(keepends=True)
annot_import_idx = _find_annotate_import_idx(lines2)
if annot_import_idx == -1:
i = 0
while i < len(lines2):
s = lines2[i].strip()
if (
lines2[i].startswith("#!")
or (s.startswith("#") and "coding" in s)
or s.startswith("from __future__ import")
):
i += 1
else:
break
while i < len(lines2):
s = lines2[i].strip()
if (
s.startswith("import ")
or s.startswith("from ")
or s == ""
or s.startswith("#")
):
i += 1
else:
break
insert_at = i
else:
insert_at = annot_import_idx + 1

k = insert_at
while k < len(lines2) and lines2[k].strip() == "":
k += 1
next_line = lines2[k] if k < len(lines2) else ""
if not (("annotate_stage" in next_line) and ("init" in next_line)):
lines2.insert(insert_at, 'annotate_stage("init")\n')
inserted_count["init"] += 1
logger.info(
_ctx(
f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}."
)
)
else:
logger.info(
_ctx(
"Skip inserting 'init': next non-empty line after annotate_stage import is already init."
)
)

new_src = "".join(lines2)

if "annotate_stage(" not in new_src and not orig_has_any:
logger.error(
_ctx(
"Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required."
)
)
raise RuntimeError(
_ctx("annotate_stage insertion failed; see logs for details.")
)

return new_src


def instrument_file(
path: str,
modules_to_instr: list[str],
Expand Down Expand Up @@ -532,7 +814,8 @@ def instrument_file(
funcs_to_instr,
API_dump_stack_trace,
)

# annotate stages
instrumented_source = annotate_stage(instrumented_source)
# logging configs
logging_start_code = f"""
import os
Expand Down
7 changes: 7 additions & 0 deletions traincheck/trace/trace_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ def _rm_incomplete_trailing_func_calls(self):
self.events.groupby("func_call_id").size().reset_index(name="count")
)

multiple_func_call_ids = func_call_groups[func_call_groups["count"] > 2][
"func_call_id"
]
assert (
len(multiple_func_call_ids) == 0
), "more than 2 events for one func call id"

incomplete_func_call_ids = func_call_groups[func_call_groups["count"] == 1][
"func_call_id"
]
Expand Down
Loading