Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/ModelingToolkitTearing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ BipartiteGraphs = "caf10ac8-0290-4205-88aa-f15908547e8d"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ModelingToolkitBase = "7771a370-6774-4173-bd38-47e70ca0b839"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand All @@ -25,6 +26,7 @@ BipartiteGraphs = "0.1.3"
CommonSolve = "0.2"
DocStringExtensions = "0.7, 0.8, 0.9"
Graphs = "1"
LinearAlgebra = "1"
ModelingToolkitBase = "1.2.0"
Moshi = "0.3"
OffsetArrays = "1"
Expand Down
100 changes: 95 additions & 5 deletions lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using Symbolics: SymbolicT, VartypeT
using SymbolicUtils: BSImpl, unwrap
using SciMLBase: LinearProblem
using SparseArrays: nonzeros
import LinearAlgebra

const TimeDomain = SciMLBase.AbstractClock

Expand Down Expand Up @@ -53,17 +54,106 @@ abstract type ReassembleAlgorithm end

include("reassemble.jl")

function MTKBase.unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
mask = trues(length(obseqs))
function MTKBase.unhack_system(sys::System)
# Observed are copied by the masking operation
obseqs = observed(sys)
eqs = copy(equations(sys))
obs_mask = trues(length(obseqs))
for (i, eq) in enumerate(obseqs)
mask[i] = @match eq.rhs begin
obs_mask[i] = @match eq.rhs begin
BSImpl.Term(; f) => f !== change_origin
_ => true
end
end
obseqs = obseqs[obs_mask]

obseqs = obseqs[mask]
return obseqs, eqs
# Map from ldiv operation to index of the equations where it is the RHS. A
# positive index is into `obseqs`, a negative index is into `eqs`. The variable
# order for the ldiv comes from the LHS of the corresponding equations.
inline_linear_scc_map = Dict{SymbolicT, Vector{Int}}()

for (i, eq) in enumerate(obseqs)
populate_inline_scc_map!(inline_linear_scc_map, eq, i, false)
end
for (i, eq) in enumerate(eqs)
populate_inline_scc_map!(inline_linear_scc_map, eq, i, true)
end

# Now, we want to turn all inlined linear SCCs into algebraic equations. If an element
# of the SCC is a differential variable, we'll introduce the `toterm` as a new algebraic.
# Otherwise, the observed equation is removed.
resize!(obs_mask, length(obseqs))
fill!(obs_mask, true)
additional_eqs = Equation[]
additional_vars = SymbolicT[]

# Also need to update schedule
sched = MTKBase.get_schedule(sys)
if sched isa MTKBase.Schedule
sched = copy(sched)
end
for (linsolve, idxs) in inline_linear_scc_map
A, b = @match linsolve begin
BSImpl.Term(; args) => args
end
A = collect(A)::Matrix{SymbolicT}
b = collect(b)::Vector{SymbolicT}
x = Vector{SymbolicT}(undef, length(b))
for (i, idx) in enumerate(idxs)
is_obs = idx > 0
idx = abs(idx)
if is_obs
var = obseqs[idx].lhs
x[i] = var
obs_mask[idx] = false
else
var = MTKBase.default_toterm(eqs[idx].lhs)
if sched isa MTKBase.Schedule
sched.dummy_sub[eqs[idx].lhs] = x[i] = var
end
eqs[idx] = eqs[idx].lhs ~ var
end
push!(additional_vars, var)
end

resid = A * x - b
for res in resid
push!(additional_eqs, Symbolics.COMMON_ZERO ~ res)
end
end
obseqs = obseqs[obs_mask]
append!(eqs, additional_eqs)

dvs = [unknowns(sys); additional_vars]

@set! sys.observed = obseqs
@set! sys.eqs = eqs
@set! sys.unknowns = dvs
@set! sys.schedule = sched
return sys
end

function populate_inline_scc_map!(
inline_linear_scc_map::Dict{SymbolicT, Vector{Int}}, eq::Equation, eq_i::Int,
neg::Bool)
@match eq.rhs begin
BSImpl.Term(; f, args) && if f === getindex && length(args) == 2 end => begin
maybe_ldiv = args[1]
_idx = args[2]
@match maybe_ldiv begin
BSImpl.Term(; f) && if f === INLINE_LINEAR_SCC_OP end => begin
ldiv = maybe_ldiv
len = length(ldiv)
buffer = get!(() -> zeros(Int, len), inline_linear_scc_map, ldiv)
idx = unwrap_const(_idx)::Int
buffer[idx] = ifelse(neg, -eq_i, eq_i)
end
_ => nothing
end
end
_ => nothing
end

end

include("clock_inference/clock.jl")
Expand Down
Loading
Loading