Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions openmc/deplete/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,8 @@ def solver(self, func):
return

# Inspect arguments
if len(sig.parameters) != 3:
raise ValueError("Function {} does not support three arguments: "
if len(sig.parameters) < 3:
raise ValueError("Function {} does not support less than three arguments: "
"{!s}".format(func, sig))

for ix, param in enumerate(sig.parameters.values()):
Expand All @@ -729,11 +729,12 @@ def solver(self, func):

self._solver = func

def _timed_deplete(self, n, rates, dt, i=None, matrix_func=None):
def _timed_deplete(self, n, rates, dt, i=None, matrix_func=None,
use_cache=False):
start = time.time()
results = deplete(
self._solver, self.chain, n, rates, dt, i, matrix_func,
self.transfer_rates, self.external_source_rates)
self.transfer_rates, self.external_source_rates, use_cache=use_cache)
return time.time() - start, results

@abstractmethod
Expand All @@ -743,7 +744,8 @@ def __call__(
rates: ReactionRates,
dt: float,
source_rate: float,
i: int
i: int,
use_cache=False
):
"""Perform the integration across one time step

Expand Down Expand Up @@ -871,7 +873,10 @@ def integrate(
n = self.operator.initial_condition()
t, self._i_res = self._get_start_data()

prev_dt = None
prev_source_rate = None
for i, (dt, source_rate) in enumerate(self):
use_cache = (prev_dt == dt) and (prev_source_rate == source_rate)
if output and comm.rank == 0:
print(f"[openmc.deplete] t={t} s, dt={dt} s, source={source_rate}")

Expand All @@ -882,7 +887,7 @@ def integrate(
n, res = self._get_bos_data_from_restart(source_rate, n)

# Solve Bateman equations over time interval
proc_time, n_end = self(n, res.rates, dt, source_rate, i)
proc_time, n_end = self(n, res.rates, dt, source_rate, i, use_cache=use_cache)

StepResult.save(
self.operator,
Expand All @@ -898,6 +903,8 @@ def integrate(

# Update for next step
n = n_end
prev_dt = dt
prev_source_rate = source_rate
t += dt

# Final simulation -- in the case that final_step is False, a zero
Expand Down
13 changes: 8 additions & 5 deletions openmc/deplete/cram.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def __init__(self, alpha, theta, alpha0):
self.alpha = alpha
self.theta = theta
self.alpha0 = alpha0
self._splu_cache = None

def __call__(self, A, n0, dt):
def __call__(self, A, n0, dt, use_cache=False):
"""Solve depletion equations using IPF CRAM

Parameters
Expand All @@ -75,11 +76,13 @@ def __call__(self, A, n0, dt):
Final compositions after ``dt``

"""
A = dt * csc_array(A, dtype=np.float64)
y = n0.copy()
ident = eye_array(A.shape[0], format='csc')
for alpha, theta in zip(self.alpha, self.theta):
y += 2*np.real(alpha*sla.spsolve(A - theta*ident, y))
if not use_cache or not self._splu_cache:
A = dt * csc_array(A, dtype=np.float64)
ident = eye_array(A.shape[0], format='csc')
self._splu_cache = [sla.splu(A - theta*ident) for alpha, theta in zip(self.alpha, self.theta)]
for alpha, splu in zip(self.alpha, self._splu_cache):
y += 2*np.real(alpha*splu.solve(y))
return y * self.alpha0


Expand Down
18 changes: 9 additions & 9 deletions openmc/deplete/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PredictorIntegrator(Integrator):
"""
_num_stages = 1

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand All @@ -50,7 +50,7 @@ def __call__(self, n, rates, dt, source_rate, _i=None):
Concentrations at end of interval

"""
proc_time, n_end = self._timed_deplete(n, rates, dt, _i)
proc_time, n_end = self._timed_deplete(n, rates, dt, _i, use_cache=use_cache)
return proc_time, n_end


Expand All @@ -74,7 +74,7 @@ class CECMIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Integrate using CE/CM

Parameters
Expand Down Expand Up @@ -135,7 +135,7 @@ class CF4Integrator(Integrator):
"""
_num_stages = 4

def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None):
def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down Expand Up @@ -207,7 +207,7 @@ class CELIIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n_bos, rates, dt, source_rate, _i=None):
def __call__(self, n_bos, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down Expand Up @@ -269,7 +269,7 @@ class EPCRK4Integrator(Integrator):
"""
_num_stages = 4

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down Expand Up @@ -347,7 +347,7 @@ class LEQIIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, i):
def __call__(self, n_bos, bos_rates, dt, source_rate, i, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down Expand Up @@ -423,7 +423,7 @@ class SICELIIntegrator(SIIntegrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None):
def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down Expand Up @@ -488,7 +488,7 @@ class SILEQIIntegrator(SIIntegrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, i):
def __call__(self, n_bos, bos_rates, dt, source_rate, i, use_cache=False):
"""Perform the integration across one time step

Parameters
Expand Down
6 changes: 3 additions & 3 deletions openmc/deplete/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _distribute(items):
j += chunk_size

def deplete(func, chain, n, rates, dt, current_timestep=None, matrix_func=None,
transfer_rates=None, external_source_rates=None, *matrix_args):
transfer_rates=None, external_source_rates=None, *matrix_args, use_cache=False):
"""Deplete materials using given reaction rates for a specified time

Parameters
Expand Down Expand Up @@ -164,7 +164,7 @@ def deplete(func, chain, n, rates, dt, current_timestep=None, matrix_func=None,

# Concatenate vectors of nuclides in one
n_multi = np.concatenate(n)
n_result = func(matrix, n_multi, dt)
n_result = func(matrix, n_multi, dt, use_cache=use_cache)

# Split back the nuclide vector result into the original form
n_result = np.split(n_result, np.cumsum([len(i) for i in n])[:-1])
Expand Down Expand Up @@ -198,7 +198,7 @@ def deplete(func, chain, n, rates, dt, current_timestep=None, matrix_func=None,
matrix.resize(matrix.shape[1], matrix.shape[1])
n[i] = np.append(n[i], 1.0)

inputs = zip(matrices, n, repeat(dt))
inputs = zip(matrices, n, repeat(dt), repeat(use_cache))

if USE_MULTIPROCESSING:
with Pool(NUM_PROCESSES) as pool:
Expand Down
Loading