33class CartPoleConfigModule ():
44 # parameters
55 ENV_NAME = "CartPole-v0"
6+ PLANNER_TYPE = "Const"
67 TYPE = "Nonlinear"
78 TASK_HORIZON = 500
89 PRED_LEN = 50
910 STATE_SIZE = 4
1011 INPUT_SIZE = 1
1112 DT = 0.02
1213 # cost parameters
13- R = np .diag ([1. ]) # 0.01 is worked for MPPI and CEM and MPPIWilliams
14+ R = np .diag ([0.01 ]) # 0.01 is worked for MPPI and CEM and MPPIWilliams
1415 # 1. is worked for iLQR
15- Terminal_Weight = 1.
16+ TERMINAL_WEIGHT = 1.
1617 Q = None
1718 Sf = None
1819 # bounds
@@ -23,6 +24,7 @@ class CartPoleConfigModule():
2324 MC = 1.
2425 L = 0.5
2526 G = 9.81
27+ CART_SIZE = (0.15 , 0.1 )
2628
2729 def __init__ (self ):
2830 """
@@ -76,6 +78,7 @@ def __init__(self):
7678 @staticmethod
7779 def input_cost_fn (u ):
7880 """ input cost functions
81+
7982 Args:
8083 u (numpy.ndarray): input, shape(pred_len, input_size)
8184 or shape(pop_size, pred_len, input_size)
@@ -88,6 +91,7 @@ def input_cost_fn(u):
8891 @staticmethod
8992 def state_cost_fn (x , g_x ):
9093 """ state cost function
94+
9195 Args:
9296 x (numpy.ndarray): state, shape(pred_len, state_size)
9397 or shape(pop_size, pred_len, state_size)
@@ -118,6 +122,7 @@ def state_cost_fn(x, g_x):
118122 @staticmethod
119123 def terminal_state_cost_fn (terminal_x , terminal_g_x ):
120124 """
125+
121126 Args:
122127 terminal_x (numpy.ndarray): terminal state,
123128 shape(state_size, ) or shape(pop_size, state_size)
@@ -133,13 +138,13 @@ def terminal_state_cost_fn(terminal_x, terminal_g_x):
133138 + 12. * ((np .cos (terminal_x [:, 2 ]) + 1. )** 2 ) \
134139 + 0.1 * (terminal_x [:, 1 ]** 2 ) \
135140 + 0.1 * (terminal_x [:, 3 ]** 2 ))[:, np .newaxis ] \
136- * CartPoleConfigModule .Terminal_Weight
141+ * CartPoleConfigModule .TERMINAL_WEIGHT
137142
138143 return (6. * (terminal_x [0 ]** 2 ) \
139144 + 12. * ((np .cos (terminal_x [2 ]) + 1. )** 2 ) \
140145 + 0.1 * (terminal_x [1 ]** 2 ) \
141146 + 0.1 * (terminal_x [3 ]** 2 )) \
142- * CartPoleConfigModule .Terminal_Weight
147+ * CartPoleConfigModule .TERMINAL_WEIGHT
143148
144149 @staticmethod
145150 def gradient_cost_fn_with_state (x , g_x , terminal = False ):
@@ -168,7 +173,7 @@ def gradient_cost_fn_with_state(x, g_x, terminal=False):
168173 cost_dx3 = 0.2 * x [3 ]
169174 cost_dx = np .array ([[cost_dx0 , cost_dx1 , cost_dx2 , cost_dx3 ]])
170175
171- return cost_dx * CartPoleConfigModule .Terminal_Weight
176+ return cost_dx * CartPoleConfigModule .TERMINAL_WEIGHT
172177
173178 @staticmethod
174179 def gradient_cost_fn_with_input (x , u ):
@@ -177,7 +182,6 @@ def gradient_cost_fn_with_input(x, u):
177182 Args:
178183 x (numpy.ndarray): state, shape(pred_len, state_size)
179184 u (numpy.ndarray): goal state, shape(pred_len, input_size)
180-
181185 Returns:
182186 l_u (numpy.ndarray): gradient of cost, shape(pred_len, input_size)
183187 """
@@ -190,7 +194,6 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
190194 Args:
191195 x (numpy.ndarray): state, shape(pred_len, state_size)
192196 g_x (numpy.ndarray): goal state, shape(pred_len, state_size)
193-
194197 Returns:
195198 l_xx (numpy.ndarray): gradient of cost,
196199 shape(pred_len, state_size, state_size) or
@@ -220,7 +223,7 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
220223 * - np .cos (x [2 ])
221224 hessian [3 , 3 ] = 0.2
222225
223- return hessian [np .newaxis , :, :] * CartPoleConfigModule .Terminal_Weight
226+ return hessian [np .newaxis , :, :] * CartPoleConfigModule .TERMINAL_WEIGHT
224227
225228 @staticmethod
226229 def hessian_cost_fn_with_input (x , u ):
@@ -229,7 +232,6 @@ def hessian_cost_fn_with_input(x, u):
229232 Args:
230233 x (numpy.ndarray): state, shape(pred_len, state_size)
231234 u (numpy.ndarray): goal state, shape(pred_len, input_size)
232-
233235 Returns:
234236 l_uu (numpy.ndarray): gradient of cost,
235237 shape(pred_len, input_size, input_size)
@@ -245,7 +247,6 @@ def hessian_cost_fn_with_input_state(x, u):
245247 Args:
246248 x (numpy.ndarray): state, shape(pred_len, state_size)
247249 u (numpy.ndarray): goal state, shape(pred_len, input_size)
248-
249250 Returns:
250251 l_ux (numpy.ndarray): gradient of cost ,
251252 shape(pred_len, input_size, state_size)
0 commit comments