diff --git a/.github/workflows/correctness_checks.yml b/.github/workflows/correctness_checks.yml
new file mode 100644
index 00000000..3574fca3
--- /dev/null
+++ b/.github/workflows/correctness_checks.yml
@@ -0,0 +1,52 @@
+name: Correctness Checks
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ paths:
+ - '.github/workflows/**'
+ - 'traincheck/invariant/**'
+ - 'traincheck/onlinechecker/**'
+ - 'traincheck/checker_online.py'
+ - 'traincheck/checker.py'
+
+permissions:
+ contents: write
+ deployments: write
+ pull-requests: write
+
+jobs:
+ correctness-check:
+ runs-on: self-hosted
+
+ name: Run Correctness Checks
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/checkout@v4
+ name: Checkout TrainCheck-Benchmarks
+ with:
+ repository: OrderLab/TrainCheck-Benchmarks
+ path: benchmarks
+ lfs: true
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e .
+
+ - name: Run correctness script
+ run: |
+ cd benchmarks/correctness_check
+ python3 correct_check.py
+
+ - name: Clear check files
+ run: |
+ cd benchmarks/correctness_check
+ rm -rf trace_*
+ rm -rf traincheck_*
diff --git a/README.md b/README.md
index 30c8ffad..6812c1bb 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,7 @@
TrainCheck: Training with Confidence
[](https://github.com/OrderLab/traincheck/actions/workflows/pre-commit-checks.yml)
+[](https://github.com/OrderLab/traincheck/actions/workflows/correctness_checks.yml)
[](https://discord.gg/ZvYewjsQ9D)
diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py
index 570ed42f..b0640dff 100644
--- a/traincheck/invariant/consistency_transient_vars.py
+++ b/traincheck/invariant/consistency_transient_vars.py
@@ -1034,7 +1034,7 @@ def online_check(
if input_value != output_value:
check_passed = False
except (IndexError, KeyError):
- logger.warning(
+ logger.debug(
f"Could not find the value to be checked in input or output tensors for the hypothesis {inv}, skipping this function call."
)
diff --git a/traincheck/onlinechecker/streamhandler_filesystem.py b/traincheck/onlinechecker/streamhandler_filesystem.py
index 310fc70e..2b69c5cc 100644
--- a/traincheck/onlinechecker/streamhandler_filesystem.py
+++ b/traincheck/onlinechecker/streamhandler_filesystem.py
@@ -184,6 +184,7 @@ def _set_func_map(self, trace_record):
self.pt_map[ptname][func_call_id].args = trace_record["args"]
self.pt_map[ptname][func_call_id].kwargs = trace_record["kwargs"]
elif trace_type == TraceLineType.FUNC_CALL_POST:
+ assert self.pt_map[ptname][func_call_id].pre_record is not None
self.pt_map[ptname][func_call_id].post_record = trace_record
self.pt_map[ptname][func_call_id].return_values = trace_record[
"return_values"
@@ -324,7 +325,7 @@ def run_stream_monitor(traces, trace_folders, checker_data: Checker_data):
if trace_folders is not None:
for trace_folder in trace_folders:
- for file in os.listdir(trace_folder):
+ for file in sorted(os.listdir(trace_folder)):
if file.startswith("trace_") or file.endswith("proxy_log.json"):
file_path = os.path.join(trace_folder, file)
handler = StreamLogHandler(file_path, checker_data)