testGymCartpoleEnv.py

You can view and download this file on Github: testGymCartpoleEnv.py

  1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  2# This is an EXUDYN example
  3#
  4# Details:  This file serves as an input to testGymCartpole.py
  5#
  6# Author:   Johannes Gerstmayr, Grzegorz Orzechowski
  7# Date:     2022-05-17
  8# Update:   2023-05-20: derive from gym.Env to ensure compatibility with newer stable-baselines3
  9#
 10# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
 11#
 12#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 13
 14import exudyn as exu
 15from exudyn.utilities import * #includes itemInterface and rigidBodyUtilities
 16import exudyn.graphics as graphics #only import if it does not conflict
 17import math
 18
 19import math
 20from typing import Optional, Union
 21
 22import numpy as np
 23
 24import gym
 25from gym import logger, spaces, Env
 26from gym.error import DependencyNotInstalled
 27
 28import stable_baselines3
 29useOldGym = tuple(map(int, stable_baselines3.__version__.split('.'))) <= tuple(map(int, '1.8.0'.split('.')))
 30
 31
 32class CartPoleEnv(Env):
 33
 34    #metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}
 35    metadata = {"render_modes": ["human"], "render_fps": 50}
 36
 37    def __init__(self, thresholdFactor = 1., forceFactor = 1.):
 38
 39        self.SC = exu.SystemContainer()
 40        self.mbs = self.SC.AddSystem()
 41
 42
 43        #%%+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 44        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 45        #take variables from cartpole example in openAIgym
 46        self.gravity = 9.8
 47        self.masscart = 1.0
 48        self.masspole = 0.1
 49        self.total_mass = self.masspole + self.masscart
 50        self.lengthHalf = 0.5  # actually half the pole's self.length
 51        self.length = self.lengthHalf*2
 52        self.polemass_length = self.masspole * self.length
 53        self.force_mag = 10.0*forceFactor
 54        self.stepUpdateTime = 0.02  # seconds between state updates
 55        self.kinematics_integrator = "euler"
 56
 57        # Angle at which to fail the episode
 58        self.theta_threshold_radians = thresholdFactor*12 * 2 * math.pi / 360
 59        self.x_threshold = thresholdFactor*2.4
 60
 61        high = np.array(
 62            [
 63                self.x_threshold * 2,
 64                np.finfo(np.float32).max,
 65                self.theta_threshold_radians * 2,
 66                np.finfo(np.float32).max,
 67            ],
 68            dtype=np.float32,
 69        )
 70
 71        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 72        #see https://github.com/openai/gym/blob/64b4b31d8245f6972b3d37270faf69b74908a67d/gym/core.py#L16
 73        #for Env:
 74        self.action_space = spaces.Discrete(2)
 75        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
 76        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 77        self.state = None
 78        self.rendererRunning=None
 79        self.useRenderer = False #turn this on if needed
 80
 81        background = graphics.CheckerBoard(point= [0,0,0], normal= [0,0,1], size=4)
 82
 83        oGround=self.mbs.AddObject(ObjectGround(referencePosition= [0,0,0],  #x-pos,y-pos,angle
 84                                           visualization=VObjectGround(graphicsData= [background])))
 85        nGround=self.mbs.AddNode(NodePointGround())
 86
 87        gCart = graphics.Brick(size=[0.5*self.length, 0.1*self.length, 0.1*self.length],
 88                                           color=graphics.color.dodgerblue)
 89        self.nCart = self.mbs.AddNode(Rigid2D(referenceCoordinates=[0,0,0]));
 90        oCart = self.mbs.AddObject(RigidBody2D(physicsMass=self.masscart,
 91                                          physicsInertia=0.1*self.masscart, #not needed
 92                                          nodeNumber=self.nCart,
 93                                          visualization=VObjectRigidBody2D(graphicsData= [gCart])))
 94
 95        gPole = graphics.Brick(size=[0.1*self.length, self.length, 0.1*self.length], color=graphics.color.red)
 96        self.nPole = self.mbs.AddNode(Rigid2D(referenceCoordinates=[0,0.5*self.length,0]));
 97        oPole = self.mbs.AddObject(RigidBody2D(physicsMass=self.masspole,
 98                                          physicsInertia=1e-6, #not included in original paper
 99                                          nodeNumber=self.nPole,
100                                          visualization=VObjectRigidBody2D(graphicsData= [gPole])))
101
102        mCartCOM = self.mbs.AddMarker(MarkerNodePosition(nodeNumber=self.nCart))
103        mPoleCOM = self.mbs.AddMarker(MarkerNodePosition(nodeNumber=self.nPole))
104        mPoleJoint = self.mbs.AddMarker(MarkerBodyPosition(bodyNumber=oPole, localPosition=[0,-0.5*self.length,0]))
105
106        mCartCoordX = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=self.nCart, coordinate=0))
107        mCartCoordY = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=self.nCart, coordinate=1))
108        mGroundNode = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=nGround, coordinate=0))
109
110        #gravity
111        self.mbs.AddLoad(Force(markerNumber=mCartCOM, loadVector=[0,-self.masscart*self.gravity,0]))
112        self.mbs.AddLoad(Force(markerNumber=mPoleCOM, loadVector=[0,-self.masspole*self.gravity,0]))
113
114        #control force
115        self.lControl = self.mbs.AddLoad(LoadCoordinate(markerNumber=mCartCoordX, load=1.))
116
117        #constraints:
118        self.mbs.AddObject(RevoluteJoint2D(markerNumbers=[mCartCOM, mPoleJoint]))
119        self.mbs.AddObject(CoordinateConstraint(markerNumbers=[mCartCoordY, mGroundNode]))
120
121
122
123
124        #%%++++++++++++++++++++++++
125        self.mbs.Assemble() #computes initial vector
126
127        self.simulationSettings = exu.SimulationSettings() #takes currently set values or default values
128
129
130        self.simulationSettings.timeIntegration.numberOfSteps = 1
131        self.simulationSettings.timeIntegration.endTime = 0 #will be overwritten in step
132        self.simulationSettings.timeIntegration.verboseMode = 0
133        self.simulationSettings.solutionSettings.writeSolutionToFile = False
134        #self.simulationSettings.timeIntegration.simulateInRealtime = True
135
136        self.simulationSettings.timeIntegration.newton.useModifiedNewton = True
137
138        self.SC.visualizationSettings.general.drawWorldBasis=True
139        self.SC.visualizationSettings.general.graphicsUpdateInterval = 0.01 #50Hz
140
141        self.simulationSettings.solutionSettings.solutionInformation = "Open AI gym"
142
143        self.dynamicSolver = exudyn.MainSolverImplicitSecondOrder()
144        self.dynamicSolver.InitializeSolver(self.mbs, self.simulationSettings)
145        self.dynamicSolver.SolveSteps(self.mbs, self.simulationSettings) #to initialize all data
146
147
148    def integrateStep(self, force):
149        #exudyn simulation part
150        #index 2 solver
151        self.mbs.SetLoadParameter(self.lControl, 'load', force)
152
153        #progress integration time
154        currentTime = self.simulationSettings.timeIntegration.endTime
155        self.simulationSettings.timeIntegration.startTime = currentTime
156        self.simulationSettings.timeIntegration.endTime = currentTime+self.stepUpdateTime
157
158        # exu.SolveDynamic(self.mbs, self.simulationSettings, solverType=exu.DynamicSolverType.TrapezoidalIndex2,
159        #                  updateInitialValues=True) #use final value as new initial values
160
161        self.dynamicSolver.InitializeSolverInitialConditions(self.mbs, self.simulationSettings)
162        self.dynamicSolver.SolveSteps(self.mbs, self.simulationSettings)
163        currentState = self.mbs.systemData.GetSystemState() #get current values
164        self.mbs.systemData.SetSystemState(systemStateList=currentState,
165                                        configuration = exu.ConfigurationType.Initial)
166        self.mbs.systemData.SetODE2Coordinates_tt(coordinates = self.mbs.systemData.GetODE2Coordinates_tt(),
167                                                configuration = exu.ConfigurationType.Initial)
168
169
170
171
172    def step(self, action):
173        err_msg = f"{action!r} ({type(action)}) invalid"
174        assert self.action_space.contains(action), err_msg
175        assert self.state is not None, "Call reset before using step method."
176        x, x_dot, theta, theta_dot = self.state
177
178        force = self.force_mag if action == 1 else -self.force_mag
179
180        #++++++++++++++++++++++++++++++++++++++++++++++++++
181        #++++++++++++++++++++++++++++++++++++++++++++++++++
182        self.integrateStep(force)
183        #+++++++++++++++++++++++++
184        #compute some output:
185        cartPosX = self.mbs.GetNodeOutput(self.nCart, variableType=exu.OutputVariableType.Coordinates)[0]
186        poleAngle = self.mbs.GetNodeOutput(self.nPole, variableType=exu.OutputVariableType.Coordinates)[2]
187        cartPosX_t = self.mbs.GetNodeOutput(self.nCart, variableType=exu.OutputVariableType.Coordinates_t)[0]
188        poleAngle_t = self.mbs.GetNodeOutput(self.nPole, variableType=exu.OutputVariableType.Coordinates_t)[2]
189
190        #finally write updated state:
191        self.state = (cartPosX, cartPosX_t, poleAngle, poleAngle_t)
192        #++++++++++++++++++++++++++++++++++++++++++++++++++
193        #++++++++++++++++++++++++++++++++++++++++++++++++++
194
195        done = bool(
196            cartPosX < -self.x_threshold
197            or cartPosX > self.x_threshold
198            or poleAngle < -self.theta_threshold_radians
199            or poleAngle > self.theta_threshold_radians
200        )
201
202        if not done:
203            reward = 1.0
204        elif self.steps_beyond_done is None:
205            # Pole just fell!
206            self.steps_beyond_done = 0
207            reward = 1.0
208        else:
209            if self.steps_beyond_done == 0:
210                logger.warn(
211                    "You are calling 'step()' even though this "
212                    "environment has already returned done = True. You "
213                    "should always call 'reset()' once you receive 'done = "
214                    "True' -- any further steps are undefined behavior."
215                )
216            self.steps_beyond_done += 1
217            reward = 0.0
218
219        info = {}
220        terminated, truncated = done, False # since stable-baselines3 > 1.8.0 implementations terminated and truncated
221        if useOldGym:
222            return np.array(self.state, dtype=np.float32), reward, terminated, info
223        else:
224            return np.array(self.state, dtype=np.float32), reward, terminated, truncated, info
225
226
227    def reset(
228        self,
229        *,
230        seed: Optional[int] = None,
231        return_info: bool = False,
232        options: Optional[dict] = None,
233    ):
234        #super().reset(seed=seed)
235        #self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
236        self.state = np.random.uniform(low=-0.05, high=0.05, size=(4,))
237        self.steps_beyond_done = None
238
239
240        #+++++++++++++++++++++++++++++++++++++++++++++
241        #set initial values:
242        #+++++++++++++++++++++++++++++++++++++++++++++
243        #set specific initial state:
244        (xCart, xCart_t, phiPole, phiPole_t) = self.state
245
246        initialValues = np.zeros(6)
247        initialValues_t = np.zeros(6)
248        initialValues[0] = xCart
249        initialValues[3+0] = xCart - 0.5*self.length * sin(phiPole)
250        initialValues[3+1] = 0.5*self.length * (cos(phiPole)-1)
251        initialValues[3+2] = phiPole
252
253        initialValues_t[0] = xCart_t
254        initialValues_t[3+0] = xCart_t - phiPole_t*0.5*self.length * cos(phiPole)
255        initialValues_t[3+1] = -0.5*self.length * sin(phiPole)  * phiPole_t
256        initialValues_t[3+2] = phiPole_t
257
258        self.mbs.systemData.SetODE2Coordinates(initialValues, exu.ConfigurationType.Initial)
259        self.mbs.systemData.SetODE2Coordinates_t(initialValues_t, exu.ConfigurationType.Initial)
260
261        self.simulationSettings.timeIntegration.endTime = 0
262        #self.dynamicSolver.FinalizeSolver(self.mbs, self.simulationSettings) #needed to update initial conditions
263        self.dynamicSolver.InitializeSolver(self.mbs, self.simulationSettings) #needed to update initial conditions
264        # self.dynamicSolver.InitializeSolverInitialConditions(self.mbs, self.simulationSettings) #needed to update initial conditions
265
266        if not return_info and useOldGym:
267            return np.array(self.state, dtype=np.float32)
268        else:
269            return np.array(self.state, dtype=np.float32), {}
270
271    def render(self, mode="human"):
272        if self.rendererRunning==None and self.useRenderer:
273            exu.StartRenderer()
274            self.rendererRunning = True
275
276    def close(self):
277        self.dynamicSolver.FinalizeSolver(self.mbs, self.simulationSettings)
278        if self.rendererRunning==True:
279            # SC.WaitForRenderEngineStopFlag()
280            exu.StopRenderer() #safely close rendering window!
281
282
283
284# #+++++++++++++++++++++++++++++++++++++++++++++
285# #reset:
286# self.mbs.systemData.SetODE2Coordinates(initialValues, exu.ConfigurationType.Initial)
287# self.mbs.systemData.SetODE2Coordinates_t(initialValues, exu.ConfigurationType.Initial)
288# self.mbs.systemData.SetODE2Coordinates_tt(initialValues, exu.ConfigurationType.Initial)