1010from ..confounds import CompCor , TCompCor , ACompCor
1111
1212
13+ def close_up_to_column_sign (a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
14+ """SVD can produce sign flips on a per-column basis."""
15+ kwargs = dict (rtol = rtol , atol = atol , equal_nan = equal_nan )
16+ if np .allclose (a , b , ** kwargs ):
17+ return True
18+
19+ ret = True
20+ for acol , bcol in zip (a .T , b .T ):
21+ ret &= np .allclose (acol , bcol , ** kwargs ) or np .allclose (acol , - bcol , ** kwargs )
22+ if not ret :
23+ break
24+
25+ return ret
26+
27+
28+ @pytest .mark .parametrize (
29+ "a, b, close" ,
30+ [
31+ ([[0.1 , 0.2 ], [0.3 , 0.4 ]], [[- 0.1 , 0.2 ], [- 0.3 , 0.4 ]], True ),
32+ ([[0.1 , 0.2 ], [0.3 , 0.4 ]], [[- 0.1 , 0.2 ], [0.3 , - 0.4 ]], False ),
33+ ],
34+ )
35+ def test_close_up_to_column_sign (a , b , close ):
36+ a = np .asanyarray (a )
37+ b = np .asanyarray (b )
38+ assert close_up_to_column_sign (a , b ) == close
39+ # Sign flips of all columns never changes result
40+ assert close_up_to_column_sign (a , - b ) == close
41+ assert close_up_to_column_sign (- a , b ) == close
42+ assert close_up_to_column_sign (- a , - b ) == close
43+ # Trivial case
44+ assert close_up_to_column_sign (a , a )
45+ assert close_up_to_column_sign (b , b )
46+
47+
1348class TestCompCor :
1449 """Note: Tests currently do a poor job of testing functionality"""
1550
@@ -42,11 +77,11 @@ def setup_class(self, tmpdir):
4277
4378 def test_compcor (self ):
4479 expected_components = [
45- [" -0.1989607212" , " -0.5753813646" ],
46- [" 0.5692369697" , " 0.5674945949" ],
47- [" -0.6662573243" , " 0.4675843432" ],
48- [" 0.4206466244" , " -0.3361270124" ],
49- [" -0.1246655485" , " -0.1235705610" ],
80+ [- 0.1989607212 , - 0.5753813646 ],
81+ [0.5692369697 , 0.5674945949 ],
82+ [- 0.6662573243 , 0.4675843432 ],
83+ [0.4206466244 , - 0.3361270124 ],
84+ [- 0.1246655485 , - 0.1235705610 ],
5085 ]
5186
5287 self .run_cc (
@@ -73,11 +108,11 @@ def test_compcor(self):
73108
74109 def test_compcor_variance_threshold_and_metadata (self ):
75110 expected_components = [
76- [" -0.2027150345" , " -0.4954813834" ],
77- [" 0.2565929051" , " 0.7866217875" ],
78- [" -0.3550986008" , " -0.0089784905" ],
79- [" 0.7512786244" , " -0.3599828482" ],
80- [" -0.4500578942" , " 0.0778209345" ],
111+ [- 0.2027150345 , - 0.4954813834 ],
112+ [0.2565929051 , 0.7866217875 ],
113+ [- 0.3550986008 , - 0.0089784905 ],
114+ [0.7512786244 , - 0.3599828482 ],
115+ [- 0.4500578942 , 0.0778209345 ],
81116 ]
82117 expected_metadata = {
83118 "component" : "CompCor00" ,
@@ -111,11 +146,11 @@ def test_tcompcor(self):
111146 self .run_cc (
112147 ccinterface ,
113148 [
114- [" -0.1114536190" , " -0.4632908609" ],
115- [" 0.4566907310" , " 0.6983205193" ],
116- [" -0.7132557407" , " 0.1340170559" ],
117- [" 0.5022537643" , " -0.5098322262" ],
118- [" -0.1342351356" , " 0.1407855119" ],
149+ [- 0.1114536190 , - 0.4632908609 ],
150+ [0.4566907310 , 0.6983205193 ],
151+ [- 0.7132557407 , 0.1340170559 ],
152+ [0.5022537643 , - 0.5098322262 ],
153+ [- 0.1342351356 , 0.1407855119 ],
119154 ],
120155 "tCompCor" ,
121156 )
@@ -138,11 +173,11 @@ def test_compcor_no_regress_poly(self):
138173 pre_filter = False ,
139174 ),
140175 [
141- [" 0.4451946442" , " -0.7683311482" ],
142- [" -0.4285129505" , " -0.0926034137" ],
143- [" 0.5721540256" , " 0.5608764842" ],
144- [" -0.5367548139" , " 0.0059943226" ],
145- [" -0.0520809054" , " 0.2940637551" ],
176+ [0.4451946442 , - 0.7683311482 ],
177+ [- 0.4285129505 , - 0.0926034137 ],
178+ [0.5721540256 , 0.5608764842 ],
179+ [- 0.5367548139 , 0.0059943226 ],
180+ [- 0.0520809054 , 0.2940637551 ],
146181 ],
147182 )
148183
@@ -225,27 +260,20 @@ def run_cc(
225260 assert os .path .getsize (expected_file ) > 0
226261
227262 with open (ccresult .outputs .components_file , "r" ) as components_file :
228- if expected_n_components is None :
229- expected_n_components = min (
230- ccinterface .inputs .num_components , self .fake_data .shape [3 ]
231- )
263+ header = components_file .readline ().rstrip ().split ("\t " )
264+ components_data = np .loadtxt (components_file , delimiter = "\t " )
265+
266+ if expected_n_components is None :
267+ expected_n_components = min (
268+ ccinterface .inputs .num_components , self .fake_data .shape [3 ]
269+ )
270+
271+ assert header == [
272+ f"{ expected_header } { i :02d} " for i in range (expected_n_components )
273+ ]
232274
233- components_data = [line .rstrip ().split ("\t " ) for line in components_file ]
234-
235- # the first item will be '#', we can throw it out
236- header = components_data .pop (0 )
237- expected_header = [
238- expected_header + "{:02d}" .format (i )
239- for i in range (expected_n_components )
240- ]
241- for i , heading in enumerate (header ):
242- assert expected_header [i ] in heading
243-
244- num_got_timepoints = len (components_data )
245- assert num_got_timepoints == self .fake_data .shape [3 ]
246- for index , timepoint in enumerate (components_data ):
247- assert len (timepoint ) == expected_n_components
248- assert timepoint [:2 ] == expected_components [index ]
275+ assert components_data .shape == (self .fake_data .shape [3 ], expected_n_components )
276+ assert close_up_to_column_sign (components_data [:, :2 ], expected_components )
249277
250278 if ccinterface .inputs .save_metadata :
251279 expected_metadata_file = ccinterface ._list_outputs ()["metadata_file" ]
0 commit comments