diff --git a/pufferlib/ocean/cartpole/cartpole.h b/pufferlib/ocean/cartpole/cartpole.h index 7d87e11d6..3650ef766 100644 --- a/pufferlib/ocean/cartpole/cartpole.h +++ b/pufferlib/ocean/cartpole/cartpole.h @@ -182,10 +182,10 @@ void c_step(Cartpole* env) { float sintheta = sinf(env->theta); float total_mass = env->cart_mass + env->pole_mass; - float polemass_length = total_mass + env->pole_mass; + float polemass_length = env->pole_mass * env->pole_length; float temp = (force + polemass_length * env->theta_dot * env->theta_dot * sintheta) / total_mass; - float thetaacc = (env->gravity * sintheta - costheta * temp) / - (env->pole_length * (4.0f / 3.0f - total_mass * costheta * costheta / total_mass)); + float thetaacc = (env->gravity * sintheta - costheta * temp) / + (env->pole_length * (4.0f / 3.0f - env->pole_mass * costheta * costheta / total_mass)); float xacc = temp - polemass_length * thetaacc * costheta / total_mass; env->x += env->tau * env->x_dot;